内存向量替代milvus, 初版未测试备份

This commit is contained in:
朱思南 2025-09-09 18:59:30 +08:00
parent edbcc245a6
commit e41d0c8dbc
5 changed files with 349 additions and 71 deletions

View File

@ -15,6 +15,7 @@ import threading
from Mil_unit import create_partition_by_hour
from datetime import datetime, timedelta
from log_config import logger
from vector_storage import VectorStorage
app = FastAPI()
cpu_count = 4
@ -110,18 +111,14 @@ def run_job():
logger.info(f'开始表格指标抽取任务ID:{file_id}')
time_start = time.time()
# 获取当前时间
now = datetime.now()
current_hour = now.strftime("%Y%m%d%H")
partition_name = f"partition_{current_hour}"
# 判断是否创建新的分区
create_partition_by_hour(current_hour)
time.sleep(10)
# 初始化向量存储类
dim = 1024
max_vectors = 5000
shared_storage = VectorStorage(dim, max_vectors)
# 判断是否为3季报
if db_service.file_type_check_v2(file_id) == 3:
main.start_table_measure_job(file_id,partition_name)
main.start_table_measure_job(file_id, shared_storage)
time_start_end = time.time()
process_time = time_start_end - time_start
db_service.process_time(file_id,'2',process_time,time_start,time_start_end)
@ -131,7 +128,7 @@ def run_job():
logger.info(f'启动这个指标归一化任务ID-修改测试:{file_id}')
time_update = time.time()
main.update_measure_data(file_id,file_path,parent_table_pages,partition_name)
main.update_measure_data(file_id,file_path,parent_table_pages,shared_storage)
logger.info(f'归一化完成任务ID:{file_id}')
end_time = time.time()
@ -141,7 +138,7 @@ def run_job():
db_service.process_time(file_id,'3',process_time,time_update,time_update_end)
# 不是三季报就直接按照年报和半年报走
else:
main.start_table_measure_job(file_id,partition_name)
main.start_table_measure_job(file_id, shared_storage)
time_start_end = time.time()
process_time = time_start_end - time_start
db_service.process_time(file_id,'2',process_time,time_start,time_start_end)
@ -151,7 +148,7 @@ def run_job():
logger.info(f'启动这个指标归一化任务ID-修改测试:{file_id}')
time_update = time.time()
main.update_measure_data(file_id,file_path,parent_table_pages,partition_name)
main.update_measure_data(file_id,file_path,parent_table_pages, shared_storage)
logger.info(f'归一化完成任务ID:{file_id}')
end_time = time.time()

View File

