pdf_code/zzb_data_prod/test.py

155 lines
6.1 KiB
Python
Raw Normal View History

2024-11-29 15:58:06 +08:00
#coding=utf-8
import sys,ast
from pdfminer.high_level import extract_text
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfdocument import PDFDocument
from pdfminer.pdfpage import PDFPage
import utils
import mysql.connector
from pymilvus import connections,MilvusClient
import json
import db_service
import ast
import numpy as np
import config
import redis_service
from config import MILVUS_CLIENT,MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB
import main
import redis
def measure_config_to_db(conn,cursor):
insert_query = '''
INSERT INTO measure_config
(measure_id, measure_name, ori_measure_id, ori_measure_name)
VALUES (%s, %s, %s, %s)
'''
check_query = '''
select ori_measure_id from measure_config
'''
# 打开文本文件
with open('/Users/zhengfei/work/zzb_data/measure_config_all.txt', 'r') as file:
# 读取所有行到一个列表中
lines = file.readlines()
# 打印每一行
for line in lines:
config_list = line.strip().split(',')
measure = config_list[0]
ori_measure = config_list[1]
ori_measure_id = utils.get_md5(ori_measure)
# 判断数据库中是否有数据
# cursor.execute(check_query.format(ori_measure_id=ori_measure_id))
# check_records = cursor.fetchall()
# if(len(check_records)) > 0:
# continue
data_to_insert = (utils.get_md5(measure), measure, ori_measure_id, ori_measure)
cursor.execute(insert_query, data_to_insert)
conn.commit()
def insert_measure_vector(conn,cursor):
redis_client = redis.Redis(host='192.168.0.172', port=6379, password='Xgf_redis', db=6)
# 执行SQL语句更新数据
select_query = '''
SELECT ori_measure_id,ori_measure_name FROM measure_config
'''
cursor.execute(select_query)
records = cursor.fetchall()
for record in records:
if redis_client.hexists('measure_config', record[0]):
measure_vector = redis_client.hget('measure_config', record[0])
else:
print('新增指标',record[1])
vector_obj = utils.embed_with_str(record[1])
measure_vector = str(vector_obj.output["embeddings"][0]["embedding"])
redis_client.hset('measure_config', record[0], measure_vector)
redis_client.close()
conn.close()
def contains_financial_indicators(text):
import re
# 正则表达式模式匹配千分位格式的数字和百分比
pattern = r"\d{1,3}(,\d{3})+(\.\d{1,3})?"
pattern1 = r"\d+(.\d+)+%?"
# 使用 re.search 函数查找匹配项
match = re.search(pattern1, text)
# 如果找到匹配项,返回 True否则返回 False
return bool(match)
def get_clean_text(text):
import re
pattern = r"\[^)]*?\"
matches = re.findall(pattern, text)
for match in matches:
# 使用 re.findall 函数查找括号内的内容中是否包含月份或关键词
month_keywords_found = re.search(r"归属于|扣非", match)
if not month_keywords_found:
# 如果包含,则从文本中删除该部分
text = re.sub(pattern,"", text)
else:
# 如果不包含,删除所有标点符号和中文数字
text = re.sub(r"[^\w\s]", "", text)
print(text)
def insert_and_update(conn,cursor,client,parent_table_pages,file_id,path):
# #通过向量查询指标
db_service.insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_id,path)
# #指标归一化处理
db_service.update_ori_measure(conn,cursor,file_id)
def print_measure_data(cursor,client):
select_query = '''
SELECT ori_measure_name,measure_name,ori_measure_id FROM measure_config
where measure_id not in(select distinct measure_id from ori_measure_list where file_id='64')
'''
cursor.execute(select_query)
records = cursor.fetchall()
for record in records:
ori_measure_name = record[0]
measure_name = record[1]
ori_measure_id = record[2]
measure_vector = redis_service.read_from_redis(ori_measure_id)
measure_list = ast.literal_eval(measure_vector)
data = [measure_list]
res = client.search(
collection_name="pdf_measure_v4", # Replace with the actual name of your collection
# Replace with your query vector
data=data,
limit=2, # Max. number of search results to return
search_params={"metric_type": "COSINE", "params": {}}, # Search parameters
output_fields=["measure_name","measure_value","table_num","table_index"],
filter = 'file_id == "64"'
)
vector_str = measure_name+":"+ori_measure_name
# Convert the output to a formatted JSON string
for i in range(len(res[0])):
vector_distance = float(res[0][i]["distance"])
vector_measure_name = res[0][i]["entity"]["measure_name"]
measure_value = res[0][i]["entity"]["measure_value"]
table_num = res[0][i]["entity"]["table_num"]
table_index = res[0][i]["entity"]["table_index"]
table_num_list = [106]
print(vector_str +":"+vector_measure_name+":"+str(vector_distance) +":"+measure_value +":"+str(table_num) +":"+str(table_index))
# if vector_distance > 0.89 and table_num not in table_num_list:
# print(vector_str +":"+vector_measure_name+":"+str(vector_distance) +":"+measure_value +":"+str(table_num) +":"+str(table_index)+":"+str(0.94))
# if vector_distance > distance and table_num not in table_num_list:
# print(vector_str +":"+vector_measure_name +":"+measure_value +":"+str(table_num) +":"+str(table_index)+":"+str(vector_distance)+":"+str(distance))
if __name__ == "__main__":
conn = mysql.connector.connect(
host=MYSQL_HOST,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=MYSQL_DB
)
cursor = conn.cursor()
insert_measure_vector(conn,cursor)