milvus 修改分区写入测试版本v1.1 功能验证正常

This commit is contained in:
吴登数 2024-12-03 12:33:40 +08:00
parent fde9ef3fef
commit 452d843a81
4 changed files with 106 additions and 87 deletions

View File

@ -3,7 +3,7 @@ from config import MILVUS_CLIENT
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
def create_partition_by_hour(): def create_partition_by_hour(current_hour):
# 连接到 Milvus 服务器 # 连接到 Milvus 服务器
connections.connect("default",uri=MILVUS_CLIENT) connections.connect("default",uri=MILVUS_CLIENT)
# 获取集合 # 获取集合
@ -12,16 +12,18 @@ def create_partition_by_hour():
# 获取当前时间 # 获取当前时间
now = datetime.now() now = datetime.now()
current_hour = now.strftime("%Y%m%d%H")
# 创建当前小时的分区 # 创建当前小时的分区
partition_name = f"partition_{current_hour}" partition_name = f"partition_{current_hour}"
if not collection.has_partition(partition_name): if not collection.has_partition(partition_name):
collection.create_partition(partition_name) collection.create_partition(partition_name)
print(f"Created partition: {partition_name}") print(f"Created partition: {partition_name}")
partition = collection.partition(partition_name)
partition.load()
# 删除前一个小时的分区 # 删除前2个小时的分区
previous_hour = (now - timedelta(hours=1)).strftime("%Y%m%d%H") previous_hour = (now - timedelta(hours=2)).strftime("%Y%m%d%H")
previous_partition_name = f"partition_{previous_hour}" previous_partition_name = f"partition_{previous_hour}"
if collection.has_partition(previous_partition_name): if collection.has_partition(previous_partition_name):
@ -30,24 +32,40 @@ def create_partition_by_hour():
collection.drop_partition(previous_partition_name) collection.drop_partition(previous_partition_name)
print(f"Dropped partition: {previous_partition_name}") print(f"Dropped partition: {previous_partition_name}")
partition = collection.partition(partition_name)
partition.load()
return collection, partition
# res = partition.search(
# # collection_name="pdf_measure_v4", # Replace with the actual name of your collection
# # Replace with your query vector # data = []
# measure_data = {}
# vector = [0.61865162262130161] * 1536
# measure_data['vector'] = vector
# measure_data['table_num'] = int(2)
# measure_data['table_index'] = int(2)
# measure_data['measure_name'] = "234234"
# measure_data['measure_value'] = "23432"
# measure_data['measure_unit'] = "123423"
# measure_data['file_id'] = "100000"
#
# data.append(measure_data)
# res = client.insert(
# collection_name=collection_name,
# data=data, # data=data,
# limit=3, # Max. number of search results to return # partition_name=partition_name
# anns_field="vector",
# param={"metric_type": "COSINE", "params": {}}, # Search parameters
# output_fields=["measure_name","measure_value","table_num","table_index","measure_unit"],
# # filter=filter_str,
# expr=query
# ) # )
# filter_str = 'file_id == "'+"2122"+'"'
# res = client.search(
# collection_name=collection_name, # Replace with the actual name of your collection
# # Replace with your query vector
# data=data,
# limit=3, # Max. number of search results to return
# search_params={"metric_type": "COSINE", "params": {}}, # Search parameters
# output_fields=["measure_name", "measure_value", "table_num", "table_index", "measure_unit"],
# filter=filter_str,
# partition_name=partition_name
# )
# print(f"============================={res}")

View File