@ -11,6 +11,7 @@ import mysql.connector
import threading
import redis
from log_config import logger
from vector_storage import VectorStorage
@ -298,7 +299,7 @@ def update_ori_measure(conn,cursor,file_id):
end_time = time.time()
logger.info(f"更新数据写入 {(end_time - start_time):.2f} 秒。")
def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name,):
def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage:VectorStorage,):
create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info(f'Run task {record_range} ({os.getpid()})...')
@ -389,9 +390,6 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re
table_index_array = []
measure_index_array = []
client = MilvusClient(
uri=MILVUS_CLIENT,
)
try:
for index in range(int(record_start),int(record_end)):
@ -405,32 +403,17 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re
measure_list = ast.literal_eval(measure_vector)
data = [measure_list]
filter_str = 'file_id == "'+file_id+'"'
res = client.search(
collection_name="pdf_measure_v4", # Replace with the actual name of your collection
# Replace with your query vector
data=data,
limit=3, # Max. number of search results to return
search_params={"metric_type": "COSINE", "params": {}}, # Search parameters
output_fields=["measure_name", "measure_value", "table_num", "table_index", "measure_unit"],
filter=filter_str,
partition_name=partition_name
)
# Convert the output to a formatted JSON string
# for i in range(len(res[0])):
for i in range(len(res[0])):
vector_distance = float(res[0][i]["distance"])
pdf_measure = res[0][i]["entity"]["measure_name"]
measure_value = res[0][i]["entity"]["measure_value"]
table_num = res[0][i]["entity"]["table_num"]
table_index = res[0][i]["entity"]["table_index"]
unit = res[0][i]["entity"]["measure_unit"]
# if pdf_measure == '2023年6月30日货币资金合计':
# print(f'{pdf_measure} 的相似度是 {vector_distance},其值为 {measure_value},页码在 {table_num}')
res = vector_storage.search_similar_vectors(measure_vector, top_k=3, similarity_threshold=distance)
# 返回格式为 [(similarity, {metadata}), ...]
for i in range(len(res)):
vector_distance = float(res[i][0])
pdf_measure = res[i][1]["measure_name"]
measure_value = float(res[i][1]["measure_value"])
table_num = int(res[i][1]["table_num"])
table_index = res[i][1]["table_index"]
unit = res[i][1]["measure_unit"]
#先过滤页码为0的情况暂时不知道原因
if table_num == 0:
@ -569,12 +552,11 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re
logger.info(e)
finally:
parent_table_pages = []
client.close()
redis_client.close()
cursor.close()
conn.close()
def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name, partition_name):
def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name, vector_storage: VectorStorage):
select_year_select = f"""select report_type,year from report_check where id = {file_id}"""
cursor.execute(select_year_select)
record_select = cursor.fetchall()
@ -619,7 +601,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil
records_range_parts = utils.get_range(len(records),MEASURE_COUNT)
processes = []
for record_range in records_range_parts:
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array, partition_name))
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array, vector_storage))
processes.append(p)
p.start()
elif report_type == 2:
@ -633,7 +615,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil
records_range_parts = utils.get_range(len(records),MEASURE_COUNT)
processes = []
for record_range in records_range_parts:
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name))
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage))
processes.append(p)
p.start()
elif report_type == 3:
@ -647,7 +629,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil
records_range_parts = utils.get_range(len(records),MEASURE_COUNT)
processes = []
for record_range in records_range_parts:
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name))
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage))
processes.append(p)
p.start()
# p.apply_async(insert_table_from_vector_mul, args=(parent_table_pages,file_id,file_name,records,record_range,))
@ -662,7 +644,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil
records_range_parts = utils.get_range(len(records),MEASURE_COUNT)
processes = []
for record_range in records_range_parts:
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name))
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage))
processes.append(p)
p.start()
@ -767,7 +749,7 @@ def insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_
start_time = time.time()
def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn):
def insert_measure_data_to_milvus(vector_storage: VectorStorage,table_info,cursor,conn):
insert_query = '''
INSERT INTO measure_parse_process
(file_id, page_num, content)
@ -806,9 +788,10 @@ def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn):
measure_unit = measure['measure_unit']
if re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', measure_value) and any(key_word in measure_name for key_word in measure_name_keywords):
vector_obj = utils.embed_with_str(measure_name_1)
vector = vector_obj.output["embeddings"][0]["embedding"]
# vector_obj = utils.embed_with_str(measure_name_1)
# vector = vector_obj.output["embeddings"][0]["embedding"]
vector = utils.embed_with_str_local(measure_name_1)
measure_data = {}
measure_data['vector'] = vector
measure_data['table_num'] = int(table_num)
@ -836,8 +819,10 @@ def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn):
if crease_type == '减少' or crease_type == '下降':
measure_value = f'-{match.group(2)}'
vector_obj = utils.embed_with_str(measure_name_1)
vector = vector_obj.output["embeddings"][0]["embedding"]
# vector_obj = utils.embed_with_str(measure_name_1)
# vector = vector_obj.output["embeddings"][0]["embedding"]
vector = utils.embed_with_str_local(measure_name_1)
measure_data = {}
measure_data['vector'] = vector
measure_data['table_num'] = int(table_num)
@ -857,11 +842,8 @@ def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn):
else:
pass#print(f"数据值的格式错误:{measure_value}。或者字段名不在名单内{measure_name}")
res = client.insert(
collection_name="pdf_measure_v4",
data=data,
partition_name=partition_name
)
vector_storage.add_data(data)
logger.info(f"向量插入结束")
except Exception as e:

View File

@ -3,7 +3,7 @@ import re
from multiprocessing import Pool
import os, time, random
import json
from config import MILVUS_CLIENT,MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB,MEASURE_COUNT,MYSQL_HOST_APP,MYSQL_USER_APP,MYSQL_PASSWORD_APP,MYSQL_DB_APP
from config import MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB,MEASURE_COUNT,MYSQL_HOST_APP,MYSQL_USER_APP,MYSQL_PASSWORD_APP,MYSQL_DB_APP
from datetime import datetime
# 读取PDF
import PyPDF2
@ -21,8 +21,8 @@ import numpy as np
from multiprocessing import Process
from config import REDIS_HOST,REDIS_PORT,REDIS_PASSWORD
import redis
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection,MilvusClient
from log_config import logger
from vector_storage import VectorStorage
'''
已知发现问题
@ -715,7 +715,7 @@ def get_table_text_info(file_id,line_text,page_num,table_index):
return table_info
# 读取pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
def get_table_measure(file_id, pdf_tables, record_range,partition_name,):
def get_table_measure(file_id, pdf_tables, record_range, vector_storage:VectorStorage,):
"""
:return: pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
"""
@ -743,10 +743,6 @@ def get_table_measure(file_id, pdf_tables, record_range,partition_name,):
report_type = record_select[0][0]
report_year = record_select[0][1]
client = MilvusClient(
uri=MILVUS_CLIENT,
)
logger.info('提取指标任务 %s (%s)...' % (record_range, os.getpid()))
start = time.time()
record_start = record_range.split('-')[0]
@ -861,7 +857,7 @@ def get_table_measure(file_id, pdf_tables, record_range,partition_name,):
data_dict["page_num"] = f"{str(t['page_num'])}_{str(t['table_index'])}"
data_dict['file_id'] = file_id
measure_obj.append(data_dict)
db_service.insert_measure_data_to_milvus(client,partition_name,measure_obj,cursor_app,conn_app)
db_service.insert_measure_data_to_milvus(vector_storage,measure_obj,cursor_app,conn_app)
except Exception as e:
logger.info(f"循环获取表格数据这里报错了,数据是{t['data']},位置在{index}")
logger.info(f"错误是:{e}")
@ -903,7 +899,7 @@ def dispatch_job(job_info):
#指标归一化处理
def update_measure_data(file_id,file_path,parent_table_pages,partition_name):
def update_measure_data(file_id,file_path,parent_table_pages,vector_storage:VectorStorage):
conn = mysql.connector.connect(
host = MYSQL_HOST,
user = MYSQL_USER,
@ -925,7 +921,7 @@ def update_measure_data(file_id,file_path,parent_table_pages,partition_name):
cursor_app = conn_app.cursor(buffered=True)
logger.info(f'目录黑名单为:{parent_table_pages}')
db_service.delete_to_run(conn,cursor,file_id)
db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path, partition_name)
db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path, vector_storage)
# #指标归一化处理
db_service.update_ori_measure(conn,cursor,file_id)
@ -1047,7 +1043,8 @@ def merge_consecutive_arrays_v1(pdf_info):
merged_objects.append(temp_array)
return merged_objects
def start_table_measure_job(file_id,partition_name):
def start_table_measure_job(file_id, vector_storage: VectorStorage):
conn_app = mysql.connector.connect(
host = MYSQL_HOST_APP,
user = MYSQL_USER_APP,
@ -1083,7 +1080,7 @@ def start_table_measure_job(file_id,partition_name):
for record_range in records_range_parts:
p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,partition_name,))
p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,vector_storage,))
processes.append(p)
p.start()

