From 452d843a81b1f0c4027e9aa97b42047c78e6fe9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E7=99=BB=E6=95=B0?= <11035577+wu-dengshu@user.noreply.gitee.com> Date: Tue, 3 Dec 2024 12:33:40 +0800 Subject: [PATCH] =?UTF-8?q?milvus=20=E4=BF=AE=E6=94=B9=E5=88=86=E5=8C=BA?= =?UTF-8?q?=E5=86=99=E5=85=A5=E6=B5=8B=E8=AF=95=E7=89=88=E6=9C=ACv1.1=20?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E9=AA=8C=E8=AF=81=E6=AD=A3=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zzb_data_prod/Mil_unit.py | 52 +++++++++++++++++++++----------- zzb_data_prod/app.py | 59 ++++++++++++++++++++----------------- zzb_data_prod/db_service.py | 54 ++++++++++++++++----------------- zzb_data_prod/main.py | 28 ++++++++---------- 4 files changed, 106 insertions(+), 87 deletions(-) diff --git a/zzb_data_prod/Mil_unit.py b/zzb_data_prod/Mil_unit.py index ba33782..869bc5b 100644 --- a/zzb_data_prod/Mil_unit.py +++ b/zzb_data_prod/Mil_unit.py @@ -3,7 +3,7 @@ from config import MILVUS_CLIENT import time from datetime import datetime, timedelta -def create_partition_by_hour(): +def create_partition_by_hour(current_hour): # 连接到 Milvus 服务器 connections.connect("default",uri=MILVUS_CLIENT) # 获取集合 @@ -12,16 +12,18 @@ def create_partition_by_hour(): # 获取当前时间 now = datetime.now() - current_hour = now.strftime("%Y%m%d%H") + # 创建当前小时的分区 partition_name = f"partition_{current_hour}" if not collection.has_partition(partition_name): collection.create_partition(partition_name) print(f"Created partition: {partition_name}") + partition = collection.partition(partition_name) + partition.load() - # 删除前一个小时的分区 - previous_hour = (now - timedelta(hours=1)).strftime("%Y%m%d%H") + # 删除前2个小时的分区 + previous_hour = (now - timedelta(hours=2)).strftime("%Y%m%d%H") previous_partition_name = f"partition_{previous_hour}" if collection.has_partition(previous_partition_name): @@ -30,24 +32,40 @@ def create_partition_by_hour(): collection.drop_partition(previous_partition_name) print(f"Dropped partition: {previous_partition_name}") - partition = collection.partition(partition_name) - partition.load() - - return collection, partition -# res = partition.search( -# # collection_name="pdf_measure_v4", # Replace with the actual name of your collection -# # Replace with your query vector + + +# data = [] +# measure_data = {} +# vector = [0.61865162262130161] * 1536 +# measure_data['vector'] = vector +# measure_data['table_num'] = int(2) +# measure_data['table_index'] = int(2) +# measure_data['measure_name'] = "234234" +# measure_data['measure_value'] = "23432" +# measure_data['measure_unit'] = "123423" +# measure_data['file_id'] = "100000" +# +# data.append(measure_data) +# res = client.insert( +# collection_name=collection_name, # data=data, -# limit=3, # Max. number of search results to return -# anns_field="vector", -# param={"metric_type": "COSINE", "params": {}}, # Search parameters -# output_fields=["measure_name","measure_value","table_num","table_index","measure_unit"], -# # filter=filter_str, -# expr=query +# partition_name=partition_name # ) +# filter_str = 'file_id == "'+"2122"+'"' +# res = client.search( +# collection_name=collection_name, # Replace with the actual name of your collection +# # Replace with your query vector +# data=data, +# limit=3, # Max. number of search results to return +# search_params={"metric_type": "COSINE", "params": {}}, # Search parameters +# output_fields=["measure_name", "measure_value", "table_num", "table_index", "measure_unit"], +# filter=filter_str, +# partition_name=partition_name +# ) +# print(f"============================={res}") diff --git a/zzb_data_prod/app.py b/zzb_data_prod/app.py index e4c21db..24d58af 100644 --- a/zzb_data_prod/app.py +++ b/zzb_data_prod/app.py @@ -12,7 +12,9 @@ import config import requests import db_service import threading -#import pdf_company_0824 +from Mil_unit import create_partition_by_hour +from datetime import datetime, timedelta + app = FastAPI() cpu_count = os.cpu_count() @@ -91,6 +93,7 @@ def run_job(): parser_start_time = time.time() processes = [] time_dispatch_job = time.time() + for job_info in page_list: p = Process(target=main.dispatch_job, args=(job_info,)) processes.append(p) @@ -120,11 +123,18 @@ def run_job(): parser_start_time = time.time() print('开始表格指标抽取,任务ID:', file_id) time_start = time.time() - if db_service.file_type_check_v2(file_id) ==3:#判断是否为3季报 - main.start_table_measure_job(file_id) - #time_start_end = time.time() - #process_time = time_start_end - time_start - #db_service.process_time(file_id,'2',process_time) + + + # 获取当前时间 + now = datetime.now() + current_hour = now.strftime("%Y%m%d%H") + partition_name = f"partition_{current_hour}" + # 判断是否创建新的分区 + create_partition_by_hour(current_hour) + # 判断是否为3季报 + + if db_service.file_type_check_v2(file_id) == 3: + main.start_table_measure_job(file_id,partition_name) time_start_end = time.time() process_time = time_start_end - time_start db_service.process_time(file_id,'2',process_time,time_start,time_start_end) @@ -134,21 +144,17 @@ def run_job(): print('启动这个指标归一化任务ID-修改测试:', file_id) time_update = time.time() - main.update_measure_data(file_id,file_path,parent_table_pages) - #time_update_end = time.time() - #process_time = time_update_end - time_update - #db_service.process_time(file_id,'3',process_time) + main.update_measure_data(file_id,file_path,parent_table_pages,partition_name) + print('归一化完成任务ID:', file_id) end_time = time.time() print(f"任务 {file_id} 完成,耗时{(end_time - start_time):.2f} 秒。") time_update_end = time.time() process_time = time_update_end - time_update db_service.process_time(file_id,'3',process_time,time_update,time_update_end) - else:#不是三季报就直接按照年报和半年报走 - main.start_table_measure_job(file_id) - #time_start_end = time.time() - #process_time = time_start_end - time_start - #db_service.process_time(file_id,'2',process_time) + # 不是三季报就直接按照年报和半年报走 + else: + main.start_table_measure_job(file_id,partition_name) time_start_end = time.time() process_time = time_start_end - time_start db_service.process_time(file_id,'2',process_time,time_start,time_start_end) @@ -158,10 +164,8 @@ def run_job(): print('启动这个指标归一化任务ID-修改测试:', file_id) time_update = time.time() - main.update_measure_data(file_id,file_path,parent_table_pages) - #time_update_end = time.time() - #process_time = time_update_end - time_update - #db_service.process_time(file_id,'3',process_time) + main.update_measure_data(file_id,file_path,parent_table_pages,partition_name) + print('归一化完成任务ID:', file_id) end_time = time.time() print(f"任务 {file_id} 完成,耗时{(end_time - start_time):.2f} 秒。") @@ -193,6 +197,7 @@ def run_job(): print(f"{file_id}运行失败: {e}") finally: print(f"任务 {file_id} 完成,运行状态:{job_status}") + #pdf_company_0824.name_code_fix(file_id,file_path) #print('公司名与编码填充完毕') else: @@ -219,13 +224,13 @@ app.post("/parser/start", # 运行 FastAPI 应用 if __name__ == "__main__": # 服务器启动服务 - # import uvicorn - # uvicorn.run(app, host="0.0.0.0", port=config.PORT) + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=config.PORT) # 本地调试任务 - job_queue.put({ - 'file_path' : '1.pdf', - 'file_id' : '2122' - }) - - run_job() + # job_queue.put({ + # 'file_path' : '1.pdf', + # 'file_id' : '2122' + # }) + # + # run_job() diff --git a/zzb_data_prod/db_service.py b/zzb_data_prod/db_service.py index 29fb0ad..baf8a37 100644 --- a/zzb_data_prod/db_service.py +++ b/zzb_data_prod/db_service.py @@ -10,7 +10,7 @@ from pymilvus import MilvusClient import mysql.connector import threading import redis -from Mil_unit import create_partition_by_hour + measure_name_keywords = ["营业","季度","利润","归属于","扣非","经营","现金","活动","损益","收益","资产","费用","销售","管理","财务","研发","货币资金","应收账款","存货","固定资产","在建工程","商誉","短期借款","应付账款","合同负债","长期借款","营业成本"] # 解析大模型抽取的指标,并插入到数据库 def parse_llm_measure_to_db(measure_info,type,conn,cursor): @@ -271,14 +271,13 @@ def update_ori_measure(conn,cursor,file_id): end_time = time.time() print(f"更新数据写入 {(end_time - start_time):.2f} 秒。") -def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array): +def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name,): create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('Run task %s (%s)...' % (record_range, os.getpid())) print(f"插入数据 {len(records)}") - _,partition = create_partition_by_hour() - + conn = mysql.connector.connect( host = MYSQL_HOST, user = MYSQL_USER, @@ -358,8 +357,9 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re record_start = record_range.split('-')[0] record_end = record_range.split('-')[1] - now = datetime.now() - current_hour = now.strftime("%Y%m%d%H") + client = MilvusClient( + uri=MILVUS_CLIENT, + ) try: for index in range(int(record_start),int(record_end)): @@ -372,24 +372,18 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re measure_vector = redis_service.read_from_redis(redis_client,ori_measure_id) measure_list = ast.literal_eval(measure_vector) data = [measure_list] - # data.append(measure_list) filter_str = 'file_id == "'+file_id+'"' - - - # 定义查询条件 - - - res = partition.search( - # collection_name="pdf_measure_v4", # Replace with the actual name of your collection + res = client.search( + collection_name="pdf_measure_v4", # Replace with the actual name of your collection # Replace with your query vector data=data, - limit=3, # Max. number of search results to return - anns_field="vector", - param={"metric_type": "COSINE", "params": {}}, # Search parameters - output_fields=["measure_name","measure_value","table_num","table_index","measure_unit"], - # filter=filter_str, - expr=filter_str + limit=3, # Max. number of search results to return + search_params={"metric_type": "COSINE", "params": {}}, # Search parameters + output_fields=["measure_name", "measure_value", "table_num", "table_index", "measure_unit"], + filter=filter_str, + partition_name=partition_name ) + # Convert the output to a formatted JSON string # for i in range(len(res[0])): @@ -540,11 +534,12 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re print(e) finally: parent_table_pages = [] + client.close() redis_client.close() cursor.close() conn.close() -def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name): +def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name, partition_name): select_year_select = f"""select report_type,year from report_check where id = {file_id}""" cursor.execute(select_year_select) record_select = cursor.fetchall() @@ -585,7 +580,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil records_range_parts = utils.get_range(len(records),MEASURE_COUNT) processes = [] for record_range in records_range_parts: - p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,)) + p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array, partition_name)) processes.append(p) p.start() elif report_type == 3: @@ -599,7 +594,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil records_range_parts = utils.get_range(len(records),MEASURE_COUNT) processes = [] for record_range in records_range_parts: - p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,)) + p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name)) processes.append(p) p.start() # p.apply_async(insert_table_from_vector_mul, args=(parent_table_pages,file_id,file_name,records,record_range,)) @@ -614,7 +609,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil records_range_parts = utils.get_range(len(records),MEASURE_COUNT) processes = [] for record_range in records_range_parts: - p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,)) + p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name)) processes.append(p) p.start() @@ -653,6 +648,8 @@ def insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_ end_time = time.time() print(f"向量配置数据查询 {(end_time - start_time):.2f} 秒。") start_time = time.time() + + try: for record in records: @@ -717,7 +714,7 @@ def insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_ start_time = time.time() -def insert_measure_data_to_milvus(milvus_partition,table_info,cursor,conn): +def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn): insert_query = ''' INSERT INTO measure_parse_process (file_id, page_num, content) @@ -798,9 +795,10 @@ def insert_measure_data_to_milvus(milvus_partition,table_info,cursor,conn): else: pass#print(f"数据值的格式错误:{measure_value}。或者字段名不在名单内{measure_name}") - - res = milvus_partition.insert( - data=data + res = client.insert( + collection_name="pdf_measure_v4", + data=data, + partition_name=partition_name ) except Exception as e: diff --git a/zzb_data_prod/main.py b/zzb_data_prod/main.py index 068825f..7d78293 100644 --- a/zzb_data_prod/main.py +++ b/zzb_data_prod/main.py @@ -21,6 +21,7 @@ import numpy as np from multiprocessing import Process from config import REDIS_HOST,REDIS_PORT,REDIS_PASSWORD import redis +from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection,MilvusClient @@ -692,21 +693,18 @@ def get_table_text_info(file_id,line_text,page_num,table_index): return table_info # 读取pdf中的表格,并将表格中指标和表头合并,eg: 2022年1季度营业收入为xxxxx -def get_table_measure(file_id, pdf_tables, record_range): +def get_table_measure(file_id, pdf_tables, record_range,partition_name,): """ :return: pdf中的表格,并将表格中指标和表头合并,eg: 2022年1季度营业收入为xxxxx """ try: - redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6) - conn = mysql.connector.connect( host = MYSQL_HOST, user = MYSQL_USER, password = MYSQL_PASSWORD, database = MYSQL_DB ) - # 创建一个cursor对象来执行SQL语句 cursor = conn.cursor(buffered=True) conn_app = mysql.connector.connect( @@ -715,22 +713,20 @@ def get_table_measure(file_id, pdf_tables, record_range): password = MYSQL_PASSWORD_APP, database = MYSQL_DB_APP ) - # 创建一个cursor对象来执行SQL语句 cursor_app = conn_app.cursor(buffered=True) - select_year_select = f"""select report_type,year from report_check where id = {file_id}""" cursor.execute(select_year_select) record_select = cursor.fetchall() report_type = record_select[0][0] report_year = record_select[0][1] - # 获取milvus 连接 - _, milvus_partition = create_partition_by_hour() + client = MilvusClient( + uri=MILVUS_CLIENT, + ) print('提取指标任务 %s (%s)...' % (record_range, os.getpid())) start = time.time() - record_start = record_range.split('-')[0] record_end = record_range.split('-')[1] for index in range(int(record_start),int(record_end)): @@ -843,7 +839,7 @@ def get_table_measure(file_id, pdf_tables, record_range): data_dict["page_num"] = f"{str(t['page_num'])}_{str(t['table_index'])}" data_dict['file_id'] = file_id measure_obj.append(data_dict) - db_service.insert_measure_data_to_milvus(milvus_partition,measure_obj,cursor_app,conn_app) + db_service.insert_measure_data_to_milvus(client,partition_name,measure_obj,cursor_app,conn_app) except Exception as e: print(f"循环获取表格数据这里报错了,数据是{t['data']},位置在{index}") print(f"错误是:{e}") @@ -885,7 +881,7 @@ def dispatch_job(job_info): #指标归一化处理 -def update_measure_data(file_id,file_path,parent_table_pages): +def update_measure_data(file_id,file_path,parent_table_pages,partition_name): conn = mysql.connector.connect( host = MYSQL_HOST, user = MYSQL_USER, @@ -907,7 +903,7 @@ def update_measure_data(file_id,file_path,parent_table_pages): cursor_app = conn_app.cursor(buffered=True) print(f'目录黑名单为:{parent_table_pages}') db_service.delete_to_run(conn,cursor,file_id) - db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path) + db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path, partition_name) # #指标归一化处理 db_service.update_ori_measure(conn,cursor,file_id) @@ -991,14 +987,13 @@ def merge_consecutive_arrays_v1(pdf_info): merged_objects.append(temp_array) return merged_objects -def start_table_measure_job(file_id): +def start_table_measure_job(file_id,partition_name): conn_app = mysql.connector.connect( host = MYSQL_HOST_APP, user = MYSQL_USER_APP, password = MYSQL_PASSWORD_APP, database = MYSQL_DB_APP ) - # 创建一个cursor对象来执行SQL语句 cursor_app = conn_app.cursor(buffered=True) @@ -1024,8 +1019,11 @@ def start_table_measure_job(file_id): records_range_parts = utils.get_range(len(pdf_tables),MEASURE_COUNT) print(f'records_range_part识别页码的值为{records_range_parts}') processes = [] + + + for record_range in records_range_parts: - p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,)) + p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,partition_name,)) processes.append(p) p.start()