pdf_code/zzb_data_word/test/zzb.py

356 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import camelot
import re
import os
import json
import numpy as np
from datetime import datetime
# 读取PDF
import PyPDF2
# 分析PDF的layout提取文本
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTTextContainer, LTRect
import pdfplumber
import mysql.connector
# 数据处理流程
# 1. 解析pdf标题获取多级标题名称及页码范围
# 2. 基于规则选出需要解析的标题下内容
# 3. 根据需要解析的标题及页码,获取所有表格内容,并转化成带语义的指标
# 4. 根据需要解析的标题及页码,获取所有非表格类的正文
# 5. 文本和表格指标调用大模型抽取原始指标
# 6. 根据规则讲原始指标转化为最终显示指标
STR_PATTERN = '营业收入|净利润|变动比例|损益|现金流量净额|现金流|每股收益|总资产|资产总额|收益率'
def get_md5(str):
import hashlib
m = hashlib.md5()
m.update(str.encode('utf-8'))
return m.hexdigest()
#获取指标的表头信息
def get_num_info(array,row_num,col_num,x,y):
num_info=""
for j in range(col_num):
if len(str(array[x][j])) > 50:
continue
num_info += str(array[x][j])
for i in range(row_num):
if len(str(array[i][y])) > 50:
continue
num_info += str(array[i][y])
return num_info.replace('%','')
def get_parse_pages(page_dict):
"""
:return: 返回一个存储需要解析的页码文本
"""
return "all"
# 读取pdf文件中文本内容不包括表格
def get_text_content(pdf_path):
"""
:return: 返回pdf文件中文本内容不包括表格
"""
page_obj = []
# 我们从PDF中提取页面
for pagenum, page in enumerate(extract_pages(pdf_path)):
page_text = ''
text_obj = {}
# 初始化检查表的数量
table_num = 0
first_element= True
table_extraction_flag= False
# # 打开pdf文件
pdf = pdfplumber.open(pdf_path)
# 查找已检查的页面
page_tables = pdf.pages[pagenum]
# 找出本页上的表格数目
tables = page_tables.find_tables()
# 找到所有的元素
page_elements = [(element.y1, element) for element in page._objs]
# 对页面中出现的所有元素进行排序
page_elements.sort(key=lambda a: a[0], reverse=True)
# 查找组成页面的元素
for i,component in enumerate(page_elements):
# 提取PDF中元素顶部的位置
pos= component[0]
# 提取页面布局的元素
element = component[1]
# 检查该元素是否为文本元素
if isinstance(element, LTTextContainer):
# 检查文本是否出现在表中
if table_extraction_flag == False:
# 使用该函数提取每个文本元素的文本和格式
line_text = element.get_text().replace('\s+', '').replace('\n', '').replace('\r', '')
# 将每行的文本追加到页文本
if len(line_text) > 5:
page_text += line_text
# 附加每一行包含文本的格式
else:
# 省略表中出现的文本
pass
# 检查表的元素
if isinstance(element, LTRect):
# 如果第一个矩形元素
if first_element == True and (table_num+1) <= len(tables):
# 找到表格的边界框
lower_side = page.bbox[3] - tables[table_num].bbox[3]
upper_side = element.y1
# 将标志设置为True以再次避免该内容
table_extraction_flag = True
# 让它成为另一个元素
first_element = False
# 检查我们是否已经从页面中提取了表
if element.y0 >= lower_side and element.y1 <= upper_side:
pass
elif not isinstance(page_elements[i+1][1], LTRect):
table_extraction_flag = False
first_element = True
table_num+=1
text_obj['page_num'] = pagenum
text_obj['text'] = page_text
page_obj.append(text_obj)
# 打印提取的文本
# print(page_obj)
return page_obj
# 读取pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
def get_table_measure(file_path, page_num="all"):
"""
:return: pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
"""
measure_obj = []
tables = camelot.read_pdf(file_path, pages=page_num, strip_text=' ,\n', copy_text=['h'])
for t in tables:
data_dict = {}
measure_list = []
arr = np.array(t.data)
rows, cols = arr.shape
if rows == 1 and cols == 1:
continue
arr_str = ''.join([''.join(map(str, row)) for row in arr])
matches = re.findall(STR_PATTERN, arr_str)
if len(matches) > 0:
arr = np.array(t.data)
rows, cols = arr.shape
row_num , col_num = -1 , -1
# 使用嵌套循环遍历数组,获取第一个数值位置
for i in range(rows):
for j in range(cols):
if re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', str(arr[i, j])):
if j == cols-1:
row_num , col_num = i , j
break
elif (re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', str(arr[i, j+1]))
or str(arr[i, j+1]) == '-'):
row_num , col_num = i , j
break
else:
continue
break
# 遍历数值二维数组,转成带语义的指标
if row_num != -1 and col_num != -1:
for i in range(row_num,arr.shape[0]):
for j in range(col_num,arr.shape[1]):
if arr[i, j] == '-' or arr[i, j] == '' or len(arr[i, j]) > 20:
continue
else:
num_info = get_num_info(arr,row_num,col_num,i,j)
measure_list.append(f"{num_info}{arr[i, j]}")
# print(f"{num_info}为{arr[i, j]}")
else:
pass
if len(measure_list) > 0:
data_dict["measure_list"] = measure_list
data_dict["page_num"] = f"{str(t.page)}_{str(t.order)}"
measure_obj.append(data_dict)
# print(measure_obj)
return measure_obj
# 文本和表格数据给大模型,返回大模型抽取原始指标列表
def get_measure_from_llm(user_prompt):
"""
:return: 文本和表格数据给大模型,返回大模型抽取原始指标列表
"""
import random
from http import HTTPStatus
from dashscope import Generation
llm_measure_list = []
system_prompt = '''
你是一个优秀的金融分析师从给定的数据报告中自动提取以下10个关键财务指标。指标包括
2023年营业收入
2023年合计营业收入
2023年调整后营业收入
2022年营业收入
2022年合计营业收入
2022年调整后营业收入
2023年营业收入变动比例
2023年营业收入比上年同期增减
2023年归属母公司净利润
2023年归属于上市公司股东的净利润
2023年归属母公司净利润变动比例
请确保只抽取这些指标,并且每个指标的输出格式为:指标名:指标值
所有的指标值必须从用户提供的信息中抽取,不允许自己生成,如果找不到相关指标,指标值显示为-
<数据报告>
<user_prompt>
</数据报告>
'''
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
response = Generation.call(
model='qwen-turbo',
prompt = system_prompt,
seed=random.randint(1, 10000),
top_p=0.1,
result_format='message',
enable_search=False,
max_tokens=1500,
temperature=0.85,
repetition_penalty=1.0
)
if response.status_code == HTTPStatus.OK:
result = response['output']['choices'][0]['message']['content']
llm_measure_list = result.split('\n')
return llm_measure_list
else:
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return "llm_error"
# 解析大模型抽取的指标,并插入到数据库
def parse_llm_measure_to_db(measure_info,type,conn,cursor):
create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 执行SQL语句插入数据
insert_query = '''
INSERT INTO ori_measure_list
(file_id, file_name, type, page_number, table_index, ori_measure_id, ori_measure_name, ori_measure_value, create_time, update_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
'''
file_id = '1111111'
file_name = '科润智控.pdf'
llm_measure = measure_info['llm_measure']
page_num = measure_info['page_num']
table_index = '0'
if type == 'table':
table_index = measure_info['table_index']
for measure_obj in llm_measure:
measure_obj = measure_obj.replace('\n', '').replace('\r', '').replace(' ', '')
if ':' in measure_obj:
ori_measure_name = measure_obj.split(':')[0]
ori_measure_value = measure_obj.split(':')[1].replace('+', '')
if '-' in ori_measure_value:
ori_measure_value = "-"
if '.' in ori_measure_name:
ori_measure_name = ori_measure_name.split('.')[1]
ori_measure_id = get_md5(ori_measure_name)
data_to_insert = (file_id, file_name, type, int(page_num), int(table_index), ori_measure_id, ori_measure_name, ori_measure_value, create_time, create_time)
cursor.execute(insert_query, data_to_insert)
print(f"{type},{page_num},{table_index},{ori_measure_name},{ori_measure_value}")
# 提交事务
conn.commit()
return ""
# 根据measure_config中的规则更新原始指标的显示指标
def update_ori_measure(conn,cursor):
# 执行SQL语句更新数据
update_query = '''
UPDATE ori_measure_list
SET measure_id = %s, measure_name = %s
WHERE ori_measure_id = %s and ori_measure_value !='-'
'''
# 执行SQL语句更新数据
select_query = '''
SELECT measure_id,measure_name,ori_measure_id FROM measure_config
'''
cursor.execute(select_query)
records = cursor.fetchall()
for record in records:
data_to_update = (record[0], record[1], record[2])
cursor.execute(update_query, data_to_update)
conn.commit()
if __name__ == "__main__":
start_time = datetime.now()
print("开始时间:", start_time.strftime("%Y-%m-%d %H:%M:%S"))
path = "/Users/zhengfei/Desktop/科润智控1.pdf"
table_info = get_table_measure(path)
# text_info = get_text_content(path)
# # 数据库连接对象
# # 连接到MySQL数据库
# conn = mysql.connector.connect(
# host="121.37.185.246",
# user="financial",
# password="financial_8000",
# database="financial_report"
# )
# 创建一个cursor对象来执行SQL语句
# cursor = conn.cursor()
for table_obj in table_info:
table_measure_obj = {}
table_page_num = table_obj['page_num'].split("_")[0]
table_index = table_obj['page_num'].split("_")[1]
table_measure = ','.join(table_obj['measure_list'])
if table_page_num == '3':
print(f"{table_page_num}页表格指标为:{table_measure}")
table_llm_measure = get_measure_from_llm(table_measure)
if table_page_num == '3':
print(f"{table_page_num}页表格llm指标为{table_llm_measure}")
# table_measure_obj['page_num'] = table_page_num
# table_measure_obj['table_index'] = table_index
# table_measure_obj['llm_measure'] = table_llm_measure
# parse_llm_measure_to_db(table_measure_obj,'table',conn,cursor)
# for text_obj in text_info:
# text_measure_obj = {}
# text_page_num = text_obj['page_num']
# text = text_obj['text']
# if len(text) > 10:
# text_llm_measure = get_measure_from_llm(text)
# text_measure_obj['page_num'] = text_page_num
# text_measure_obj['llm_measure'] = text_llm_measure
# parse_llm_measure_to_db(text_measure_obj,'text',conn,cursor)
# print(text_llm_measure)
# update_ori_measure(conn,cursor)
# cursor.close()
# conn.close()
# measure_info =['1. 2023年营业收入: 983698831.48', '2. 2023年营业收入变动比例: 15.10%', '3. 2023年归属母公司净利润: - (未在报告中找到)', '4. 2023年归属母公司净利润变动比例: - (未在报告中找到)', '5. 2023年毛利率: (营业收入 - 主营业务成本) / 营业收入 = (983698831.48 - 793604607.43) / 983698831.48', '6. 2022年毛利率: (主营业务收入 - 主营业务成本) / 主营业务收入 = (854682261.31 - 690932741.27) / 854682261.31', '7. 2023年主营业务收入: 983698831.48', '8. 2022年主营业务收入: 854682261.31']
# parse_llm_measure_to_db(measure_info)
# get_measure_from_llm()
end_time = datetime.now()
print("结束时间:", end_time.strftime("%Y-%m-%d %H:%M:%S"))
#print(pdf_data)