内存向量替代milvus, 初版未测试备份
This commit is contained in:
parent
edbcc245a6
commit
b5af2301f0
|
@ -0,0 +1,236 @@
|
||||||
|
import numpy as np
|
||||||
|
from multiprocessing import Process, Array, RLock, cpu_count
|
||||||
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
|
import heapq
|
||||||
|
import ctypes
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
|
||||||
|
class VectorStorage:
|
||||||
|
"""内存向量存储与检索类"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 1024, max_vectors: int = 10000):
|
||||||
|
self.dim = dim
|
||||||
|
self.max_vectors = max_vectors
|
||||||
|
|
||||||
|
# 使用共享内存存储所有数据(向量和元数据)
|
||||||
|
# 我们将向量和元数据一起序列化存储
|
||||||
|
self.data_strings = Array(ctypes.c_char, max_vectors * 16384, lock=False) # 假设每条数据最大16384字节
|
||||||
|
self.data_offsets = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的偏移量
|
||||||
|
self.data_lengths = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的长度
|
||||||
|
|
||||||
|
# 使用共享计数器记录数据数量
|
||||||
|
self.data_count = Array(ctypes.c_int, 1, lock=True)
|
||||||
|
|
||||||
|
self.lock = RLock()
|
||||||
|
|
||||||
|
def get_data(self, index: int) -> Dict[str, Any]:
|
||||||
|
"""从共享内存中获取数据(包含向量和元数据)"""
|
||||||
|
if index >= self.data_count[0]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
offset = self.data_offsets[index]
|
||||||
|
length = self.data_lengths[index]
|
||||||
|
|
||||||
|
# 从共享内存中提取字节串
|
||||||
|
data_bytes = bytes(self.data_strings[offset:offset+length])
|
||||||
|
|
||||||
|
# 反序列化为字典
|
||||||
|
try:
|
||||||
|
data_dict = json.loads(data_bytes.decode('utf-8'))
|
||||||
|
|
||||||
|
# 将base64编码的向量转换回numpy数组
|
||||||
|
if 'vector_base64' in data_dict:
|
||||||
|
vector_bytes = base64.b64decode(data_dict['vector_base64'])
|
||||||
|
data_dict['vector'] = np.frombuffer(vector_bytes, dtype=np.float32)
|
||||||
|
del data_dict['vector_base64'] # 删除base64编码的向量
|
||||||
|
|
||||||
|
return data_dict
|
||||||
|
except:
|
||||||
|
return {
|
||||||
|
'table_num': -1,
|
||||||
|
'table_index': -1,
|
||||||
|
'measure_name': 'unknown',
|
||||||
|
'measure_value': 'unknown',
|
||||||
|
'measure_unit': 'unknown',
|
||||||
|
'file_id': 'unknown',
|
||||||
|
'vector': np.zeros(self.dim, dtype=np.float32)
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_data(self, index: int, data: Dict[str, Any]) -> None:
|
||||||
|
"""将数据(包含向量和元数据)存储到共享内存中"""
|
||||||
|
# 复制数据字典,避免修改原始数据
|
||||||
|
data_copy = data.copy()
|
||||||
|
|
||||||
|
# 将numpy向量转换为base64编码的字符串
|
||||||
|
if 'vector' in data_copy and isinstance(data_copy['vector'], np.ndarray):
|
||||||
|
vector_bytes = data_copy['vector'].astype(np.float32).tobytes()
|
||||||
|
data_copy['vector_base64'] = base64.b64encode(vector_bytes).decode('utf-8')
|
||||||
|
del data_copy['vector'] # 删除原始向量
|
||||||
|
|
||||||
|
# 序列化为JSON字符串
|
||||||
|
data_str = json.dumps(data_copy)
|
||||||
|
data_bytes = data_str.encode('utf-8')
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = 0
|
||||||
|
if index > 0:
|
||||||
|
offset = self.data_offsets[index-1] + self.data_lengths[index-1]
|
||||||
|
|
||||||
|
# 检查是否有足够空间
|
||||||
|
if offset + len(data_bytes) > len(self.data_strings):
|
||||||
|
raise ValueError("Not enough space for data")
|
||||||
|
|
||||||
|
# 存储到共享内存
|
||||||
|
self.data_offsets[index] = offset
|
||||||
|
self.data_lengths[index] = len(data_bytes)
|
||||||
|
|
||||||
|
# 将字节数据复制到共享内存
|
||||||
|
for i, byte in enumerate(data_bytes):
|
||||||
|
self.data_strings[offset + i] = byte
|
||||||
|
|
||||||
|
def add_data(self, data_list: List[Dict[str, Any]]) -> None:
|
||||||
|
"""添加数据(包含向量和元数据)到存储"""
|
||||||
|
num_new = len(data_list)
|
||||||
|
|
||||||
|
with self.data_count.get_lock():
|
||||||
|
current_count = self.data_count[0]
|
||||||
|
if current_count + num_new > self.max_vectors:
|
||||||
|
raise ValueError("Exceeded maximum data storage capacity.")
|
||||||
|
|
||||||
|
# 将数据添加到共享内存
|
||||||
|
for i, data in enumerate(data_list):
|
||||||
|
self.set_data(current_count + i, data)
|
||||||
|
|
||||||
|
# 更新计数器
|
||||||
|
self.data_count[0] = current_count + num_new
|
||||||
|
|
||||||
|
def search_similar_vectors(self, query_vector: np.ndarray,
|
||||||
|
top_k: int = 10, similarity_threshold: float = 0.0) -> List[Tuple[float, Dict[str, Any]]]:
|
||||||
|
"""搜索相似向量"""
|
||||||
|
query_vector = query_vector.astype(np.float32)
|
||||||
|
current_count = self.data_count[0]
|
||||||
|
if current_count == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 计算所有向量的相似度
|
||||||
|
similarities = []
|
||||||
|
for i in range(current_count):
|
||||||
|
data = self.get_data(i)
|
||||||
|
if 'vector' in data:
|
||||||
|
vector = data['vector']
|
||||||
|
# 计算余弦相似度
|
||||||
|
dot_product = np.dot(vector, query_vector)
|
||||||
|
norm_vector = np.linalg.norm(vector)
|
||||||
|
norm_query = np.linalg.norm(query_vector)
|
||||||
|
similarity = dot_product / (norm_vector * norm_query + 1e-10)
|
||||||
|
similarities.append((similarity, data))
|
||||||
|
|
||||||
|
# 筛选出相似度大于阈值的向量
|
||||||
|
above_threshold = [(sim, data) for sim, data in similarities if sim > similarity_threshold]
|
||||||
|
|
||||||
|
if len(above_threshold) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 按相似度降序排序
|
||||||
|
above_threshold.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
# 返回top_k个结果
|
||||||
|
return above_threshold[:top_k]
|
||||||
|
|
||||||
|
def vectorize_financial_data(financial_data: Dict[str, Any]) -> np.ndarray:
|
||||||
|
"""将财务指标数据转化为1024维向量"""
|
||||||
|
# 实际应用中替换为您的向量化算法
|
||||||
|
vector = np.random.rand(1024).astype(np.float32) # 随机生成示例向量
|
||||||
|
return vector
|
||||||
|
|
||||||
|
def worker_process(data_chunk: List[Dict[str, Any]], storage: VectorStorage) -> None:
|
||||||
|
"""工作进程函数"""
|
||||||
|
data_list = []
|
||||||
|
|
||||||
|
for data in data_chunk:
|
||||||
|
# 生成向量
|
||||||
|
vector = vectorize_financial_data(data)
|
||||||
|
|
||||||
|
# 准备完整数据(包含向量和元数据)
|
||||||
|
full_data = {
|
||||||
|
'table_num': data.get('table_num', -1),
|
||||||
|
'table_index': data.get('table_index', -1),
|
||||||
|
'measure_name': data.get('measure_name', 'unknown'),
|
||||||
|
'measure_value': data.get('measure_value', 'unknown'),
|
||||||
|
'measure_unit': data.get('measure_unit', 'unknown'),
|
||||||
|
'file_id': data.get('file_id', 'unknown'),
|
||||||
|
'vector': vector # 将向量作为数据的一部分
|
||||||
|
}
|
||||||
|
data_list.append(full_data)
|
||||||
|
|
||||||
|
# 添加数据到存储
|
||||||
|
storage.add_data(data_list)
|
||||||
|
print(f"Processed {len(data_chunk)} financial data entries. Total data now: {storage.data_count[0]}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 初始化
|
||||||
|
dim = 1024
|
||||||
|
max_vectors = 100
|
||||||
|
shared_storage = VectorStorage(dim, max_vectors)
|
||||||
|
|
||||||
|
# 准备数据 - 模拟您的数据格式
|
||||||
|
financial_data = []
|
||||||
|
for i in range(100):
|
||||||
|
measure_data = {
|
||||||
|
'table_num': i // 10 + 1, # 模拟10个表
|
||||||
|
'table_index': i % 10,
|
||||||
|
'measure_name': f'measure_{i}',
|
||||||
|
'measure_value': i * 1.5,
|
||||||
|
'measure_unit': 'unit',
|
||||||
|
'file_id': f'file_{i // 20}' # 模拟5个文件
|
||||||
|
}
|
||||||
|
financial_data.append(measure_data)
|
||||||
|
|
||||||
|
# 分割数据
|
||||||
|
num_processes = 4
|
||||||
|
chunk_size = len(financial_data) // num_processes
|
||||||
|
data_chunks = [
|
||||||
|
financial_data[i*chunk_size:(i+1)*chunk_size]
|
||||||
|
for i in range(num_processes)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 处理可能的不整除情况
|
||||||
|
if len(financial_data) % num_processes != 0:
|
||||||
|
data_chunks[-1].extend(financial_data[num_processes * chunk_size:])
|
||||||
|
|
||||||
|
# 启动多进程
|
||||||
|
processes = []
|
||||||
|
for chunk in data_chunks:
|
||||||
|
p = Process(target=worker_process, args=(chunk, shared_storage))
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
# 等待所有进程完成
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
print(f"\nAll processes finished. Total data stored: {shared_storage.data_count[0]}")
|
||||||
|
|
||||||
|
# 示例查询
|
||||||
|
try:
|
||||||
|
# 创建一个查询向量
|
||||||
|
query_vector = np.random.rand(1024).astype(np.float32)
|
||||||
|
|
||||||
|
# 使用相似度阈值进行查询
|
||||||
|
similarity_threshold = 0.5
|
||||||
|
similar_items = shared_storage.search_similar_vectors(
|
||||||
|
query_vector, top_k=5, similarity_threshold=similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nTop vectors with similarity > {similarity_threshold}:")
|
||||||
|
if similar_items:
|
||||||
|
for score, data in similar_items:
|
||||||
|
print(f"Score: {score:.4f}, Table: {data['table_num']}, Index: {data['table_index']}, Name: {data['measure_name']}")
|
||||||
|
else:
|
||||||
|
print("No vectors found above the similarity threshold.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during query: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
Loading…
Reference in New Issue