diff --git a/zzb_data_prod/vector_storage.py b/zzb_data_prod/vector_storage.py new file mode 100644 index 0000000..9da1ed2 --- /dev/null +++ b/zzb_data_prod/vector_storage.py @@ -0,0 +1,236 @@ +import numpy as np +from multiprocessing import Process, Array, RLock, cpu_count +from typing import List, Dict, Any, Tuple, Optional +import heapq +import ctypes +import json +import base64 + +class VectorStorage: + """内存向量存储与检索类""" + + def __init__(self, dim: int = 1024, max_vectors: int = 10000): + self.dim = dim + self.max_vectors = max_vectors + + # 使用共享内存存储所有数据(向量和元数据) + # 我们将向量和元数据一起序列化存储 + self.data_strings = Array(ctypes.c_char, max_vectors * 16384, lock=False) # 假设每条数据最大16384字节 + self.data_offsets = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的偏移量 + self.data_lengths = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的长度 + + # 使用共享计数器记录数据数量 + self.data_count = Array(ctypes.c_int, 1, lock=True) + + self.lock = RLock() + + def get_data(self, index: int) -> Dict[str, Any]: + """从共享内存中获取数据(包含向量和元数据)""" + if index >= self.data_count[0]: + return {} + + offset = self.data_offsets[index] + length = self.data_lengths[index] + + # 从共享内存中提取字节串 + data_bytes = bytes(self.data_strings[offset:offset+length]) + + # 反序列化为字典 + try: + data_dict = json.loads(data_bytes.decode('utf-8')) + + # 将base64编码的向量转换回numpy数组 + if 'vector_base64' in data_dict: + vector_bytes = base64.b64decode(data_dict['vector_base64']) + data_dict['vector'] = np.frombuffer(vector_bytes, dtype=np.float32) + del data_dict['vector_base64'] # 删除base64编码的向量 + + return data_dict + except: + return { + 'table_num': -1, + 'table_index': -1, + 'measure_name': 'unknown', + 'measure_value': 'unknown', + 'measure_unit': 'unknown', + 'file_id': 'unknown', + 'vector': np.zeros(self.dim, dtype=np.float32) + } + + def set_data(self, index: int, data: Dict[str, Any]) -> None: + """将数据(包含向量和元数据)存储到共享内存中""" + # 复制数据字典,避免修改原始数据 + data_copy = data.copy() + + # 将numpy向量转换为base64编码的字符串 + if 'vector' in data_copy and isinstance(data_copy['vector'], np.ndarray): + vector_bytes = data_copy['vector'].astype(np.float32).tobytes() + data_copy['vector_base64'] = base64.b64encode(vector_bytes).decode('utf-8') + del data_copy['vector'] # 删除原始向量 + + # 序列化为JSON字符串 + data_str = json.dumps(data_copy) + data_bytes = data_str.encode('utf-8') + + # 计算偏移量 + offset = 0 + if index > 0: + offset = self.data_offsets[index-1] + self.data_lengths[index-1] + + # 检查是否有足够空间 + if offset + len(data_bytes) > len(self.data_strings): + raise ValueError("Not enough space for data") + + # 存储到共享内存 + self.data_offsets[index] = offset + self.data_lengths[index] = len(data_bytes) + + # 将字节数据复制到共享内存 + for i, byte in enumerate(data_bytes): + self.data_strings[offset + i] = byte + + def add_data(self, data_list: List[Dict[str, Any]]) -> None: + """添加数据(包含向量和元数据)到存储""" + num_new = len(data_list) + + with self.data_count.get_lock(): + current_count = self.data_count[0] + if current_count + num_new > self.max_vectors: + raise ValueError("Exceeded maximum data storage capacity.") + + # 将数据添加到共享内存 + for i, data in enumerate(data_list): + self.set_data(current_count + i, data) + + # 更新计数器 + self.data_count[0] = current_count + num_new + + def search_similar_vectors(self, query_vector: np.ndarray, + top_k: int = 10, similarity_threshold: float = 0.0) -> List[Tuple[float, Dict[str, Any]]]: + """搜索相似向量""" + query_vector = query_vector.astype(np.float32) + current_count = self.data_count[0] + if current_count == 0: + return [] + + # 计算所有向量的相似度 + similarities = [] + for i in range(current_count): + data = self.get_data(i) + if 'vector' in data: + vector = data['vector'] + # 计算余弦相似度 + dot_product = np.dot(vector, query_vector) + norm_vector = np.linalg.norm(vector) + norm_query = np.linalg.norm(query_vector) + similarity = dot_product / (norm_vector * norm_query + 1e-10) + similarities.append((similarity, data)) + + # 筛选出相似度大于阈值的向量 + above_threshold = [(sim, data) for sim, data in similarities if sim > similarity_threshold] + + if len(above_threshold) == 0: + return [] + + # 按相似度降序排序 + above_threshold.sort(key=lambda x: x[0], reverse=True) + + # 返回top_k个结果 + return above_threshold[:top_k] + +def vectorize_financial_data(financial_data: Dict[str, Any]) -> np.ndarray: + """将财务指标数据转化为1024维向量""" + # 实际应用中替换为您的向量化算法 + vector = np.random.rand(1024).astype(np.float32) # 随机生成示例向量 + return vector + +def worker_process(data_chunk: List[Dict[str, Any]], storage: VectorStorage) -> None: + """工作进程函数""" + data_list = [] + + for data in data_chunk: + # 生成向量 + vector = vectorize_financial_data(data) + + # 准备完整数据(包含向量和元数据) + full_data = { + 'table_num': data.get('table_num', -1), + 'table_index': data.get('table_index', -1), + 'measure_name': data.get('measure_name', 'unknown'), + 'measure_value': data.get('measure_value', 'unknown'), + 'measure_unit': data.get('measure_unit', 'unknown'), + 'file_id': data.get('file_id', 'unknown'), + 'vector': vector # 将向量作为数据的一部分 + } + data_list.append(full_data) + + # 添加数据到存储 + storage.add_data(data_list) + print(f"Processed {len(data_chunk)} financial data entries. Total data now: {storage.data_count[0]}") + +if __name__ == '__main__': + # 初始化 + dim = 1024 + max_vectors = 100 + shared_storage = VectorStorage(dim, max_vectors) + + # 准备数据 - 模拟您的数据格式 + financial_data = [] + for i in range(100): + measure_data = { + 'table_num': i // 10 + 1, # 模拟10个表 + 'table_index': i % 10, + 'measure_name': f'measure_{i}', + 'measure_value': i * 1.5, + 'measure_unit': 'unit', + 'file_id': f'file_{i // 20}' # 模拟5个文件 + } + financial_data.append(measure_data) + + # 分割数据 + num_processes = 4 + chunk_size = len(financial_data) // num_processes + data_chunks = [ + financial_data[i*chunk_size:(i+1)*chunk_size] + for i in range(num_processes) + ] + + # 处理可能的不整除情况 + if len(financial_data) % num_processes != 0: + data_chunks[-1].extend(financial_data[num_processes * chunk_size:]) + + # 启动多进程 + processes = [] + for chunk in data_chunks: + p = Process(target=worker_process, args=(chunk, shared_storage)) + processes.append(p) + p.start() + + # 等待所有进程完成 + for p in processes: + p.join() + + print(f"\nAll processes finished. Total data stored: {shared_storage.data_count[0]}") + + # 示例查询 + try: + # 创建一个查询向量 + query_vector = np.random.rand(1024).astype(np.float32) + + # 使用相似度阈值进行查询 + similarity_threshold = 0.5 + similar_items = shared_storage.search_similar_vectors( + query_vector, top_k=5, similarity_threshold=similarity_threshold + ) + + print(f"\nTop vectors with similarity > {similarity_threshold}:") + if similar_items: + for score, data in similar_items: + print(f"Score: {score:.4f}, Table: {data['table_num']}, Index: {data['table_index']}, Name: {data['measure_name']}") + else: + print("No vectors found above the similarity threshold.") + + except Exception as e: + print(f"Error during query: {e}") + import traceback + traceback.print_exc() \ No newline at end of file