pdf_code/zzb_data_word/test_db.py

197 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
import re,os,json
import utils
import ast
import time
import redis_service
from multiprocessing import Process
from config_p import MILVUS_CLIENT,MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB,REDIS_HOST,REDIS_PORT,REDIS_PASSWORD,MEASURE_COUNT
from pymilvus import MilvusClient
import mysql.connector
import threading
import redis
measure_name_keywords = ["营业","季度","利润","归属于","扣非","经营","现金","活动","损益","收益","资产","费用","销售","管理","财务","研发"]
def insert_table_from_vector_mul_process(file_id):
client = MilvusClient(
uri=MILVUS_CLIENT
)
conn = mysql.connector.connect(
host = MYSQL_HOST,
user = MYSQL_USER,
password = MYSQL_PASSWORD,
database = MYSQL_DB
)
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6)
# 创建一个cursor对象来执行SQL语句
cursor = conn.cursor(buffered=True)
select_year_select = f"""select report_type,year from report_check where id = {file_id}"""
cursor.execute(select_year_select)
record_select = cursor.fetchall()
report_type = record_select[0][0]
report_year = record_select[0][1]
select_query_half_year = '''
SELECT ori_measure_name,measure_name,distance,ori_measure_id,measure_id FROM measure_config_half_year
where year = '{year}'
and measure_name rlike '货币资金|应收账款|存货|固定资产|在建工程|商誉|短期借款|应付账款|合同负债|长期借款|营业成本'
'''.format(year=report_year)
cursor.execute(select_query_half_year)
records = cursor.fetchall()
record_range=f'0-{len(records)}'
check_query = '''
select id from ori_measure_list
WHERE file_id = %s and measure_name = %s and page_number = %s and table_index = %s and ori_measure_value = %s
'''
#获取表格上方文字黑名单关键词的页码和表格下标
select_table_index_query = '''
select distinct content from measure_parser_info WHERE file_id = '{file_id}' and type='table_index'
'''.format(file_id=file_id)
#表格上方文字黑名单关键词的页码和表格下标转成数组
table_index_array = []
cursor.execute(select_table_index_query)
table_index_records = cursor.fetchall()
for table_index_record in table_index_records:
table_index_array.append(table_index_record[0])
record_start = record_range.split('-')[0]
record_end = record_range.split('-')[1]
try:
for index in range(int(record_start),int(record_end)):
record = records[index]
ori_measure_name = record[0]
measure_name = record[1]
ori_measure_id = record[3]
measure_vector = redis_service.read_from_redis(redis_client,ori_measure_id)
measure_list = ast.literal_eval(measure_vector)
data = [measure_list]
# data.append(measure_list)
filter_str = 'file_id == "'+file_id+'"'
res = client.search(
collection_name="pdf_measure_v4", # Replace with the actual name of your collection
# Replace with your query vector
data=data,
limit=5, # Max. number of search results to return
search_params={"metric_type": "COSINE", "params": {}}, # Search parameters
output_fields=["measure_name","measure_value","table_num","table_index","measure_unit"],
filter=filter_str
)
# Convert the output to a formatted JSON string
# for i in range(len(res[0])):
for i in range(len(res[0])):
vector_distance = float(res[0][i]["distance"])
pdf_measure = res[0][i]["entity"]["measure_name"]
measure_value = res[0][i]["entity"]["measure_value"]
table_num = res[0][i]["entity"]["table_num"]
table_index = res[0][i]["entity"]["table_index"]
unit = res[0][i]["entity"]["measure_unit"]
#先过滤页码为0的情况暂时不知道原因
if table_num == 0:
continue
#过滤表格上方文字黑名单关键词的页码和表格下标
if f"{table_num}_{table_index}" in table_index_array:
continue
#if f"{table_num}_{table_index}" in table_index_array and pdf_measure in ():
#过滤指标中包含黑名单关键词
if utils.check_pdf_measure_black_list(pdf_measure):
continue
# if vector_distance > 0.80 :
#检测规则开始
#判断抽取指标和财报指标周期是否相同
# ori_period = utils.get_period_type(ori_measure_name, report_year)
# pdf_period = utils.get_period_type(pdf_measure, report_year)
# if(ori_period != pdf_period):
# continue
# #判断抽取指标和财报指标是否期初指标
# start_ori_period = utils.get_start_period_type(ori_measure_name)
# start_pdf_period = utils.get_start_period_type(pdf_measure)
# if(start_ori_period != start_pdf_period):
# continue
# #判断抽取指标和财报指标类型是否相同,是否都是季度
# ori_season_type = utils.get_season_flag(ori_measure_name)
# pdf_season_type = utils.get_season_flag(pdf_measure)
# if(ori_season_type != pdf_season_type):
# continue
# #判断是否都是扣非指标
# ori_kf_type = utils.get_kf_flag(ori_measure_name)
# pdf_kf_type = utils.get_kf_flag(pdf_measure)
# if(ori_kf_type != pdf_kf_type):
# continue
# #判断抽取指标和财报指标类型是否相同,是否都是百分比
# ori_type = utils.get_percent_flag(ori_measure_name)
# pdf_type = utils.get_percent_flag(pdf_measure)
# if(ori_type != pdf_type):
# continue
# #判断抽取指标和财报指标类型是否相同,是否都是占比同比变动类
# ori_growth_type = utils.get_percent_growth(ori_measure_name)
# pdf_growth_type = utils.get_percent_growth(pdf_measure)
# if(ori_growth_type != pdf_growth_type):
# continue
# #解决指标语义是比率,但值为非比率的情况
# if ori_growth_type == '1':
# check_measure_value = abs(float(measure_value))
# if(check_measure_value > 10000):
# continue
# # 判断数据库中是否有数据
# check_query_data = (file_id, measure_name, int(table_num), int(table_index), measure_value)
# cursor.execute(check_query, check_query_data)
# check_records = cursor.fetchall()
# if(len(check_records)) > 0:
# continue
# #判断是否包含黑名单
# if(utils.check_black_list(measure_name,pdf_measure)):
# continue
# #判断抽取指标和财报指标类型是否都是增长类,比如同比变动为增长类
# ori_change_type = utils.get_change_rate_flag(ori_measure_name)
# pdf_change_type = utils.get_change_rate_flag(pdf_measure)
# if(ori_change_type != pdf_change_type):
# continue
# #处理调整前,调整前、后同时出现,如果有调整前过滤
# if pdf_measure.find('调整前') != -1 or pdf_measure.find('重述前') != -1:
# continue
# #判断指标是否报告期初
# ori_report_start = utils.get_report_start(ori_measure_name)
# pdf_report_start = utils.get_report_start(pdf_measure)
# if(ori_report_start != pdf_report_start):
# continue
print(f'{measure_name},{ori_measure_name},{pdf_measure},{vector_distance},{table_num},{table_index}')
except Exception as e:
print(e)
finally:
redis_client.close()
cursor.close()
conn.close()
client.close()
if __name__ == '__main__':
insert_table_from_vector_mul_process('1766')