diff --git a/zzb_data_prod/app.py b/zzb_data_prod/app.py index 71d4e2f..ac23805 100644 --- a/zzb_data_prod/app.py +++ b/zzb_data_prod/app.py @@ -15,6 +15,7 @@ import threading from Mil_unit import create_partition_by_hour from datetime import datetime, timedelta from log_config import logger +from vector_storage import VectorStorage app = FastAPI() cpu_count = 4 @@ -110,18 +111,14 @@ def run_job(): logger.info(f'开始表格指标抽取,任务ID:{file_id}') time_start = time.time() - - # 获取当前时间 - now = datetime.now() - current_hour = now.strftime("%Y%m%d%H") - partition_name = f"partition_{current_hour}" - # 判断是否创建新的分区 - create_partition_by_hour(current_hour) - time.sleep(10) + # 初始化向量存储类 + dim = 1024 + max_vectors = 5000 + shared_storage = VectorStorage(dim, max_vectors) # 判断是否为3季报 if db_service.file_type_check_v2(file_id) == 3: - main.start_table_measure_job(file_id,partition_name) + main.start_table_measure_job(file_id, shared_storage) time_start_end = time.time() process_time = time_start_end - time_start db_service.process_time(file_id,'2',process_time,time_start,time_start_end) @@ -131,7 +128,7 @@ def run_job(): logger.info(f'启动这个指标归一化任务ID-修改测试:{file_id}') time_update = time.time() - main.update_measure_data(file_id,file_path,parent_table_pages,partition_name) + main.update_measure_data(file_id,file_path,parent_table_pages,shared_storage) logger.info(f'归一化完成任务ID:{file_id}') end_time = time.time() @@ -141,7 +138,7 @@ def run_job(): db_service.process_time(file_id,'3',process_time,time_update,time_update_end) # 不是三季报就直接按照年报和半年报走 else: - main.start_table_measure_job(file_id,partition_name) + main.start_table_measure_job(file_id, shared_storage) time_start_end = time.time() process_time = time_start_end - time_start db_service.process_time(file_id,'2',process_time,time_start,time_start_end) @@ -151,7 +148,7 @@ def run_job(): logger.info(f'启动这个指标归一化任务ID-修改测试:{file_id}') time_update = time.time() - main.update_measure_data(file_id,file_path,parent_table_pages,partition_name) + main.update_measure_data(file_id,file_path,parent_table_pages, shared_storage) logger.info(f'归一化完成任务ID:{file_id}') end_time = time.time() diff --git a/zzb_data_prod/db_service.py b/zzb_data_prod/db_service.py index ece8be1..f3738b5 100644 --- a/zzb_data_prod/db_service.py +++ b/zzb_data_prod/db_service.py @@ -11,6 +11,7 @@ import mysql.connector import threading import redis from log_config import logger +from vector_storage import VectorStorage @@ -298,7 +299,7 @@ def update_ori_measure(conn,cursor,file_id): end_time = time.time() logger.info(f"更新数据写入 {(end_time - start_time):.2f} 秒。") -def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name,): +def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage:VectorStorage,): create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") logger.info(f'Run task {record_range} ({os.getpid()})...') @@ -389,9 +390,6 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re table_index_array = [] measure_index_array = [] - client = MilvusClient( - uri=MILVUS_CLIENT, - ) try: for index in range(int(record_start),int(record_end)): @@ -405,32 +403,17 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re measure_list = ast.literal_eval(measure_vector) - data = [measure_list] + filter_str = 'file_id == "'+file_id+'"' - res = client.search( - collection_name="pdf_measure_v4", # Replace with the actual name of your collection - # Replace with your query vector - data=data, - limit=3, # Max. number of search results to return - search_params={"metric_type": "COSINE", "params": {}}, # Search parameters - output_fields=["measure_name", "measure_value", "table_num", "table_index", "measure_unit"], - filter=filter_str, - partition_name=partition_name - ) - - # Convert the output to a formatted JSON string - # for i in range(len(res[0])): - - for i in range(len(res[0])): - - vector_distance = float(res[0][i]["distance"]) - pdf_measure = res[0][i]["entity"]["measure_name"] - measure_value = res[0][i]["entity"]["measure_value"] - table_num = res[0][i]["entity"]["table_num"] - table_index = res[0][i]["entity"]["table_index"] - unit = res[0][i]["entity"]["measure_unit"] - # if pdf_measure == '2023年6月30日货币资金合计': - # print(f'{pdf_measure} 的相似度是 {vector_distance},其值为 {measure_value},页码在 {table_num}') + res = vector_storage.search_similar_vectors(measure_vector, top_k=3, similarity_threshold=distance) + # 返回格式为 [(similarity, {metadata}), ...] + for i in range(len(res)): + vector_distance = float(res[i][0]) + pdf_measure = res[i][1]["measure_name"] + measure_value = float(res[i][1]["measure_value"]) + table_num = int(res[i][1]["table_num"]) + table_index = res[i][1]["table_index"] + unit = res[i][1]["measure_unit"] #先过滤页码为0的情况,暂时不知道原因 if table_num == 0: @@ -569,12 +552,11 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re logger.info(e) finally: parent_table_pages = [] - client.close() redis_client.close() cursor.close() conn.close() -def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name, partition_name): +def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name, vector_storage: VectorStorage): select_year_select = f"""select report_type,year from report_check where id = {file_id}""" cursor.execute(select_year_select) record_select = cursor.fetchall() @@ -619,7 +601,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil records_range_parts = utils.get_range(len(records),MEASURE_COUNT) processes = [] for record_range in records_range_parts: - p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array, partition_name)) + p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array, vector_storage)) processes.append(p) p.start() elif report_type == 2: @@ -633,7 +615,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil records_range_parts = utils.get_range(len(records),MEASURE_COUNT) processes = [] for record_range in records_range_parts: - p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name)) + p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage)) processes.append(p) p.start() elif report_type == 3: @@ -647,7 +629,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil records_range_parts = utils.get_range(len(records),MEASURE_COUNT) processes = [] for record_range in records_range_parts: - p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name)) + p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage)) processes.append(p) p.start() # p.apply_async(insert_table_from_vector_mul, args=(parent_table_pages,file_id,file_name,records,record_range,)) @@ -662,7 +644,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil records_range_parts = utils.get_range(len(records),MEASURE_COUNT) processes = [] for record_range in records_range_parts: - p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name)) + p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,vector_storage)) processes.append(p) p.start() @@ -767,7 +749,7 @@ def insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_ start_time = time.time() -def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn): +def insert_measure_data_to_milvus(vector_storage: VectorStorage,table_info,cursor,conn): insert_query = ''' INSERT INTO measure_parse_process (file_id, page_num, content) @@ -806,9 +788,10 @@ def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn): measure_unit = measure['measure_unit'] if re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', measure_value) and any(key_word in measure_name for key_word in measure_name_keywords): - vector_obj = utils.embed_with_str(measure_name_1) - - vector = vector_obj.output["embeddings"][0]["embedding"] + # vector_obj = utils.embed_with_str(measure_name_1) + # vector = vector_obj.output["embeddings"][0]["embedding"] + + vector = utils.embed_with_str_local(measure_name_1) measure_data = {} measure_data['vector'] = vector measure_data['table_num'] = int(table_num) @@ -836,8 +819,10 @@ def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn): if crease_type == '减少' or crease_type == '下降': measure_value = f'-{match.group(2)}' - vector_obj = utils.embed_with_str(measure_name_1) - vector = vector_obj.output["embeddings"][0]["embedding"] + # vector_obj = utils.embed_with_str(measure_name_1) + # vector = vector_obj.output["embeddings"][0]["embedding"] + + vector = utils.embed_with_str_local(measure_name_1) measure_data = {} measure_data['vector'] = vector measure_data['table_num'] = int(table_num) @@ -857,11 +842,8 @@ def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn): else: pass#print(f"数据值的格式错误:{measure_value}。或者字段名不在名单内{measure_name}") - res = client.insert( - collection_name="pdf_measure_v4", - data=data, - partition_name=partition_name - ) + + vector_storage.add_data(data) logger.info(f"向量插入结束") except Exception as e: diff --git a/zzb_data_prod/main.py b/zzb_data_prod/main.py index 3bef28f..2d5ee00 100644 --- a/zzb_data_prod/main.py +++ b/zzb_data_prod/main.py @@ -3,7 +3,7 @@ import re from multiprocessing import Pool import os, time, random import json -from config import MILVUS_CLIENT,MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB,MEASURE_COUNT,MYSQL_HOST_APP,MYSQL_USER_APP,MYSQL_PASSWORD_APP,MYSQL_DB_APP +from config import MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB,MEASURE_COUNT,MYSQL_HOST_APP,MYSQL_USER_APP,MYSQL_PASSWORD_APP,MYSQL_DB_APP from datetime import datetime # 读取PDF import PyPDF2 @@ -21,8 +21,8 @@ import numpy as np from multiprocessing import Process from config import REDIS_HOST,REDIS_PORT,REDIS_PASSWORD import redis -from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection,MilvusClient from log_config import logger +from vector_storage import VectorStorage ''' 已知发现问题: @@ -715,7 +715,7 @@ def get_table_text_info(file_id,line_text,page_num,table_index): return table_info # 读取pdf中的表格,并将表格中指标和表头合并,eg: 2022年1季度营业收入为xxxxx -def get_table_measure(file_id, pdf_tables, record_range,partition_name,): +def get_table_measure(file_id, pdf_tables, record_range, vector_storage:VectorStorage,): """ :return: pdf中的表格,并将表格中指标和表头合并,eg: 2022年1季度营业收入为xxxxx """ @@ -743,10 +743,6 @@ def get_table_measure(file_id, pdf_tables, record_range,partition_name,): report_type = record_select[0][0] report_year = record_select[0][1] - client = MilvusClient( - uri=MILVUS_CLIENT, - ) - logger.info('提取指标任务 %s (%s)...' % (record_range, os.getpid())) start = time.time() record_start = record_range.split('-')[0] @@ -861,7 +857,7 @@ def get_table_measure(file_id, pdf_tables, record_range,partition_name,): data_dict["page_num"] = f"{str(t['page_num'])}_{str(t['table_index'])}" data_dict['file_id'] = file_id measure_obj.append(data_dict) - db_service.insert_measure_data_to_milvus(client,partition_name,measure_obj,cursor_app,conn_app) + db_service.insert_measure_data_to_milvus(vector_storage,measure_obj,cursor_app,conn_app) except Exception as e: logger.info(f"循环获取表格数据这里报错了,数据是{t['data']},位置在{index}") logger.info(f"错误是:{e}") @@ -903,7 +899,7 @@ def dispatch_job(job_info): #指标归一化处理 -def update_measure_data(file_id,file_path,parent_table_pages,partition_name): +def update_measure_data(file_id,file_path,parent_table_pages,vector_storage:VectorStorage): conn = mysql.connector.connect( host = MYSQL_HOST, user = MYSQL_USER, @@ -925,7 +921,7 @@ def update_measure_data(file_id,file_path,parent_table_pages,partition_name): cursor_app = conn_app.cursor(buffered=True) logger.info(f'目录黑名单为:{parent_table_pages}') db_service.delete_to_run(conn,cursor,file_id) - db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path, partition_name) + db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path, vector_storage) # #指标归一化处理 db_service.update_ori_measure(conn,cursor,file_id) @@ -1047,7 +1043,8 @@ def merge_consecutive_arrays_v1(pdf_info): merged_objects.append(temp_array) return merged_objects -def start_table_measure_job(file_id,partition_name): + +def start_table_measure_job(file_id, vector_storage: VectorStorage): conn_app = mysql.connector.connect( host = MYSQL_HOST_APP, user = MYSQL_USER_APP, @@ -1083,7 +1080,7 @@ def start_table_measure_job(file_id,partition_name): for record_range in records_range_parts: - p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,partition_name,)) + p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,vector_storage,)) processes.append(p) p.start() diff --git a/zzb_data_prod/utils.py b/zzb_data_prod/utils.py index 41f33fd..2d607cd 100644 --- a/zzb_data_prod/utils.py +++ b/zzb_data_prod/utils.py @@ -14,6 +14,18 @@ from config import api_key import logging logger = logging.getLogger(__name__) +# Requires transformers>=4.51.0 +import torch +import torch.nn.functional as F +from torch import Tensor +from modelscope import AutoTokenizer, AutoModel +from http import HTTPStatus + +dashscope.api_key = api_key +# Load the tokenizer and model +tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left') +local_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + dashscope.api_key = api_key @@ -23,6 +35,43 @@ def get_md5(str): m.update(str.encode('utf-8')) return m.hexdigest() +def embed_with_str_local(input): + # Tokenize the input text + max_length = 8192 + input_texts = [input] if isinstance(input, str) else input + + # Tokenize the input texts + batch_dict = tokenizer( + input_texts, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + batch_dict.to(local_model.device) + outputs = local_model(**batch_dict) + embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) + # normalize embeddings + embeddings = F.normalize(embeddings, p=2, dim=1) + + # # 补零到1536维 + # zeros = torch.zeros(embeddings.shape[0], 1536 - 1024).to(embeddings.device) + # embeddings = torch.cat([embeddings, zeros], dim=1) + + ret = embeddings.tolist()[0] + + return ret + +def last_token_pool(last_hidden_states: Tensor, + attention_mask: Tensor) -> Tensor: + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + def embed_with_str(input): retry = 0 max_retry = 5 @@ -532,7 +581,24 @@ def check_black_table_list(data): if __name__ == '__main__': - logger.debug(len('我是我')) + # logger.debug(len('我是我')) + vector_a = embed_with_str_local("扣除非经常性损益后归属于公司普通股股东的净利润加权平均净资产收益率变动比例") + vector_b = embed_with_str_local("扣非加权平均净资产收益率同比变动") + print(type(vector_a)) + similarity = cosine_similarity(vector_a, vector_b) + print(f"余弦相似度: {similarity}") + print(f"维度 {len(vector_a)} {len(vector_b)}") + print(vector_a[:10]) + print(vector_b[:10]) + + vector_a = embed_with_str("扣除非经常性损益后归属于公司普通股股东的净利润加权平均净资产收益率变动比例").output["embeddings"][0]["embedding"] + vector_b = embed_with_str("扣非加权平均净资产收益率同比变动").output["embeddings"][0]["embedding"] + print(type(vector_a)) + similarity = cosine_similarity(vector_a, vector_b) + print(f"余弦相似度: {similarity}") + print(f"维度 {len(vector_a)} {len(vector_b)}") + print(vector_a[:10]) + print(vector_b[:10]) # logger.debug(under_non_alpha_ratio('202水电费水电费水电费是的205月')) # title = '母公司财务报表主要项目注释' diff --git a/zzb_data_prod/vector_storage.py b/zzb_data_prod/vector_storage.py new file mode 100644 index 0000000..9da1ed2 --- /dev/null +++ b/zzb_data_prod/vector_storage.py @@ -0,0 +1,236 @@ +import numpy as np +from multiprocessing import Process, Array, RLock, cpu_count +from typing import List, Dict, Any, Tuple, Optional +import heapq +import ctypes +import json +import base64 + +class VectorStorage: + """内存向量存储与检索类""" + + def __init__(self, dim: int = 1024, max_vectors: int = 10000): + self.dim = dim + self.max_vectors = max_vectors + + # 使用共享内存存储所有数据(向量和元数据) + # 我们将向量和元数据一起序列化存储 + self.data_strings = Array(ctypes.c_char, max_vectors * 16384, lock=False) # 假设每条数据最大16384字节 + self.data_offsets = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的偏移量 + self.data_lengths = Array(ctypes.c_int, max_vectors, lock=False) # 存储每条数据的长度 + + # 使用共享计数器记录数据数量 + self.data_count = Array(ctypes.c_int, 1, lock=True) + + self.lock = RLock() + + def get_data(self, index: int) -> Dict[str, Any]: + """从共享内存中获取数据(包含向量和元数据)""" + if index >= self.data_count[0]: + return {} + + offset = self.data_offsets[index] + length = self.data_lengths[index] + + # 从共享内存中提取字节串 + data_bytes = bytes(self.data_strings[offset:offset+length]) + + # 反序列化为字典 + try: + data_dict = json.loads(data_bytes.decode('utf-8')) + + # 将base64编码的向量转换回numpy数组 + if 'vector_base64' in data_dict: + vector_bytes = base64.b64decode(data_dict['vector_base64']) + data_dict['vector'] = np.frombuffer(vector_bytes, dtype=np.float32) + del data_dict['vector_base64'] # 删除base64编码的向量 + + return data_dict + except: + return { + 'table_num': -1, + 'table_index': -1, + 'measure_name': 'unknown', + 'measure_value': 'unknown', + 'measure_unit': 'unknown', + 'file_id': 'unknown', + 'vector': np.zeros(self.dim, dtype=np.float32) + } + + def set_data(self, index: int, data: Dict[str, Any]) -> None: + """将数据(包含向量和元数据)存储到共享内存中""" + # 复制数据字典,避免修改原始数据 + data_copy = data.copy() + + # 将numpy向量转换为base64编码的字符串 + if 'vector' in data_copy and isinstance(data_copy['vector'], np.ndarray): + vector_bytes = data_copy['vector'].astype(np.float32).tobytes() + data_copy['vector_base64'] = base64.b64encode(vector_bytes).decode('utf-8') + del data_copy['vector'] # 删除原始向量 + + # 序列化为JSON字符串 + data_str = json.dumps(data_copy) + data_bytes = data_str.encode('utf-8') + + # 计算偏移量 + offset = 0 + if index > 0: + offset = self.data_offsets[index-1] + self.data_lengths[index-1] + + # 检查是否有足够空间 + if offset + len(data_bytes) > len(self.data_strings): + raise ValueError("Not enough space for data") + + # 存储到共享内存 + self.data_offsets[index] = offset + self.data_lengths[index] = len(data_bytes) + + # 将字节数据复制到共享内存 + for i, byte in enumerate(data_bytes): + self.data_strings[offset + i] = byte + + def add_data(self, data_list: List[Dict[str, Any]]) -> None: + """添加数据(包含向量和元数据)到存储""" + num_new = len(data_list) + + with self.data_count.get_lock(): + current_count = self.data_count[0] + if current_count + num_new > self.max_vectors: + raise ValueError("Exceeded maximum data storage capacity.") + + # 将数据添加到共享内存 + for i, data in enumerate(data_list): + self.set_data(current_count + i, data) + + # 更新计数器 + self.data_count[0] = current_count + num_new + + def search_similar_vectors(self, query_vector: np.ndarray, + top_k: int = 10, similarity_threshold: float = 0.0) -> List[Tuple[float, Dict[str, Any]]]: + """搜索相似向量""" + query_vector = query_vector.astype(np.float32) + current_count = self.data_count[0] + if current_count == 0: + return [] + + # 计算所有向量的相似度 + similarities = [] + for i in range(current_count): + data = self.get_data(i) + if 'vector' in data: + vector = data['vector'] + # 计算余弦相似度 + dot_product = np.dot(vector, query_vector) + norm_vector = np.linalg.norm(vector) + norm_query = np.linalg.norm(query_vector) + similarity = dot_product / (norm_vector * norm_query + 1e-10) + similarities.append((similarity, data)) + + # 筛选出相似度大于阈值的向量 + above_threshold = [(sim, data) for sim, data in similarities if sim > similarity_threshold] + + if len(above_threshold) == 0: + return [] + + # 按相似度降序排序 + above_threshold.sort(key=lambda x: x[0], reverse=True) + + # 返回top_k个结果 + return above_threshold[:top_k] + +def vectorize_financial_data(financial_data: Dict[str, Any]) -> np.ndarray: + """将财务指标数据转化为1024维向量""" + # 实际应用中替换为您的向量化算法 + vector = np.random.rand(1024).astype(np.float32) # 随机生成示例向量 + return vector + +def worker_process(data_chunk: List[Dict[str, Any]], storage: VectorStorage) -> None: + """工作进程函数""" + data_list = [] + + for data in data_chunk: + # 生成向量 + vector = vectorize_financial_data(data) + + # 准备完整数据(包含向量和元数据) + full_data = { + 'table_num': data.get('table_num', -1), + 'table_index': data.get('table_index', -1), + 'measure_name': data.get('measure_name', 'unknown'), + 'measure_value': data.get('measure_value', 'unknown'), + 'measure_unit': data.get('measure_unit', 'unknown'), + 'file_id': data.get('file_id', 'unknown'), + 'vector': vector # 将向量作为数据的一部分 + } + data_list.append(full_data) + + # 添加数据到存储 + storage.add_data(data_list) + print(f"Processed {len(data_chunk)} financial data entries. Total data now: {storage.data_count[0]}") + +if __name__ == '__main__': + # 初始化 + dim = 1024 + max_vectors = 100 + shared_storage = VectorStorage(dim, max_vectors) + + # 准备数据 - 模拟您的数据格式 + financial_data = [] + for i in range(100): + measure_data = { + 'table_num': i // 10 + 1, # 模拟10个表 + 'table_index': i % 10, + 'measure_name': f'measure_{i}', + 'measure_value': i * 1.5, + 'measure_unit': 'unit', + 'file_id': f'file_{i // 20}' # 模拟5个文件 + } + financial_data.append(measure_data) + + # 分割数据 + num_processes = 4 + chunk_size = len(financial_data) // num_processes + data_chunks = [ + financial_data[i*chunk_size:(i+1)*chunk_size] + for i in range(num_processes) + ] + + # 处理可能的不整除情况 + if len(financial_data) % num_processes != 0: + data_chunks[-1].extend(financial_data[num_processes * chunk_size:]) + + # 启动多进程 + processes = [] + for chunk in data_chunks: + p = Process(target=worker_process, args=(chunk, shared_storage)) + processes.append(p) + p.start() + + # 等待所有进程完成 + for p in processes: + p.join() + + print(f"\nAll processes finished. Total data stored: {shared_storage.data_count[0]}") + + # 示例查询 + try: + # 创建一个查询向量 + query_vector = np.random.rand(1024).astype(np.float32) + + # 使用相似度阈值进行查询 + similarity_threshold = 0.5 + similar_items = shared_storage.search_similar_vectors( + query_vector, top_k=5, similarity_threshold=similarity_threshold + ) + + print(f"\nTop vectors with similarity > {similarity_threshold}:") + if similar_items: + for score, data in similar_items: + print(f"Score: {score:.4f}, Table: {data['table_num']}, Index: {data['table_index']}, Name: {data['measure_name']}") + else: + print("No vectors found above the similarity threshold.") + + except Exception as e: + print(f"Error during query: {e}") + import traceback + traceback.print_exc() \ No newline at end of file