358 lines
14 KiB
Python
358 lines
14 KiB
Python
|
import camelot
|
|||
|
import re
|
|||
|
import os
|
|||
|
import json
|
|||
|
import numpy as np
|
|||
|
from datetime import datetime
|
|||
|
import logging
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
# 读取PDF
|
|||
|
import PyPDF2
|
|||
|
# 分析PDF的layout,提取文本
|
|||
|
from pdfminer.high_level import extract_pages
|
|||
|
from pdfminer.layout import LTTextContainer, LTRect
|
|||
|
import pdfplumber
|
|||
|
import mysql.connector
|
|||
|
|
|||
|
# 数据处理流程
|
|||
|
# 1. 解析pdf标题,获取多级标题名称及页码范围
|
|||
|
# 2. 基于规则选出需要解析的标题下内容
|
|||
|
# 3. 根据需要解析的标题及页码,获取所有表格内容,并转化成带语义的指标
|
|||
|
# 4. 根据需要解析的标题及页码,获取所有非表格类的正文
|
|||
|
# 5. 文本和表格指标调用大模型抽取原始指标
|
|||
|
# 6. 根据规则讲原始指标转化为最终显示指标
|
|||
|
|
|||
|
STR_PATTERN = '营业收入|净利润|变动比例|损益|现金流量净额|现金流|每股收益|总资产|资产总额|收益率'
|
|||
|
|
|||
|
def get_md5(str):
|
|||
|
import hashlib
|
|||
|
m = hashlib.md5()
|
|||
|
m.update(str.encode('utf-8'))
|
|||
|
return m.hexdigest()
|
|||
|
|
|||
|
#获取指标的表头信息
|
|||
|
def get_num_info(array,row_num,col_num,x,y):
|
|||
|
num_info=""
|
|||
|
for j in range(col_num):
|
|||
|
if len(str(array[x][j])) > 50:
|
|||
|
continue
|
|||
|
num_info += str(array[x][j])
|
|||
|
|
|||
|
for i in range(row_num):
|
|||
|
if len(str(array[i][y])) > 50:
|
|||
|
continue
|
|||
|
num_info += str(array[i][y])
|
|||
|
|
|||
|
|
|||
|
return num_info.replace('%','')
|
|||
|
|
|||
|
def get_parse_pages(page_dict):
|
|||
|
"""
|
|||
|
:return: 返回一个存储需要解析的页码文本
|
|||
|
"""
|
|||
|
return "all"
|
|||
|
|
|||
|
# 读取pdf文件中文本内容,不包括表格
|
|||
|
def get_text_content(pdf_path):
|
|||
|
"""
|
|||
|
:return: 返回pdf文件中文本内容,不包括表格
|
|||
|
"""
|
|||
|
page_obj = []
|
|||
|
|
|||
|
# 我们从PDF中提取页面
|
|||
|
for pagenum, page in enumerate(extract_pages(pdf_path)):
|
|||
|
page_text = ''
|
|||
|
text_obj = {}
|
|||
|
# 初始化检查表的数量
|
|||
|
table_num = 0
|
|||
|
first_element= True
|
|||
|
table_extraction_flag= False
|
|||
|
# # 打开pdf文件
|
|||
|
pdf = pdfplumber.open(pdf_path)
|
|||
|
# 查找已检查的页面
|
|||
|
page_tables = pdf.pages[pagenum]
|
|||
|
# 找出本页上的表格数目
|
|||
|
tables = page_tables.find_tables()
|
|||
|
|
|||
|
# 找到所有的元素
|
|||
|
page_elements = [(element.y1, element) for element in page._objs]
|
|||
|
# 对页面中出现的所有元素进行排序
|
|||
|
page_elements.sort(key=lambda a: a[0], reverse=True)
|
|||
|
|
|||
|
# 查找组成页面的元素
|
|||
|
for i,component in enumerate(page_elements):
|
|||
|
# 提取PDF中元素顶部的位置
|
|||
|
pos= component[0]
|
|||
|
# 提取页面布局的元素
|
|||
|
element = component[1]
|
|||
|
|
|||
|
# 检查该元素是否为文本元素
|
|||
|
if isinstance(element, LTTextContainer):
|
|||
|
# 检查文本是否出现在表中
|
|||
|
|
|||
|
if table_extraction_flag == False:
|
|||
|
# 使用该函数提取每个文本元素的文本和格式
|
|||
|
line_text = element.get_text().replace('\s+', '').replace('\n', '').replace('\r', '')
|
|||
|
# 将每行的文本追加到页文本
|
|||
|
if len(line_text) > 5:
|
|||
|
page_text += line_text
|
|||
|
|
|||
|
# 附加每一行包含文本的格式
|
|||
|
else:
|
|||
|
# 省略表中出现的文本
|
|||
|
pass
|
|||
|
|
|||
|
# 检查表的元素
|
|||
|
if isinstance(element, LTRect):
|
|||
|
# 如果第一个矩形元素
|
|||
|
if first_element == True and (table_num+1) <= len(tables):
|
|||
|
# 找到表格的边界框
|
|||
|
lower_side = page.bbox[3] - tables[table_num].bbox[3]
|
|||
|
upper_side = element.y1
|
|||
|
# 将标志设置为True以再次避免该内容
|
|||
|
table_extraction_flag = True
|
|||
|
# 让它成为另一个元素
|
|||
|
first_element = False
|
|||
|
|
|||
|
# 检查我们是否已经从页面中提取了表
|
|||
|
if element.y0 >= lower_side and element.y1 <= upper_side:
|
|||
|
pass
|
|||
|
elif not isinstance(page_elements[i+1][1], LTRect):
|
|||
|
table_extraction_flag = False
|
|||
|
first_element = True
|
|||
|
table_num+=1
|
|||
|
|
|||
|
text_obj['page_num'] = pagenum
|
|||
|
text_obj['text'] = page_text
|
|||
|
page_obj.append(text_obj)
|
|||
|
|
|||
|
# 打印提取的文本
|
|||
|
# print(page_obj)
|
|||
|
return page_obj
|
|||
|
|
|||
|
# 读取pdf中的表格,并将表格中指标和表头合并,eg: 2022年1季度营业收入为xxxxx
|
|||
|
def get_table_measure(file_path, page_num="all"):
|
|||
|
"""
|
|||
|
:return: pdf中的表格,并将表格中指标和表头合并,eg: 2022年1季度营业收入为xxxxx
|
|||
|
"""
|
|||
|
measure_obj = []
|
|||
|
tables = camelot.read_pdf(file_path, pages=page_num, strip_text=' ,\n', copy_text=['h'])
|
|||
|
|
|||
|
for t in tables:
|
|||
|
data_dict = {}
|
|||
|
measure_list = []
|
|||
|
arr = np.array(t.data)
|
|||
|
rows, cols = arr.shape
|
|||
|
if rows == 1 and cols == 1:
|
|||
|
continue
|
|||
|
arr_str = ''.join([''.join(map(str, row)) for row in arr])
|
|||
|
|
|||
|
matches = re.findall(STR_PATTERN, arr_str)
|
|||
|
if len(matches) > 0:
|
|||
|
arr = np.array(t.data)
|
|||
|
rows, cols = arr.shape
|
|||
|
row_num , col_num = -1 , -1
|
|||
|
# 使用嵌套循环遍历数组,获取第一个数值位置
|
|||
|
for i in range(rows):
|
|||
|
for j in range(cols):
|
|||
|
if re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', str(arr[i, j])):
|
|||
|
if j == cols-1:
|
|||
|
row_num , col_num = i , j
|
|||
|
break
|
|||
|
elif (re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', str(arr[i, j+1]))
|
|||
|
or str(arr[i, j+1]) == '-'):
|
|||
|
row_num , col_num = i , j
|
|||
|
break
|
|||
|
else:
|
|||
|
continue
|
|||
|
break
|
|||
|
|
|||
|
# 遍历数值二维数组,转成带语义的指标
|
|||
|
if row_num != -1 and col_num != -1:
|
|||
|
for i in range(row_num,arr.shape[0]):
|
|||
|
for j in range(col_num,arr.shape[1]):
|
|||
|
if arr[i, j] == '-' or arr[i, j] == '' or len(arr[i, j]) > 20:
|
|||
|
continue
|
|||
|
else:
|
|||
|
num_info = get_num_info(arr,row_num,col_num,i,j)
|
|||
|
measure_list.append(f"{num_info}为{arr[i, j]}")
|
|||
|
# print(f"{num_info}为{arr[i, j]}")
|
|||
|
else:
|
|||
|
pass
|
|||
|
if len(measure_list) > 0:
|
|||
|
data_dict["measure_list"] = measure_list
|
|||
|
data_dict["page_num"] = f"{str(t.page)}_{str(t.order)}"
|
|||
|
measure_obj.append(data_dict)
|
|||
|
# print(measure_obj)
|
|||
|
return measure_obj
|
|||
|
|
|||
|
# 文本和表格数据给大模型,返回大模型抽取原始指标列表
|
|||
|
def get_measure_from_llm(user_prompt):
|
|||
|
"""
|
|||
|
:return: 文本和表格数据给大模型,返回大模型抽取原始指标列表
|
|||
|
"""
|
|||
|
import random
|
|||
|
from http import HTTPStatus
|
|||
|
from dashscope import Generation
|
|||
|
|
|||
|
llm_measure_list = []
|
|||
|
system_prompt = '''
|
|||
|
你是一个优秀的金融分析师,从给定的数据报告中自动提取以下10个关键财务指标。指标包括:
|
|||
|
2023年营业收入
|
|||
|
2023年合计营业收入
|
|||
|
2023年调整后营业收入
|
|||
|
2022年营业收入
|
|||
|
2022年合计营业收入
|
|||
|
2022年调整后营业收入
|
|||
|
2023年营业收入变动比例
|
|||
|
2023年营业收入比上年同期增减
|
|||
|
2023年归属母公司净利润
|
|||
|
2023年归属于上市公司股东的净利润
|
|||
|
2023年归属母公司净利润变动比例
|
|||
|
请确保只抽取这些指标,并且每个指标的输出格式为:指标名:指标值
|
|||
|
所有的指标值必须从用户提供的信息中抽取,不允许自己生成,如果找不到相关指标,指标值显示为-
|
|||
|
<数据报告>
|
|||
|
<user_prompt>
|
|||
|
</数据报告>
|
|||
|
'''
|
|||
|
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
|
|||
|
response = Generation.call(
|
|||
|
model='qwen-turbo',
|
|||
|
prompt = system_prompt,
|
|||
|
seed=random.randint(1, 10000),
|
|||
|
top_p=0.1,
|
|||
|
result_format='message',
|
|||
|
enable_search=False,
|
|||
|
max_tokens=1500,
|
|||
|
temperature=0.85,
|
|||
|
repetition_penalty=1.0
|
|||
|
)
|
|||
|
if response.status_code == HTTPStatus.OK:
|
|||
|
result = response['output']['choices'][0]['message']['content']
|
|||
|
llm_measure_list = result.split('\n')
|
|||
|
return llm_measure_list
|
|||
|
else:
|
|||
|
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
|||
|
response.request_id, response.status_code,
|
|||
|
response.code, response.message
|
|||
|
))
|
|||
|
|
|||
|
return "llm_error"
|
|||
|
|
|||
|
|
|||
|
# 解析大模型抽取的指标,并插入到数据库
|
|||
|
def parse_llm_measure_to_db(measure_info,type,conn,cursor):
|
|||
|
|
|||
|
create_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
|
|||
|
# 执行SQL语句,插入数据
|
|||
|
insert_query = '''
|
|||
|
INSERT INTO ori_measure_list
|
|||
|
(file_id, file_name, type, page_number, table_index, ori_measure_id, ori_measure_name, ori_measure_value, create_time, update_time)
|
|||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|||
|
'''
|
|||
|
file_id = '1111111'
|
|||
|
file_name = '科润智控.pdf'
|
|||
|
llm_measure = measure_info['llm_measure']
|
|||
|
page_num = measure_info['page_num']
|
|||
|
table_index = '0'
|
|||
|
if type == 'table':
|
|||
|
table_index = measure_info['table_index']
|
|||
|
|
|||
|
for measure_obj in llm_measure:
|
|||
|
measure_obj = measure_obj.replace('\n', '').replace('\r', '').replace(' ', '')
|
|||
|
|
|||
|
if ':' in measure_obj:
|
|||
|
ori_measure_name = measure_obj.split(':')[0]
|
|||
|
ori_measure_value = measure_obj.split(':')[1].replace('+', '')
|
|||
|
if '-' in ori_measure_value:
|
|||
|
ori_measure_value = "-"
|
|||
|
if '.' in ori_measure_name:
|
|||
|
ori_measure_name = ori_measure_name.split('.')[1]
|
|||
|
ori_measure_id = get_md5(ori_measure_name)
|
|||
|
data_to_insert = (file_id, file_name, type, int(page_num), int(table_index), ori_measure_id, ori_measure_name, ori_measure_value, create_time, create_time)
|
|||
|
cursor.execute(insert_query, data_to_insert)
|
|||
|
logger.info(f"{type},{page_num},{table_index},{ori_measure_name},{ori_measure_value}")
|
|||
|
|
|||
|
# 提交事务
|
|||
|
conn.commit()
|
|||
|
return ""
|
|||
|
|
|||
|
# 根据measure_config中的规则,更新原始指标的显示指标
|
|||
|
def update_ori_measure(conn,cursor):
|
|||
|
|
|||
|
# 执行SQL语句,更新数据
|
|||
|
update_query = '''
|
|||
|
UPDATE ori_measure_list
|
|||
|
SET measure_id = %s, measure_name = %s
|
|||
|
WHERE ori_measure_id = %s and ori_measure_value !='-'
|
|||
|
'''
|
|||
|
|
|||
|
# 执行SQL语句,更新数据
|
|||
|
select_query = '''
|
|||
|
SELECT measure_id,measure_name,ori_measure_id FROM measure_config
|
|||
|
'''
|
|||
|
cursor.execute(select_query)
|
|||
|
records = cursor.fetchall()
|
|||
|
for record in records:
|
|||
|
data_to_update = (record[0], record[1], record[2])
|
|||
|
cursor.execute(update_query, data_to_update)
|
|||
|
|
|||
|
conn.commit()
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
start_time = datetime.now()
|
|||
|
logger.info("开始时间:", start_time.strftime("%Y-%m-%d %H:%M:%S"))
|
|||
|
|
|||
|
path = "/Users/zhengfei/Desktop/科润智控1.pdf"
|
|||
|
table_info = get_table_measure(path)
|
|||
|
# text_info = get_text_content(path)
|
|||
|
|
|||
|
# # 数据库连接对象
|
|||
|
# # 连接到MySQL数据库
|
|||
|
# conn = mysql.connector.connect(
|
|||
|
# host="121.37.185.246",
|
|||
|
# user="financial",
|
|||
|
# password="financial_8000",
|
|||
|
# database="financial_report"
|
|||
|
# )
|
|||
|
|
|||
|
# 创建一个cursor对象来执行SQL语句
|
|||
|
# cursor = conn.cursor()
|
|||
|
|
|||
|
for table_obj in table_info:
|
|||
|
table_measure_obj = {}
|
|||
|
table_page_num = table_obj['page_num'].split("_")[0]
|
|||
|
table_index = table_obj['page_num'].split("_")[1]
|
|||
|
table_measure = ','.join(table_obj['measure_list'])
|
|||
|
if table_page_num == '3':
|
|||
|
logger.info(f"第{table_page_num}页表格指标为:{table_measure}")
|
|||
|
table_llm_measure = get_measure_from_llm(table_measure)
|
|||
|
if table_page_num == '3':
|
|||
|
logger.info(f"第{table_page_num}页表格llm指标为:{table_llm_measure}")
|
|||
|
# table_measure_obj['page_num'] = table_page_num
|
|||
|
# table_measure_obj['table_index'] = table_index
|
|||
|
# table_measure_obj['llm_measure'] = table_llm_measure
|
|||
|
# parse_llm_measure_to_db(table_measure_obj,'table',conn,cursor)
|
|||
|
|
|||
|
# for text_obj in text_info:
|
|||
|
# text_measure_obj = {}
|
|||
|
# text_page_num = text_obj['page_num']
|
|||
|
# text = text_obj['text']
|
|||
|
# if len(text) > 10:
|
|||
|
# text_llm_measure = get_measure_from_llm(text)
|
|||
|
# text_measure_obj['page_num'] = text_page_num
|
|||
|
# text_measure_obj['llm_measure'] = text_llm_measure
|
|||
|
# parse_llm_measure_to_db(text_measure_obj,'text',conn,cursor)
|
|||
|
# print(text_llm_measure)
|
|||
|
|
|||
|
# update_ori_measure(conn,cursor)
|
|||
|
|
|||
|
# cursor.close()
|
|||
|
# conn.close()
|
|||
|
# measure_info =['1. 2023年营业收入: 983698831.48', '2. 2023年营业收入变动比例: 15.10%', '3. 2023年归属母公司净利润: - (未在报告中找到)', '4. 2023年归属母公司净利润变动比例: - (未在报告中找到)', '5. 2023年毛利率: (营业收入 - 主营业务成本) / 营业收入 = (983698831.48 - 793604607.43) / 983698831.48', '6. 2022年毛利率: (主营业务收入 - 主营业务成本) / 主营业务收入 = (854682261.31 - 690932741.27) / 854682261.31', '7. 2023年主营业务收入: 983698831.48', '8. 2022年主营业务收入: 854682261.31']
|
|||
|
# parse_llm_measure_to_db(measure_info)
|
|||
|
# get_measure_from_llm()
|
|||
|
end_time = datetime.now()
|
|||
|
logger.info("结束时间:", end_time.strftime("%Y-%m-%d %H:%M:%S"))
|
|||
|
#print(pdf_data)
|