View File

@ -14,6 +14,18 @@ from config import api_key
import logging
logger = logging.getLogger(__name__)
# Requires transformers>=4.51.0
import torch
import torch.nn.functional as F
from torch import Tensor
from modelscope import AutoTokenizer, AutoModel
from http import HTTPStatus
dashscope.api_key = api_key
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
local_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
dashscope.api_key = api_key
@ -23,6 +35,43 @@ def get_md5(str):
m.update(str.encode('utf-8'))
return m.hexdigest()
def embed_with_str_local(input):
# Tokenize the input text
max_length = 8192
input_texts = [input] if isinstance(input, str) else input
# Tokenize the input texts
batch_dict = tokenizer(
input_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
batch_dict.to(local_model.device)
outputs = local_model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
# # 补零到1536维
# zeros = torch.zeros(embeddings.shape[0], 1536 - 1024).to(embeddings.device)
# embeddings = torch.cat([embeddings, zeros], dim=1)
ret = embeddings.tolist()[0]
return ret
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def embed_with_str(input):
retry = 0
max_retry = 5
@ -532,7 +581,24 @@ def check_black_table_list(data):
if __name__ == '__main__':
logger.debug(len('我是我'))
# logger.debug(len('我是我'))
vector_a = embed_with_str_local("扣除非经常性损益后归属于公司普通股股东的净利润加权平均净资产收益率变动比例")
vector_b = embed_with_str_local("扣非加权平均净资产收益率同比变动")
print(type(vector_a))
similarity = cosine_similarity(vector_a, vector_b)
print(f"余弦相似度: {similarity}")
print(f"维度 {len(vector_a)} {len(vector_b)}")
print(vector_a[:10])
print(vector_b[:10])
vector_a = embed_with_str("扣除非经常性损益后归属于公司普通股股东的净利润加权平均净资产收益率变动比例").output["embeddings"][0]["embedding"]
vector_b = embed_with_str("扣非加权平均净资产收益率同比变动").output["embeddings"][0]["embedding"]
print(type(vector_a))
similarity = cosine_similarity(vector_a, vector_b)
print(f"余弦相似度: {similarity}")
print(f"维度 {len(vector_a)} {len(vector_b)}")
print(vector_a[:10])
print(vector_b[:10])
# logger.debug(under_non_alpha_ratio('202水电费水电费水电费是的205月'))
# title = '母公司财务报表主要项目注释'

View File

@ -0,0 +1,236 @@
import numpy as np
from multiprocessing import Process, Array, RLock, cpu_count
from typing import List, Dict, Any, Tuple, Optional
import heapq
import ctypes
import json
import base64
class VectorStorage:
"""内存向量存储与检索类"""
def __init__(self, dim: int = 1024, max_vectors: int = 10000):
self.dim = dim
self.max_vectors = max_vectors
# 使用共享内存存储所有数据(向量和元数据)
# 我们将向量和元数据一起序列化存储
self.data_strings = Array(ctypes.c_char, max_vectors * 16384, lock=False) # 假设每条数据最大16384字节
self.data_offsets = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的偏移量
self.data_lengths = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的长度
# 使用共享计数器记录数据数量
self.data_count = Array(ctypes.c_int, 1, lock=True)
self.lock = RLock()
def get_data(self, index: int) -> Dict[str, Any]:
"""从共享内存中获取数据(包含向量和元数据)"""
if index >= self.data_count[0]:
return {}
offset = self.data_offsets[index]
length = self.data_lengths[index]
# 从共享内存中提取字节串
data_bytes = bytes(self.data_strings[offset:offset+length])
# 反序列化为字典
try:
data_dict = json.loads(data_bytes.decode('utf-8'))
# 将base64编码的向量转换回numpy数组
if 'vector_base64' in data_dict:
vector_bytes = base64.b64decode(data_dict['vector_base64'])
data_dict['vector'] = np.frombuffer(vector_bytes, dtype=np.float32)
del data_dict['vector_base64'] # 删除base64编码的向量
return data_dict
except:
return {
'table_num': -1,
'table_index': -1,
'measure_name': 'unknown',
'measure_value': 'unknown',
'measure_unit': 'unknown',
'file_id': 'unknown',
'vector': np.zeros(self.dim, dtype=np.float32)
}
def set_data(self, index: int, data: Dict[str, Any]) -> None:
"""将数据(包含向量和元数据)存储到共享内存中"""
# 复制数据字典,避免修改原始数据
data_copy = data.copy()
# 将numpy向量转换为base64编码的字符串
if 'vector' in data_copy and isinstance(data_copy['vector'], np.ndarray):
vector_bytes = data_copy['vector'].astype(np.float32).tobytes()
data_copy['vector_base64'] = base64.b64encode(vector_bytes).decode('utf-8')
del data_copy['vector'] # 删除原始向量
# 序列化为JSON字符串
data_str = json.dumps(data_copy)
data_bytes = data_str.encode('utf-8')
# 计算偏移量
offset = 0
if index > 0:
offset = self.data_offsets[index-1] + self.data_lengths[index-1]
# 检查是否有足够空间
if offset + len(data_bytes) > len(self.data_strings):
raise ValueError("Not enough space for data")
# 存储到共享内存
self.data_offsets[index] = offset
self.data_lengths[index] = len(data_bytes)
# 将字节数据复制到共享内存
for i, byte in enumerate(data_bytes):
self.data_strings[offset + i] = byte
def add_data(self, data_list: List[Dict[str, Any]]) -> None:
"""添加数据(包含向量和元数据)到存储"""
num_new = len(data_list)
with self.data_count.get_lock():
current_count = self.data_count[0]
if current_count + num_new > self.max_vectors:
raise ValueError("Exceeded maximum data storage capacity.")
# 将数据添加到共享内存
for i, data in enumerate(data_list):
self.set_data(current_count + i, data)
# 更新计数器
self.data_count[0] = current_count + num_new
def search_similar_vectors(self, query_vector: np.ndarray,
top_k: int = 10, similarity_threshold: float = 0.0) -> List[Tuple[float, Dict[str, Any]]]:
"""搜索相似向量"""
query_vector = query_vector.astype(np.float32)
current_count = self.data_count[0]
if current_count == 0:
return []
# 计算所有向量的相似度
similarities = []
for i in range(current_count):
data = self.get_data(i)
if 'vector' in data:
vector = data['vector']
# 计算余弦相似度
dot_product = np.dot(vector, query_vector)
norm_vector = np.linalg.norm(vector)
norm_query = np.linalg.norm(query_vector)
similarity = dot_product / (norm_vector * norm_query + 1e-10)
similarities.append((similarity, data))
# 筛选出相似度大于阈值的向量
above_threshold = [(sim, data) for sim, data in similarities if sim > similarity_threshold]
if len(above_threshold) == 0:
return []
# 按相似度降序排序
above_threshold.sort(key=lambda x: x[0], reverse=True)
# 返回top_k个结果
return above_threshold[:top_k]
def vectorize_financial_data(financial_data: Dict[str, Any]) -> np.ndarray:
"""将财务指标数据转化为1024维向量"""
# 实际应用中替换为您的向量化算法
vector = np.random.rand(1024).astype(np.float32) # 随机生成示例向量
return vector
def worker_process(data_chunk: List[Dict[str, Any]], storage: VectorStorage) -> None:
"""工作进程函数"""
data_list = []
for data in data_chunk:
# 生成向量
vector = vectorize_financial_data(data)
# 准备完整数据(包含向量和元数据)
full_data = {
'table_num': data.get('table_num', -1),
'table_index': data.get('table_index', -1),
'measure_name': data.get('measure_name', 'unknown'),
'measure_value': data.get('measure_value', 'unknown'),
'measure_unit': data.get('measure_unit', 'unknown'),
'file_id': data.get('file_id', 'unknown'),
'vector': vector # 将向量作为数据的一部分
}
data_list.append(full_data)
# 添加数据到存储
storage.add_data(data_list)
print(f"Processed {len(data_chunk)} financial data entries. Total data now: {storage.data_count[0]}")
if __name__ == '__main__':
# 初始化
dim = 1024
max_vectors = 100
shared_storage = VectorStorage(dim, max_vectors)
# 准备数据 - 模拟您的数据格式
financial_data = []
for i in range(100):
measure_data = {
'table_num': i // 10 + 1, # 模拟10个表
'table_index': i % 10,
'measure_name': f'measure_{i}',
'measure_value': i * 1.5,
'measure_unit': 'unit',
'file_id': f'file_{i // 20}' # 模拟5个文件
}
financial_data.append(measure_data)
# 分割数据
num_processes = 4
chunk_size = len(financial_data) // num_processes
data_chunks = [
financial_data[i*chunk_size:(i+1)*chunk_size]
for i in range(num_processes)
]
# 处理可能的不整除情况
if len(financial_data) % num_processes != 0:
data_chunks[-1].extend(financial_data[num_processes * chunk_size:])
# 启动多进程
processes = []
for chunk in data_chunks:
p = Process(target=worker_process, args=(chunk, shared_storage))
processes.append(p)
p.start()
# 等待所有进程完成
for p in processes:
p.join()
print(f"\nAll processes finished. Total data stored: {shared_storage.data_count[0]}")
# 示例查询
try:
# 创建一个查询向量
query_vector = np.random.rand(1024).astype(np.float32)
# 使用相似度阈值进行查询
similarity_threshold = 0.5
similar_items = shared_storage.search_similar_vectors(
query_vector, top_k=5, similarity_threshold=similarity_threshold
)
print(f"\nTop vectors with similarity > {similarity_threshold}:")
if similar_items:
for score, data in similar_items:
print(f"Score: {score:.4f}, Table: {data['table_num']}, Index: {data['table_index']}, Name: {data['measure_name']}")
else:
print("No vectors found above the similarity threshold.")
except Exception as e:
print(f"Error during query: {e}")
import traceback
traceback.print_exc()