@ -12,7 +12,9 @@ import config
import requests import requests
import db_service import db_service
import threading import threading
#import pdf_company_0824 from Mil_unit import create_partition_by_hour
from datetime import datetime, timedelta
app = FastAPI() app = FastAPI()
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
@ -91,6 +93,7 @@ def run_job():
parser_start_time = time.time() parser_start_time = time.time()
processes = [] processes = []
time_dispatch_job = time.time() time_dispatch_job = time.time()
for job_info in page_list: for job_info in page_list:
p = Process(target=main.dispatch_job, args=(job_info,)) p = Process(target=main.dispatch_job, args=(job_info,))
processes.append(p) processes.append(p)
@ -120,11 +123,18 @@ def run_job():
parser_start_time = time.time() parser_start_time = time.time()
print('开始表格指标抽取任务ID:', file_id) print('开始表格指标抽取任务ID:', file_id)
time_start = time.time() time_start = time.time()
if db_service.file_type_check_v2(file_id) ==3:#判断是否为3季报
main.start_table_measure_job(file_id)
#time_start_end = time.time() # 获取当前时间
#process_time = time_start_end - time_start now = datetime.now()
#db_service.process_time(file_id,'2',process_time) current_hour = now.strftime("%Y%m%d%H")
partition_name = f"partition_{current_hour}"
# 判断是否创建新的分区
create_partition_by_hour(current_hour)
# 判断是否为3季报
if db_service.file_type_check_v2(file_id) == 3:
main.start_table_measure_job(file_id,partition_name)
time_start_end = time.time() time_start_end = time.time()
process_time = time_start_end - time_start process_time = time_start_end - time_start
db_service.process_time(file_id,'2',process_time,time_start,time_start_end) db_service.process_time(file_id,'2',process_time,time_start,time_start_end)
@ -134,21 +144,17 @@ def run_job():
print('启动这个指标归一化任务ID-修改测试:', file_id) print('启动这个指标归一化任务ID-修改测试:', file_id)
time_update = time.time() time_update = time.time()
main.update_measure_data(file_id,file_path,parent_table_pages) main.update_measure_data(file_id,file_path,parent_table_pages,partition_name)
#time_update_end = time.time()
#process_time = time_update_end - time_update
#db_service.process_time(file_id,'3',process_time)
print('归一化完成任务ID:', file_id) print('归一化完成任务ID:', file_id)
end_time = time.time() end_time = time.time()
print(f"任务 {file_id} 完成,耗时{(end_time - start_time):.2f} 秒。") print(f"任务 {file_id} 完成,耗时{(end_time - start_time):.2f} 秒。")
time_update_end = time.time() time_update_end = time.time()
process_time = time_update_end - time_update process_time = time_update_end - time_update
db_service.process_time(file_id,'3',process_time,time_update,time_update_end) db_service.process_time(file_id,'3',process_time,time_update,time_update_end)
else:#不是三季报就直接按照年报和半年报走 # 不是三季报就直接按照年报和半年报走
main.start_table_measure_job(file_id) else:
#time_start_end = time.time() main.start_table_measure_job(file_id,partition_name)
#process_time = time_start_end - time_start
#db_service.process_time(file_id,'2',process_time)
time_start_end = time.time() time_start_end = time.time()
process_time = time_start_end - time_start process_time = time_start_end - time_start
db_service.process_time(file_id,'2',process_time,time_start,time_start_end) db_service.process_time(file_id,'2',process_time,time_start,time_start_end)
@ -158,10 +164,8 @@ def run_job():
print('启动这个指标归一化任务ID-修改测试:', file_id) print('启动这个指标归一化任务ID-修改测试:', file_id)
time_update = time.time() time_update = time.time()
main.update_measure_data(file_id,file_path,parent_table_pages) main.update_measure_data(file_id,file_path,parent_table_pages,partition_name)
#time_update_end = time.time()
#process_time = time_update_end - time_update
#db_service.process_time(file_id,'3',process_time)
print('归一化完成任务ID:', file_id) print('归一化完成任务ID:', file_id)
end_time = time.time() end_time = time.time()
print(f"任务 {file_id} 完成,耗时{(end_time - start_time):.2f} 秒。") print(f"任务 {file_id} 完成,耗时{(end_time - start_time):.2f} 秒。")
@ -193,6 +197,7 @@ def run_job():
print(f"{file_id}运行失败: {e}") print(f"{file_id}运行失败: {e}")
finally: finally:
print(f"任务 {file_id} 完成,运行状态:{job_status}") print(f"任务 {file_id} 完成,运行状态:{job_status}")
#pdf_company_0824.name_code_fix(file_id,file_path) #pdf_company_0824.name_code_fix(file_id,file_path)
#print('公司名与编码填充完毕') #print('公司名与编码填充完毕')
else: else:
@ -219,13 +224,13 @@ app.post("/parser/start",
# 运行 FastAPI 应用 # 运行 FastAPI 应用
if __name__ == "__main__": if __name__ == "__main__":
# 服务器启动服务 # 服务器启动服务
# import uvicorn import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=config.PORT) uvicorn.run(app, host="0.0.0.0", port=config.PORT)
# 本地调试任务 # 本地调试任务
job_queue.put({ # job_queue.put({
'file_path' : '1.pdf', # 'file_path' : '1.pdf',
'file_id' : '2122' # 'file_id' : '2122'
}) # })
#
run_job() # run_job()

View File

@ -10,7 +10,7 @@ from pymilvus import MilvusClient
import mysql.connector import mysql.connector
import threading import threading
import redis import redis
from Mil_unit import create_partition_by_hour
measure_name_keywords = ["营业","季度","利润","归属于","扣非","经营","现金","活动","损益","收益","资产","费用","销售","管理","财务","研发","货币资金","应收账款","存货","固定资产","在建工程","商誉","短期借款","应付账款","合同负债","长期借款","营业成本"] measure_name_keywords = ["营业","季度","利润","归属于","扣非","经营","现金","活动","损益","收益","资产","费用","销售","管理","财务","研发","货币资金","应收账款","存货","固定资产","在建工程","商誉","短期借款","应付账款","合同负债","长期借款","营业成本"]
# 解析大模型抽取的指标,并插入到数据库 # 解析大模型抽取的指标,并插入到数据库
def parse_llm_measure_to_db(measure_info,type,conn,cursor): def parse_llm_measure_to_db(measure_info,type,conn,cursor):
@ -271,14 +271,13 @@ def update_ori_measure(conn,cursor,file_id):
end_time = time.time() end_time = time.time()
print(f"更新数据写入 {(end_time - start_time):.2f} 秒。") print(f"更新数据写入 {(end_time - start_time):.2f} 秒。")
def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array): def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name,):
create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print('Run task %s (%s)...' % (record_range, os.getpid())) print('Run task %s (%s)...' % (record_range, os.getpid()))
print(f"插入数据 {len(records)}") print(f"插入数据 {len(records)}")
_,partition = create_partition_by_hour()
conn = mysql.connector.connect( conn = mysql.connector.connect(
host = MYSQL_HOST, host = MYSQL_HOST,
user = MYSQL_USER, user = MYSQL_USER,
@ -358,8 +357,9 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re
record_start = record_range.split('-')[0] record_start = record_range.split('-')[0]
record_end = record_range.split('-')[1] record_end = record_range.split('-')[1]
now = datetime.now() client = MilvusClient(
current_hour = now.strftime("%Y%m%d%H") uri=MILVUS_CLIENT,
)
try: try:
for index in range(int(record_start),int(record_end)): for index in range(int(record_start),int(record_end)):
@ -372,24 +372,18 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re
measure_vector = redis_service.read_from_redis(redis_client,ori_measure_id) measure_vector = redis_service.read_from_redis(redis_client,ori_measure_id)
measure_list = ast.literal_eval(measure_vector) measure_list = ast.literal_eval(measure_vector)
data = [measure_list] data = [measure_list]
# data.append(measure_list)
filter_str = 'file_id == "'+file_id+'"' filter_str = 'file_id == "'+file_id+'"'
res = client.search(
collection_name="pdf_measure_v4", # Replace with the actual name of your collection
# 定义查询条件
res = partition.search(
# collection_name="pdf_measure_v4", # Replace with the actual name of your collection
# Replace with your query vector # Replace with your query vector
data=data, data=data,
limit=3, # Max. number of search results to return limit=3, # Max. number of search results to return
anns_field="vector", search_params={"metric_type": "COSINE", "params": {}}, # Search parameters
param={"metric_type": "COSINE", "params": {}}, # Search parameters output_fields=["measure_name", "measure_value", "table_num", "table_index", "measure_unit"],
output_fields=["measure_name","measure_value","table_num","table_index","measure_unit"], filter=filter_str,
# filter=filter_str, partition_name=partition_name
expr=filter_str
) )
# Convert the output to a formatted JSON string # Convert the output to a formatted JSON string
# for i in range(len(res[0])): # for i in range(len(res[0])):
@ -540,11 +534,12 @@ def insert_table_from_vector_mul_process(parent_table_pages,file_id,file_name,re
print(e) print(e)
finally: finally:
parent_table_pages = [] parent_table_pages = []
client.close()
redis_client.close() redis_client.close()
cursor.close() cursor.close()
conn.close() conn.close()
def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name): def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_name, partition_name):
select_year_select = f"""select report_type,year from report_check where id = {file_id}""" select_year_select = f"""select report_type,year from report_check where id = {file_id}"""
cursor.execute(select_year_select) cursor.execute(select_year_select)
record_select = cursor.fetchall() record_select = cursor.fetchall()
@ -585,7 +580,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil
records_range_parts = utils.get_range(len(records),MEASURE_COUNT) records_range_parts = utils.get_range(len(records),MEASURE_COUNT)
processes = [] processes = []
for record_range in records_range_parts: for record_range in records_range_parts:
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,)) p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array, partition_name))
processes.append(p) processes.append(p)
p.start() p.start()
elif report_type == 3: elif report_type == 3:
@ -599,7 +594,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil
records_range_parts = utils.get_range(len(records),MEASURE_COUNT) records_range_parts = utils.get_range(len(records),MEASURE_COUNT)
processes = [] processes = []
for record_range in records_range_parts: for record_range in records_range_parts:
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,)) p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name))
processes.append(p) processes.append(p)
p.start() p.start()
# p.apply_async(insert_table_from_vector_mul, args=(parent_table_pages,file_id,file_name,records,record_range,)) # p.apply_async(insert_table_from_vector_mul, args=(parent_table_pages,file_id,file_name,records,record_range,))
@ -614,7 +609,7 @@ def insert_table_measure_from_vector_async_process(cursor,parent_table_pages,fil
records_range_parts = utils.get_range(len(records),MEASURE_COUNT) records_range_parts = utils.get_range(len(records),MEASURE_COUNT)
processes = [] processes = []
for record_range in records_range_parts: for record_range in records_range_parts:
p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,)) p = Process(target=insert_table_from_vector_mul_process, args=(parent_table_pages,file_id,file_name,records,record_range,black_array,partition_name))
processes.append(p) processes.append(p)
p.start() p.start()
@ -653,6 +648,8 @@ def insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_
end_time = time.time() end_time = time.time()
print(f"向量配置数据查询 {(end_time - start_time):.2f} 秒。") print(f"向量配置数据查询 {(end_time - start_time):.2f} 秒。")
start_time = time.time() start_time = time.time()
try: try:
for record in records: for record in records:
@ -717,7 +714,7 @@ def insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_
start_time = time.time() start_time = time.time()
def insert_measure_data_to_milvus(milvus_partition,table_info,cursor,conn): def insert_measure_data_to_milvus(client,partition_name,table_info,cursor,conn):
insert_query = ''' insert_query = '''
INSERT INTO measure_parse_process INSERT INTO measure_parse_process
(file_id, page_num, content) (file_id, page_num, content)
@ -798,9 +795,10 @@ def insert_measure_data_to_milvus(milvus_partition,table_info,cursor,conn):
else: else:
pass#print(f"数据值的格式错误:{measure_value}。或者字段名不在名单内{measure_name}") pass#print(f"数据值的格式错误:{measure_value}。或者字段名不在名单内{measure_name}")
res = client.insert(
res = milvus_partition.insert( collection_name="pdf_measure_v4",
data=data data=data,
partition_name=partition_name
) )
except Exception as e: except Exception as e:

