155 lines
6.1 KiB
Python
155 lines
6.1 KiB
Python
#coding=utf-8
|
||
import sys,ast
|
||
from pdfminer.high_level import extract_text
|
||
from pdfminer.pdfparser import PDFParser
|
||
from pdfminer.pdfdocument import PDFDocument
|
||
from pdfminer.pdfpage import PDFPage
|
||
import utils
|
||
import mysql.connector
|
||
from pymilvus import connections,MilvusClient
|
||
import json
|
||
import db_service
|
||
import ast
|
||
import numpy as np
|
||
import config
|
||
import redis_service
|
||
from config import MILVUS_CLIENT,MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB
|
||
import main
|
||
import redis
|
||
|
||
def measure_config_to_db(conn,cursor):
|
||
insert_query = '''
|
||
INSERT INTO measure_config
|
||
(measure_id, measure_name, ori_measure_id, ori_measure_name)
|
||
VALUES (%s, %s, %s, %s)
|
||
'''
|
||
check_query = '''
|
||
select ori_measure_id from measure_config
|
||
'''
|
||
# 打开文本文件
|
||
with open('/Users/zhengfei/work/zzb_data/measure_config_all.txt', 'r') as file:
|
||
# 读取所有行到一个列表中
|
||
lines = file.readlines()
|
||
|
||
# 打印每一行
|
||
for line in lines:
|
||
config_list = line.strip().split(',')
|
||
measure = config_list[0]
|
||
ori_measure = config_list[1]
|
||
ori_measure_id = utils.get_md5(ori_measure)
|
||
# 判断数据库中是否有数据
|
||
# cursor.execute(check_query.format(ori_measure_id=ori_measure_id))
|
||
# check_records = cursor.fetchall()
|
||
# if(len(check_records)) > 0:
|
||
# continue
|
||
data_to_insert = (utils.get_md5(measure), measure, ori_measure_id, ori_measure)
|
||
cursor.execute(insert_query, data_to_insert)
|
||
conn.commit()
|
||
|
||
def insert_measure_vector(conn,cursor):
|
||
|
||
redis_client = redis.Redis(host='192.168.0.172', port=6379, password='Xgf_redis', db=6)
|
||
# 执行SQL语句,更新数据
|
||
select_query = '''
|
||
SELECT ori_measure_id,ori_measure_name FROM measure_config
|
||
'''
|
||
cursor.execute(select_query)
|
||
records = cursor.fetchall()
|
||
for record in records:
|
||
if redis_client.hexists('measure_config', record[0]):
|
||
measure_vector = redis_client.hget('measure_config', record[0])
|
||
else:
|
||
print('新增指标',record[1])
|
||
vector_obj = utils.embed_with_str(record[1])
|
||
measure_vector = str(vector_obj.output["embeddings"][0]["embedding"])
|
||
|
||
redis_client.hset('measure_config', record[0], measure_vector)
|
||
redis_client.close()
|
||
conn.close()
|
||
|
||
def contains_financial_indicators(text):
|
||
import re
|
||
# 正则表达式模式匹配千分位格式的数字和百分比
|
||
pattern = r"\d{1,3}(,\d{3})+(\.\d{1,3})?"
|
||
|
||
pattern1 = r"\d+(.\d+)+%?"
|
||
# 使用 re.search 函数查找匹配项
|
||
match = re.search(pattern1, text)
|
||
|
||
# 如果找到匹配项,返回 True,否则返回 False
|
||
return bool(match)
|
||
|
||
def get_clean_text(text):
|
||
import re
|
||
pattern = r"\([^)]*?\)"
|
||
matches = re.findall(pattern, text)
|
||
for match in matches:
|
||
# 使用 re.findall 函数查找括号内的内容中是否包含月份或关键词
|
||
month_keywords_found = re.search(r"归属于|扣非", match)
|
||
if not month_keywords_found:
|
||
# 如果包含,则从文本中删除该部分
|
||
text = re.sub(pattern,"", text)
|
||
else:
|
||
# 如果不包含,删除所有标点符号和中文数字
|
||
text = re.sub(r"[^\w\s]", "", text)
|
||
print(text)
|
||
|
||
def insert_and_update(conn,cursor,client,parent_table_pages,file_id,path):
|
||
# #通过向量查询指标
|
||
db_service.insert_table_measure_from_vector(conn,cursor,client,parent_table_pages,file_id,path)
|
||
|
||
# #指标归一化处理
|
||
db_service.update_ori_measure(conn,cursor,file_id)
|
||
|
||
def print_measure_data(cursor,client):
|
||
select_query = '''
|
||
SELECT ori_measure_name,measure_name,ori_measure_id FROM measure_config
|
||
where measure_id not in(select distinct measure_id from ori_measure_list where file_id='64')
|
||
'''
|
||
cursor.execute(select_query)
|
||
records = cursor.fetchall()
|
||
for record in records:
|
||
ori_measure_name = record[0]
|
||
measure_name = record[1]
|
||
ori_measure_id = record[2]
|
||
measure_vector = redis_service.read_from_redis(ori_measure_id)
|
||
|
||
measure_list = ast.literal_eval(measure_vector)
|
||
data = [measure_list]
|
||
res = client.search(
|
||
collection_name="pdf_measure_v4", # Replace with the actual name of your collection
|
||
# Replace with your query vector
|
||
data=data,
|
||
limit=2, # Max. number of search results to return
|
||
search_params={"metric_type": "COSINE", "params": {}}, # Search parameters
|
||
output_fields=["measure_name","measure_value","table_num","table_index"],
|
||
filter = 'file_id == "64"'
|
||
)
|
||
vector_str = measure_name+":"+ori_measure_name
|
||
# Convert the output to a formatted JSON string
|
||
for i in range(len(res[0])):
|
||
|
||
vector_distance = float(res[0][i]["distance"])
|
||
vector_measure_name = res[0][i]["entity"]["measure_name"]
|
||
measure_value = res[0][i]["entity"]["measure_value"]
|
||
table_num = res[0][i]["entity"]["table_num"]
|
||
table_index = res[0][i]["entity"]["table_index"]
|
||
table_num_list = [106]
|
||
print(vector_str +":"+vector_measure_name+":"+str(vector_distance) +":"+measure_value +":"+str(table_num) +":"+str(table_index))
|
||
# if vector_distance > 0.89 and table_num not in table_num_list:
|
||
# print(vector_str +":"+vector_measure_name+":"+str(vector_distance) +":"+measure_value +":"+str(table_num) +":"+str(table_index)+":"+str(0.94))
|
||
# if vector_distance > distance and table_num not in table_num_list:
|
||
# print(vector_str +":"+vector_measure_name +":"+measure_value +":"+str(table_num) +":"+str(table_index)+":"+str(vector_distance)+":"+str(distance))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
conn = mysql.connector.connect(
|
||
host=MYSQL_HOST,
|
||
user=MYSQL_USER,
|
||
password=MYSQL_PASSWORD,
|
||
database=MYSQL_DB
|
||
)
|
||
cursor = conn.cursor()
|
||
|
||
insert_measure_vector(conn,cursor)
|