pdf_code/zzb_data_prod/vector_storage.py

236 lines
9.1 KiB
Python

import numpy as np
from multiprocessing import Process, Array, RLock, cpu_count
from typing import List, Dict, Any, Tuple, Optional
import heapq
import ctypes
import json
import base64
class VectorStorage:
"""内存向量存储与检索类"""
def __init__(self, dim: int = 1024, max_vectors: int = 10000):
self.dim = dim
self.max_vectors = max_vectors
# 使用共享内存存储所有数据(向量和元数据)
# 我们将向量和元数据一起序列化存储
self.data_strings = Array(ctypes.c_char, max_vectors * 16384, lock=False) # 假设每条数据最大16384字节
self.data_offsets = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的偏移量
self.data_lengths = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的长度
# 使用共享计数器记录数据数量
self.data_count = Array(ctypes.c_int, 1, lock=True)
self.lock = RLock()
def get_data(self, index: int) -> Dict[str, Any]:
"""从共享内存中获取数据(包含向量和元数据)"""
if index >= self.data_count[0]:
return {}
offset = self.data_offsets[index]
length = self.data_lengths[index]
# 从共享内存中提取字节串
data_bytes = bytes(self.data_strings[offset:offset+length])
# 反序列化为字典
try:
data_dict = json.loads(data_bytes.decode('utf-8'))
# 将base64编码的向量转换回numpy数组
if 'vector_base64' in data_dict:
vector_bytes = base64.b64decode(data_dict['vector_base64'])
data_dict['vector'] = np.frombuffer(vector_bytes, dtype=np.float32)
del data_dict['vector_base64'] # 删除base64编码的向量
return data_dict
except:
return {
'table_num': -1,
'table_index': -1,
'measure_name': 'unknown',
'measure_value': 'unknown',
'measure_unit': 'unknown',
'file_id': 'unknown',
'vector': np.zeros(self.dim, dtype=np.float32)
}
def set_data(self, index: int, data: Dict[str, Any]) -> None:
"""将数据(包含向量和元数据)存储到共享内存中"""
# 复制数据字典,避免修改原始数据
data_copy = data.copy()
# 将numpy向量转换为base64编码的字符串
if 'vector' in data_copy and isinstance(data_copy['vector'], np.ndarray):
vector_bytes = data_copy['vector'].astype(np.float32).tobytes()
data_copy['vector_base64'] = base64.b64encode(vector_bytes).decode('utf-8')
del data_copy['vector'] # 删除原始向量
# 序列化为JSON字符串
data_str = json.dumps(data_copy)
data_bytes = data_str.encode('utf-8')
# 计算偏移量
offset = 0
if index > 0:
offset = self.data_offsets[index-1] + self.data_lengths[index-1]
# 检查是否有足够空间
if offset + len(data_bytes) > len(self.data_strings):
raise ValueError("Not enough space for data")
# 存储到共享内存
self.data_offsets[index] = offset
self.data_lengths[index] = len(data_bytes)
# 将字节数据复制到共享内存
for i, byte in enumerate(data_bytes):
self.data_strings[offset + i] = byte
def add_data(self, data_list: List[Dict[str, Any]]) -> None:
"""添加数据(包含向量和元数据)到存储"""
num_new = len(data_list)
with self.data_count.get_lock():
current_count = self.data_count[0]
if current_count + num_new > self.max_vectors:
raise ValueError("Exceeded maximum data storage capacity.")
# 将数据添加到共享内存
for i, data in enumerate(data_list):
self.set_data(current_count + i, data)
# 更新计数器
self.data_count[0] = current_count + num_new
def search_similar_vectors(self, query_vector: np.ndarray,
top_k: int = 10, similarity_threshold: float = 0.0) -> List[Tuple[float, Dict[str, Any]]]:
"""搜索相似向量"""
query_vector = np.asarray(query_vector).astype(np.float32)
current_count = self.data_count[0]
if current_count == 0:
return []
# 计算所有向量的相似度
similarities = []
for i in range(current_count):
data = self.get_data(i)
if 'vector' in data:
vector = data['vector']
# 计算余弦相似度
dot_product = np.dot(vector, query_vector)
norm_vector = np.linalg.norm(vector)
norm_query = np.linalg.norm(query_vector)
similarity = dot_product / (norm_vector * norm_query + 1e-10)
similarities.append((similarity, data))
# 筛选出相似度大于阈值的向量
above_threshold = [(sim, data) for sim, data in similarities if sim > similarity_threshold]
if len(above_threshold) == 0:
return []
# 按相似度降序排序
above_threshold.sort(key=lambda x: x[0], reverse=True)
# 返回top_k个结果
return above_threshold[:top_k]
def vectorize_financial_data(financial_data: Dict[str, Any]) -> np.ndarray:
"""将财务指标数据转化为1024维向量"""
# 实际应用中替换为您的向量化算法
vector = np.random.rand(1024).astype(np.float32) # 随机生成示例向量
return vector
def worker_process(data_chunk: List[Dict[str, Any]], storage: VectorStorage) -> None:
"""工作进程函数"""
data_list = []
for data in data_chunk:
# 生成向量
vector = vectorize_financial_data(data)
# 准备完整数据(包含向量和元数据)
full_data = {
'table_num': data.get('table_num', -1),
'table_index': data.get('table_index', -1),
'measure_name': data.get('measure_name', 'unknown'),
'measure_value': data.get('measure_value', 'unknown'),
'measure_unit': data.get('measure_unit', 'unknown'),
'file_id': data.get('file_id', 'unknown'),
'vector': vector # 将向量作为数据的一部分
}
data_list.append(full_data)
# 添加数据到存储
storage.add_data(data_list)
print(f"Processed {len(data_chunk)} financial data entries. Total data now: {storage.data_count[0]}")
if __name__ == '__main__':
# 初始化
dim = 1024
max_vectors = 100
shared_storage = VectorStorage(dim, max_vectors)
# 准备数据 - 模拟您的数据格式
financial_data = []
for i in range(100):
measure_data = {
'table_num': i // 10 + 1, # 模拟10个表
'table_index': i % 10,
'measure_name': f'measure_{i}',
'measure_value': i * 1.5,
'measure_unit': 'unit',
'file_id': f'file_{i // 20}' # 模拟5个文件
}
financial_data.append(measure_data)
# 分割数据
num_processes = 4
chunk_size = len(financial_data) // num_processes
data_chunks = [
financial_data[i*chunk_size:(i+1)*chunk_size]
for i in range(num_processes)
]
# 处理可能的不整除情况
if len(financial_data) % num_processes != 0:
data_chunks[-1].extend(financial_data[num_processes * chunk_size:])
# 启动多进程
processes = []
for chunk in data_chunks:
p = Process(target=worker_process, args=(chunk, shared_storage))
processes.append(p)
p.start()
# 等待所有进程完成
for p in processes:
p.join()
print(f"\nAll processes finished. Total data stored: {shared_storage.data_count[0]}")
# 示例查询
try:
# 创建一个查询向量
query_vector = np.random.rand(1024).astype(np.float32)
# 使用相似度阈值进行查询
similarity_threshold = 0.5
similar_items = shared_storage.search_similar_vectors(
query_vector, top_k=5, similarity_threshold=similarity_threshold
)
print(f"\nTop vectors with similarity > {similarity_threshold}:")
if similar_items:
for score, data in similar_items:
print(f"Score: {score:.4f}, Table: {data['table_num']}, Index: {data['table_index']}, Name: {data['measure_name']}")
else:
print("No vectors found above the similarity threshold.")
except Exception as e:
print(f"Error during query: {e}")
import traceback
traceback.print_exc()