import numpy as np from multiprocessing import Process, Array, RLock, cpu_count from typing import List, Dict, Any, Tuple, Optional import heapq import ctypes import json import base64 class VectorStorage: """内存向量存储与检索类""" def __init__(self, dim: int = 1024, max_vectors: int = 10000): self.dim = dim self.max_vectors = max_vectors # 使用共享内存存储所有数据(向量和元数据) # 我们将向量和元数据一起序列化存储 self.data_strings = Array(ctypes.c_char, max_vectors * 16384, lock=False) # 假设每条数据最大16384字节 self.data_offsets = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的偏移量 self.data_lengths = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的长度 # 使用共享计数器记录数据数量 self.data_count = Array(ctypes.c_int, 1, lock=True) self.lock = RLock() def get_data(self, index: int) -> Dict[str, Any]: """从共享内存中获取数据(包含向量和元数据)""" if index >= self.data_count[0]: return {} offset = self.data_offsets[index] length = self.data_lengths[index] # 从共享内存中提取字节串 data_bytes = bytes(self.data_strings[offset:offset+length]) # 反序列化为字典 try: data_dict = json.loads(data_bytes.decode('utf-8')) # 将base64编码的向量转换回numpy数组 if 'vector_base64' in data_dict: vector_bytes = base64.b64decode(data_dict['vector_base64']) data_dict['vector'] = np.frombuffer(vector_bytes, dtype=np.float32) del data_dict['vector_base64'] # 删除base64编码的向量 return data_dict except: return { 'table_num': -1, 'table_index': -1, 'measure_name': 'unknown', 'measure_value': 'unknown', 'measure_unit': 'unknown', 'file_id': 'unknown', 'vector': np.zeros(self.dim, dtype=np.float32) } def set_data(self, index: int, data: Dict[str, Any]) -> None: """将数据(包含向量和元数据)存储到共享内存中""" # 复制数据字典,避免修改原始数据 data_copy = data.copy() # 将numpy向量转换为base64编码的字符串 if 'vector' in data_copy and isinstance(data_copy['vector'], np.ndarray): vector_bytes = data_copy['vector'].astype(np.float32).tobytes() data_copy['vector_base64'] = base64.b64encode(vector_bytes).decode('utf-8') del data_copy['vector'] # 删除原始向量 # 序列化为JSON字符串 data_str = json.dumps(data_copy) data_bytes = data_str.encode('utf-8') # 计算偏移量 offset = 0 if index > 0: offset = self.data_offsets[index-1] + self.data_lengths[index-1] # 检查是否有足够空间 if offset + len(data_bytes) > len(self.data_strings): raise ValueError("Not enough space for data") # 存储到共享内存 self.data_offsets[index] = offset self.data_lengths[index] = len(data_bytes) # 将字节数据复制到共享内存 for i, byte in enumerate(data_bytes): self.data_strings[offset + i] = byte def add_data(self, data_list: List[Dict[str, Any]]) -> None: """添加数据(包含向量和元数据)到存储""" num_new = len(data_list) with self.data_count.get_lock(): current_count = self.data_count[0] if current_count + num_new > self.max_vectors: raise ValueError("Exceeded maximum data storage capacity.") # 将数据添加到共享内存 for i, data in enumerate(data_list): self.set_data(current_count + i, data) # 更新计数器 self.data_count[0] = current_count + num_new def search_similar_vectors(self, query_vector: np.ndarray, top_k: int = 10, similarity_threshold: float = 0.0) -> List[Tuple[float, Dict[str, Any]]]: """搜索相似向量""" query_vector = np.asarray(query_vector).astype(np.float32) current_count = self.data_count[0] if current_count == 0: return [] # 计算所有向量的相似度 similarities = [] for i in range(current_count): data = self.get_data(i) if 'vector' in data: vector = data['vector'] # 计算余弦相似度 dot_product = np.dot(vector, query_vector) norm_vector = np.linalg.norm(vector) norm_query = np.linalg.norm(query_vector) similarity = dot_product / (norm_vector * norm_query + 1e-10) similarities.append((similarity, data)) # 筛选出相似度大于阈值的向量 above_threshold = [(sim, data) for sim, data in similarities if sim > similarity_threshold] if len(above_threshold) == 0: return [] # 按相似度降序排序 above_threshold.sort(key=lambda x: x[0], reverse=True) # 返回top_k个结果 return above_threshold[:top_k] def vectorize_financial_data(financial_data: Dict[str, Any]) -> np.ndarray: """将财务指标数据转化为1024维向量""" # 实际应用中替换为您的向量化算法 vector = np.random.rand(1024).astype(np.float32) # 随机生成示例向量 return vector def worker_process(data_chunk: List[Dict[str, Any]], storage: VectorStorage) -> None: """工作进程函数""" data_list = [] for data in data_chunk: # 生成向量 vector = vectorize_financial_data(data) # 准备完整数据(包含向量和元数据) full_data = { 'table_num': data.get('table_num', -1), 'table_index': data.get('table_index', -1), 'measure_name': data.get('measure_name', 'unknown'), 'measure_value': data.get('measure_value', 'unknown'), 'measure_unit': data.get('measure_unit', 'unknown'), 'file_id': data.get('file_id', 'unknown'), 'vector': vector # 将向量作为数据的一部分 } data_list.append(full_data) # 添加数据到存储 storage.add_data(data_list) print(f"Processed {len(data_chunk)} financial data entries. Total data now: {storage.data_count[0]}") if __name__ == '__main__': # 初始化 dim = 1024 max_vectors = 100 shared_storage = VectorStorage(dim, max_vectors) # 准备数据 - 模拟您的数据格式 financial_data = [] for i in range(100): measure_data = { 'table_num': i // 10 + 1, # 模拟10个表 'table_index': i % 10, 'measure_name': f'measure_{i}', 'measure_value': i * 1.5, 'measure_unit': 'unit', 'file_id': f'file_{i // 20}' # 模拟5个文件 } financial_data.append(measure_data) # 分割数据 num_processes = 4 chunk_size = len(financial_data) // num_processes data_chunks = [ financial_data[i*chunk_size:(i+1)*chunk_size] for i in range(num_processes) ] # 处理可能的不整除情况 if len(financial_data) % num_processes != 0: data_chunks[-1].extend(financial_data[num_processes * chunk_size:]) # 启动多进程 processes = [] for chunk in data_chunks: p = Process(target=worker_process, args=(chunk, shared_storage)) processes.append(p) p.start() # 等待所有进程完成 for p in processes: p.join() print(f"\nAll processes finished. Total data stored: {shared_storage.data_count[0]}") # 示例查询 try: # 创建一个查询向量 query_vector = np.random.rand(1024).astype(np.float32) # 使用相似度阈值进行查询 similarity_threshold = 0.5 similar_items = shared_storage.search_similar_vectors( query_vector, top_k=5, similarity_threshold=similarity_threshold ) print(f"\nTop vectors with similarity > {similarity_threshold}:") if similar_items: for score, data in similar_items: print(f"Score: {score:.4f}, Table: {data['table_num']}, Index: {data['table_index']}, Name: {data['measure_name']}") else: print("No vectors found above the similarity threshold.") except Exception as e: print(f"Error during query: {e}") import traceback traceback.print_exc()