pdf_code/zzb_data_prod/test.py

155 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#coding=utf-8
import sys,ast
from pdfminer.high_level import extract_text
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfdocument import PDFDocument
from pdfminer.pdfpage import PDFPage
import utils
import mysql.connector
from pymilvus import connections,MilvusClient
import json
import db_service
import ast
import numpy as np
import config
import redis_service
from config import MILVUS_CLIENT,MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB
import main
import redis
def measure_config_to_db(conn,cursor):
insert_query = '''
INSERT INTO measure_config
(measure_id, measure_name, ori_measure_id, ori_measure_name)
VALUES (%s, %s, %s, %s)
'''
check_query = '''
select ori_measure_id from measure_config
'''
# 打开文本文件
with open('/Users/zhengfei/work/zzb_data/measure_config_all.txt', 'r') as file:
# 读取所有行到一个列表中
lines = file.readlines()
# 打印每一行
for line in lines:
config_list = line.strip().split(',')
measure = config_list[0]
ori_measure = config_list[1]
ori_measure_id = utils.get_md5(ori_measure)
# 判断数据库中是否有数据
# cursor.execute(check_query.format(ori_measure_id=ori_measure_id))
# check_records = cursor.fetchall()
# if(len(check_records)) > 0:
# continue
data_to_insert = (utils.get_md5(measure), measure, ori_measure_id, ori_measure)
cursor.execute(insert_query, data_to_insert)
conn.commit()
def insert_measure_vector(conn,cursor):
redis_client = redis.Redis(host='192.168.0.172', port=6379, password='Xgf_redis', db=6)
# 执行SQL语句更新数据
select_query = '''
SELECT ori_measure_id,ori_measure_name FROM measure_config
'''
cursor.execute(select_query)
records = cursor.fetchall()
for record in records:
if redis_client.hexists('measure_config', record[0]):
measure_vector = redis_client.hget('measure_config', record[0])
else:
print('新增指标',record[1])
vector_obj = utils.embed_with_str(record[1])
measure_vector = str(vector_obj.output["embeddings"][0]["embedding"])
redis_client.hset('measure_config', record[0], measure_vector)
redis_client.close()
conn.close()
def contains_financial_indicators(text):
import re
# 正则表达式模式匹配千分位格式的数字和百分比
pattern = r"\d{1,3}(,\d{3})+(\.\d{1,3})?"
pattern1 = r"\d+(.\d+)+%?"
# 使用 re.search 函数查找匹配项
match = re.search(pattern1, text)
# 如果找到匹配项,返回 True否则返回 False
return bool(match)
def get_clean_text(text):
import re
pattern = r"\[^)]*?\"
matches = re.findall(pattern, text)
for match in matches:
# 使用 re.findall 函数查找括号内的内容中是否包含月份或关键词
month_keywords_found = re.search(r"归属于|扣非", match)
if not month_keywords_found:
# 如果包含,则从文本中删除该部分
text = re.sub(pattern,"", text)
else:
# 如果不包含,删除所有标点符号和中文数字
text = re.sub(r"[^\w\s]", "", text)
print(text)
def insert_and_update(conn,cursor,client,parent_table_pages,file_id,path):
# #通过向量查询指标
db_service.insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_id,path)
# #指标归一化处理
db_service.update_ori_measure(conn,cursor,file_id)
def print_measure_data(cursor,client):
select_query = '''
SELECT ori_measure_name,measure_name,ori_measure_id FROM measure_config
where measure_id not in(select distinct measure_id from ori_measure_list where file_id='64')
'''
cursor.execute(select_query)
records = cursor.fetchall()
for record in records:
ori_measure_name = record[0]
measure_name = record[1]
ori_measure_id = record[2]
measure_vector = redis_service.read_from_redis(ori_measure_id)
measure_list = ast.literal_eval(measure_vector)
data = [measure_list]
res = client.search(
collection_name="pdf_measure_v4", # Replace with the actual name of your collection
# Replace with your query vector
data=data,
limit=2, # Max. number of search results to return
search_params={"metric_type": "COSINE", "params": {}}, # Search parameters
output_fields=["measure_name","measure_value","table_num","table_index"],
filter = 'file_id == "64"'
)
vector_str = measure_name+":"+ori_measure_name
# Convert the output to a formatted JSON string
for i in range(len(res[0])):
vector_distance = float(res[0][i]["distance"])
vector_measure_name = res[0][i]["entity"]["measure_name"]
measure_value = res[0][i]["entity"]["measure_value"]
table_num = res[0][i]["entity"]["table_num"]
table_index = res[0][i]["entity"]["table_index"]
table_num_list = [106]
print(vector_str +":"+vector_measure_name+":"+str(vector_distance) +":"+measure_value +":"+str(table_num) +":"+str(table_index))
# if vector_distance > 0.89 and table_num not in table_num_list:
# print(vector_str +":"+vector_measure_name+":"+str(vector_distance) +":"+measure_value +":"+str(table_num) +":"+str(table_index)+":"+str(0.94))
# if vector_distance > distance and table_num not in table_num_list:
# print(vector_str +":"+vector_measure_name +":"+measure_value +":"+str(table_num) +":"+str(table_index)+":"+str(vector_distance)+":"+str(distance))
if __name__ == "__main__":
conn = mysql.connector.connect(
host=MYSQL_HOST,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=MYSQL_DB
)
cursor = conn.cursor()
insert_measure_vector(conn,cursor)