View File

@ -21,6 +21,7 @@ import numpy as np
from multiprocessing import Process from multiprocessing import Process
from config import REDIS_HOST,REDIS_PORT,REDIS_PASSWORD from config import REDIS_HOST,REDIS_PORT,REDIS_PASSWORD
import redis import redis
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection,MilvusClient
@ -692,21 +693,18 @@ def get_table_text_info(file_id,line_text,page_num,table_index):
return table_info return table_info
# 读取pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx # 读取pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
def get_table_measure(file_id, pdf_tables, record_range): def get_table_measure(file_id, pdf_tables, record_range,partition_name,):
""" """
:return: pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx :return: pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
""" """
try: try:
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6) redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6)
conn = mysql.connector.connect( conn = mysql.connector.connect(
host = MYSQL_HOST, host = MYSQL_HOST,
user = MYSQL_USER, user = MYSQL_USER,
password = MYSQL_PASSWORD, password = MYSQL_PASSWORD,
database = MYSQL_DB database = MYSQL_DB
) )
# 创建一个cursor对象来执行SQL语句 # 创建一个cursor对象来执行SQL语句
cursor = conn.cursor(buffered=True) cursor = conn.cursor(buffered=True)
conn_app = mysql.connector.connect( conn_app = mysql.connector.connect(
@ -715,22 +713,20 @@ def get_table_measure(file_id, pdf_tables, record_range):
password = MYSQL_PASSWORD_APP, password = MYSQL_PASSWORD_APP,
database = MYSQL_DB_APP database = MYSQL_DB_APP
) )
# 创建一个cursor对象来执行SQL语句 # 创建一个cursor对象来执行SQL语句
cursor_app = conn_app.cursor(buffered=True) cursor_app = conn_app.cursor(buffered=True)
select_year_select = f"""select report_type,year from report_check where id = {file_id}""" select_year_select = f"""select report_type,year from report_check where id = {file_id}"""
cursor.execute(select_year_select) cursor.execute(select_year_select)
record_select = cursor.fetchall() record_select = cursor.fetchall()
report_type = record_select[0][0] report_type = record_select[0][0]
report_year = record_select[0][1] report_year = record_select[0][1]
# 获取milvus 连接 client = MilvusClient(
_, milvus_partition = create_partition_by_hour() uri=MILVUS_CLIENT,
)
print('提取指标任务 %s (%s)...' % (record_range, os.getpid())) print('提取指标任务 %s (%s)...' % (record_range, os.getpid()))
start = time.time() start = time.time()
record_start = record_range.split('-')[0] record_start = record_range.split('-')[0]
record_end = record_range.split('-')[1] record_end = record_range.split('-')[1]
for index in range(int(record_start),int(record_end)): for index in range(int(record_start),int(record_end)):
@ -843,7 +839,7 @@ def get_table_measure(file_id, pdf_tables, record_range):
data_dict["page_num"] = f"{str(t['page_num'])}_{str(t['table_index'])}" data_dict["page_num"] = f"{str(t['page_num'])}_{str(t['table_index'])}"
data_dict['file_id'] = file_id data_dict['file_id'] = file_id
measure_obj.append(data_dict) measure_obj.append(data_dict)
db_service.insert_measure_data_to_milvus(milvus_partition,measure_obj,cursor_app,conn_app) db_service.insert_measure_data_to_milvus(client,partition_name,measure_obj,cursor_app,conn_app)
except Exception as e: except Exception as e:
print(f"循环获取表格数据这里报错了,数据是{t['data']},位置在{index}") print(f"循环获取表格数据这里报错了,数据是{t['data']},位置在{index}")
print(f"错误是:{e}") print(f"错误是:{e}")
@ -885,7 +881,7 @@ def dispatch_job(job_info):
#指标归一化处理 #指标归一化处理
def update_measure_data(file_id,file_path,parent_table_pages): def update_measure_data(file_id,file_path,parent_table_pages,partition_name):
conn = mysql.connector.connect( conn = mysql.connector.connect(
host = MYSQL_HOST, host = MYSQL_HOST,
user = MYSQL_USER, user = MYSQL_USER,
@ -907,7 +903,7 @@ def update_measure_data(file_id,file_path,parent_table_pages):
cursor_app = conn_app.cursor(buffered=True) cursor_app = conn_app.cursor(buffered=True)
print(f'目录黑名单为:{parent_table_pages}') print(f'目录黑名单为:{parent_table_pages}')
db_service.delete_to_run(conn,cursor,file_id) db_service.delete_to_run(conn,cursor,file_id)
db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path) db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path, partition_name)
# #指标归一化处理 # #指标归一化处理
db_service.update_ori_measure(conn,cursor,file_id) db_service.update_ori_measure(conn,cursor,file_id)
@ -991,14 +987,13 @@ def merge_consecutive_arrays_v1(pdf_info):
merged_objects.append(temp_array) merged_objects.append(temp_array)
return merged_objects return merged_objects
def start_table_measure_job(file_id): def start_table_measure_job(file_id,partition_name):
conn_app = mysql.connector.connect( conn_app = mysql.connector.connect(
host = MYSQL_HOST_APP, host = MYSQL_HOST_APP,
user = MYSQL_USER_APP, user = MYSQL_USER_APP,
password = MYSQL_PASSWORD_APP, password = MYSQL_PASSWORD_APP,
database = MYSQL_DB_APP database = MYSQL_DB_APP
) )
# 创建一个cursor对象来执行SQL语句 # 创建一个cursor对象来执行SQL语句
cursor_app = conn_app.cursor(buffered=True) cursor_app = conn_app.cursor(buffered=True)
@ -1024,8 +1019,11 @@ def start_table_measure_job(file_id):
records_range_parts = utils.get_range(len(pdf_tables),MEASURE_COUNT) records_range_parts = utils.get_range(len(pdf_tables),MEASURE_COUNT)
print(f'records_range_part识别页码的值为{records_range_parts}') print(f'records_range_part识别页码的值为{records_range_parts}')
processes = [] processes = []
for record_range in records_range_parts: for record_range in records_range_parts:
p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,)) p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,partition_name,))
processes.append(p) processes.append(p)
p.start() p.start()