221 lines
8.8 KiB
Python
221 lines
8.8 KiB
Python
from config import MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB
|
||
import mysql.connector
|
||
from http import HTTPStatus
|
||
import dashscope
|
||
import random,re
|
||
from pdfminer.high_level import extract_pages
|
||
from pdfminer.layout import LTTextBoxHorizontal
|
||
import PyPDF2
|
||
dashscope.api_key='sk-63c02fbb9b7d4b0494a3200bec1ae286'
|
||
|
||
def get_company_name(file_path):
|
||
line_text = ''
|
||
# 我们从PDF中提取页面,page_numbers=[4,5,6]
|
||
for pagenum, page in enumerate(extract_pages(file_path)):
|
||
if pagenum > 1:
|
||
break
|
||
# 找到所有的元素
|
||
page_elements = [(element.y1, element) for element in page._objs]
|
||
# 查找组成页面的元素
|
||
for i,component in enumerate(page_elements):
|
||
# 提取页面布局的元素
|
||
element = component[1]
|
||
# 检查该元素是否为文本元素
|
||
if isinstance(element, LTTextBoxHorizontal):
|
||
# 检查文本是否出现在表中
|
||
line_text += element.get_text()
|
||
|
||
return llm_service(line_text)
|
||
def get_company_code(file_path):
|
||
line_text = ''
|
||
# 我们从PDF中提取页面,page_numbers=[4,5,6]
|
||
for pagenum, page in enumerate(extract_pages(file_path)):
|
||
if pagenum > 1:
|
||
break
|
||
# 找到所有的元素
|
||
page_elements = [(element.y1, element) for element in page._objs]
|
||
# 查找组成页面的元素
|
||
for i,component in enumerate(page_elements):
|
||
# 提取页面布局的元素
|
||
element = component[1]
|
||
# 检查该元素是否为文本元素
|
||
if isinstance(element, LTTextBoxHorizontal):
|
||
# 检查文本是否出现在表中
|
||
line_text += element.get_text()
|
||
|
||
return llm_service_code(line_text)
|
||
#获取公司简介的那一页
|
||
# def get_code_page(pdf_path):
|
||
# with open(pdf_path, 'rb') as file:
|
||
# reader = PyPDF2.PdfReader(file)
|
||
# outlines = reader.outline
|
||
# company_profile_page = None
|
||
|
||
# def find_page_from_outlines(outlines):
|
||
# nonlocal company_profile_page
|
||
# for item in outlines:
|
||
# if isinstance(item, list): # 如果是子目录,则递归
|
||
# find_page_from_outlines(item)
|
||
# else:
|
||
# title = item.title
|
||
# if title is not None and '公司简介' in title:
|
||
# # 获取页面的实际页码
|
||
# page_num = reader.get_destination_page_number(item)
|
||
# company_profile_page = page_num
|
||
# return
|
||
# # 处理没有标题的情况
|
||
# elif item.page is not None:
|
||
# page_num = reader.get_destination_page_number(item)
|
||
# if page_num is not None:
|
||
# pass
|
||
|
||
# find_page_from_outlines(outlines)
|
||
|
||
# return company_profile_page
|
||
|
||
# def get_company_code(file_path):
|
||
# line_text = ''
|
||
# # 我们从PDF中提取页面,page_numbers=[4,5,6]
|
||
# for pagenum, page in enumerate(extract_pages(file_path)):
|
||
# print(f'页码是{get_code_page(file_path)+1}')
|
||
# if pagenum > 1 and pagenum != get_code_page(file_path)+1:
|
||
# break
|
||
# # 找到所有的元素
|
||
# #print(pagenum)
|
||
# page_elements = [(element.y1, element) for element in page._objs]
|
||
# # 查找组成页面的元素
|
||
# # for i,component in enumerate(page_elements):
|
||
# # # 提取页面布局的元素
|
||
# # element = component[1]
|
||
# # # 检查该元素是否为文本元素
|
||
# # if isinstance(element, LTTextBoxHorizontal):
|
||
# # # 检查文本是否出现在表中
|
||
# # line_text += element.get_text()
|
||
# for _, element in page_elements:
|
||
# if isinstance(element, LTTextBoxHorizontal):
|
||
# # 提取文本并添加到 line_text
|
||
# line_text += element.get_text()
|
||
|
||
# return llm_service_code(line_text)
|
||
def llm_service(user_prompt):
|
||
|
||
system_prompt = '''
|
||
从以下数据报告中提取公司全称,只需要提取中文公司全称,不要增加其他内容,如果提取不到公司全称,请返回-。
|
||
<数据报告>
|
||
<user_prompt>
|
||
</数据报告>
|
||
'''
|
||
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
|
||
response = dashscope.Generation.call(
|
||
model='qwen-plus',
|
||
prompt = system_prompt,
|
||
seed=random.randint(1, 10000),
|
||
top_p=0.8,
|
||
result_format='message',
|
||
enable_search=False,
|
||
max_tokens=1500,
|
||
temperature=0.85,
|
||
repetition_penalty=1.0
|
||
)
|
||
if response.status_code == HTTPStatus.OK:
|
||
result = response['output']['choices'][0]['message']['content']
|
||
return result
|
||
else:
|
||
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||
response.request_id, response.status_code,
|
||
response.code, response.message
|
||
))
|
||
|
||
return "llm_error"
|
||
def llm_service_code(user_prompt):
|
||
|
||
system_prompt = '''
|
||
从以下数据报告中提取6位数字的股票代码,只需要提取股票代码,如果有多个则以','隔开,不要增加其他内容,如果提取不到股票代码,请返回-,不要返回其他任何内容。
|
||
<数据报告>
|
||
<user_prompt>
|
||
</数据报告>
|
||
'''
|
||
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
|
||
response = dashscope.Generation.call(
|
||
model='qwen-plus',
|
||
prompt = system_prompt,
|
||
seed=random.randint(1, 10000),
|
||
top_p=0.8,
|
||
result_format='message',
|
||
enable_search=False,
|
||
max_tokens=1500,
|
||
temperature=0.85,
|
||
repetition_penalty=1.0
|
||
)
|
||
if response.status_code == HTTPStatus.OK:
|
||
result = response['output']['choices'][0]['message']['content']
|
||
return result
|
||
else:
|
||
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||
response.request_id, response.status_code,
|
||
response.code, response.message
|
||
))
|
||
|
||
return "llm_error"
|
||
def update_company_name(file_id, company_name,company_code, cursor, conn):
|
||
update_sql = f'''
|
||
UPDATE report_check
|
||
SET c_name = '{company_name}',c_code = '{company_code}'
|
||
WHERE id = {file_id}
|
||
'''
|
||
cursor.execute(update_sql)
|
||
conn.commit()
|
||
|
||
if __name__ == '__main__':
|
||
conn = mysql.connector.connect(
|
||
host = MYSQL_HOST,
|
||
user = MYSQL_USER,
|
||
password = MYSQL_PASSWORD,
|
||
database = MYSQL_DB
|
||
)
|
||
|
||
# 创建一个cursor对象来执行SQL语句
|
||
cursor = conn.cursor()
|
||
|
||
data_query = '''
|
||
SELECT id,file_path FROM report_check where c_code is null
|
||
'''
|
||
|
||
cursor.execute(data_query)
|
||
data_list = cursor.fetchall()
|
||
|
||
for data in data_list:
|
||
try:
|
||
file_id = data[0]
|
||
#生产环境地址
|
||
file_path = f'/usr/local/zhanglei/financial{data[1]}'
|
||
#测试环境地址
|
||
# file_path_1 = f'/root/pdf_parser/pdf/{data[1]}'
|
||
# file_path = file_path_1.replace('/upload/file/','')
|
||
print(f'财报{file_id}开始解析')
|
||
#file_id = '305'
|
||
#file_path = r"F:\11_pdf\7874.pdf"
|
||
|
||
company_name = get_company_name(file_path)
|
||
contains_newline = '\n' in company_name
|
||
if contains_newline:
|
||
lines = company_name.splitlines(True)
|
||
company_name = lines[0]
|
||
|
||
company_code = get_company_code(file_path)
|
||
contains_newline1 = '\n' in company_code
|
||
if contains_newline1:
|
||
lines = company_code.splitlines(True)
|
||
company_code = lines[0]
|
||
|
||
if company_name != "llm_error" or company_code != "llm_error":
|
||
#print(company_code)
|
||
pattern = re.compile(r'^(\d{6}|\d{6}(,\d{6})*)$')
|
||
if not pattern.match(company_code):
|
||
company_code = '-'
|
||
update_company_name(file_id, company_name,company_code, cursor, conn)
|
||
except Exception as e:
|
||
print(f'财报解析失败',e)
|
||
|
||
cursor.close()
|
||
conn.close() |