269 lines
10 KiB
Python
269 lines
10 KiB
Python
|
from config import MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB
|
|||
|
import mysql.connector
|
|||
|
from http import HTTPStatus
|
|||
|
import dashscope
|
|||
|
import random,re
|
|||
|
from pdfminer.high_level import extract_pages
|
|||
|
from pdfminer.layout import LTTextBoxHorizontal
|
|||
|
import PyPDF2
|
|||
|
dashscope.api_key='sk-63c02fbb9b7d4b0494a3200bec1ae286'
|
|||
|
|
|||
|
def get_company_name(file_path):
|
|||
|
line_text = ''
|
|||
|
# 我们从PDF中提取页面,page_numbers=[4,5,6]
|
|||
|
for pagenum, page in enumerate(extract_pages(file_path)):
|
|||
|
if pagenum > 1:
|
|||
|
break
|
|||
|
# 找到所有的元素
|
|||
|
page_elements = [(element.y1, element) for element in page._objs]
|
|||
|
# 查找组成页面的元素
|
|||
|
for i,component in enumerate(page_elements):
|
|||
|
# 提取页面布局的元素
|
|||
|
element = component[1]
|
|||
|
# 检查该元素是否为文本元素
|
|||
|
if isinstance(element, LTTextBoxHorizontal):
|
|||
|
# 检查文本是否出现在表中
|
|||
|
line_text += element.get_text()
|
|||
|
|
|||
|
return llm_service(line_text)
|
|||
|
def get_company_code(file_path):
|
|||
|
line_text = ''
|
|||
|
# 我们从PDF中提取页面,page_numbers=[4,5,6]
|
|||
|
for pagenum, page in enumerate(extract_pages(file_path)):
|
|||
|
if pagenum > 1:
|
|||
|
break
|
|||
|
# 找到所有的元素
|
|||
|
page_elements = [(element.y1, element) for element in page._objs]
|
|||
|
# 查找组成页面的元素
|
|||
|
for i,component in enumerate(page_elements):
|
|||
|
# 提取页面布局的元素
|
|||
|
element = component[1]
|
|||
|
# 检查该元素是否为文本元素
|
|||
|
if isinstance(element, LTTextBoxHorizontal):
|
|||
|
# 检查文本是否出现在表中
|
|||
|
line_text += element.get_text()
|
|||
|
|
|||
|
return llm_service_code(line_text)
|
|||
|
#获取公司简介的那一页
|
|||
|
# def get_code_page(pdf_path):
|
|||
|
# with open(pdf_path, 'rb') as file:
|
|||
|
# reader = PyPDF2.PdfReader(file)
|
|||
|
# outlines = reader.outline
|
|||
|
# company_profile_page = None
|
|||
|
|
|||
|
# def find_page_from_outlines(outlines):
|
|||
|
# nonlocal company_profile_page
|
|||
|
# for item in outlines:
|
|||
|
# if isinstance(item, list): # 如果是子目录,则递归
|
|||
|
# find_page_from_outlines(item)
|
|||
|
# else:
|
|||
|
# title = item.title
|
|||
|
# if title is not None and '公司简介' in title:
|
|||
|
# # 获取页面的实际页码
|
|||
|
# page_num = reader.get_destination_page_number(item)
|
|||
|
# company_profile_page = page_num
|
|||
|
# return
|
|||
|
# # 处理没有标题的情况
|
|||
|
# elif item.page is not None:
|
|||
|
# page_num = reader.get_destination_page_number(item)
|
|||
|
# if page_num is not None:
|
|||
|
# pass
|
|||
|
|
|||
|
# find_page_from_outlines(outlines)
|
|||
|
|
|||
|
# return company_profile_page
|
|||
|
|
|||
|
# def get_company_code(file_path):
|
|||
|
# line_text = ''
|
|||
|
# # 我们从PDF中提取页面,page_numbers=[4,5,6]
|
|||
|
# for pagenum, page in enumerate(extract_pages(file_path)):
|
|||
|
# print(f'页码是{get_code_page(file_path)+1}')
|
|||
|
# if pagenum > 1 and pagenum != get_code_page(file_path)+1:
|
|||
|
# break
|
|||
|
# # 找到所有的元素
|
|||
|
# #print(pagenum)
|
|||
|
# page_elements = [(element.y1, element) for element in page._objs]
|
|||
|
# # 查找组成页面的元素
|
|||
|
# # for i,component in enumerate(page_elements):
|
|||
|
# # # 提取页面布局的元素
|
|||
|
# # element = component[1]
|
|||
|
# # # 检查该元素是否为文本元素
|
|||
|
# # if isinstance(element, LTTextBoxHorizontal):
|
|||
|
# # # 检查文本是否出现在表中
|
|||
|
# # line_text += element.get_text()
|
|||
|
# for _, element in page_elements:
|
|||
|
# if isinstance(element, LTTextBoxHorizontal):
|
|||
|
# # 提取文本并添加到 line_text
|
|||
|
# line_text += element.get_text()
|
|||
|
|
|||
|
# return llm_service_code(line_text)
|
|||
|
def llm_service(user_prompt):
|
|||
|
|
|||
|
system_prompt = '''
|
|||
|
从以下数据报告中提取公司全称,只需要提取中文公司全称,不要增加其他内容,如果提取不到公司全称,请返回-,不要返回其他任何内容。
|
|||
|
<数据报告>
|
|||
|
<user_prompt>
|
|||
|
</数据报告>
|
|||
|
'''
|
|||
|
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
|
|||
|
response = dashscope.Generation.call(
|
|||
|
model='qwen-plus',
|
|||
|
prompt = system_prompt,
|
|||
|
seed=random.randint(1, 10000),
|
|||
|
top_p=0.8,
|
|||
|
result_format='message',
|
|||
|
enable_search=False,
|
|||
|
max_tokens=1500,
|
|||
|
temperature=0.85,
|
|||
|
repetition_penalty=1.0
|
|||
|
)
|
|||
|
if response.status_code == HTTPStatus.OK:
|
|||
|
result = response['output']['choices'][0]['message']['content']
|
|||
|
return result
|
|||
|
else:
|
|||
|
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
|||
|
response.request_id, response.status_code,
|
|||
|
response.code, response.message
|
|||
|
))
|
|||
|
|
|||
|
return "llm_error"
|
|||
|
def llm_service_code(user_prompt):
|
|||
|
|
|||
|
system_prompt = '''
|
|||
|
从以下数据报告中提取6位数字的股票代码,只需要提取股票代码,如果有多个则以','隔开,不要增加其他内容,如果提取不到股票代码,请返回-,不要返回其他任何内容。
|
|||
|
<数据报告>
|
|||
|
<user_prompt>
|
|||
|
</数据报告>
|
|||
|
'''
|
|||
|
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
|
|||
|
response = dashscope.Generation.call(
|
|||
|
model='qwen-plus',
|
|||
|
prompt = system_prompt,
|
|||
|
seed=random.randint(1, 10000),
|
|||
|
top_p=0.8,
|
|||
|
result_format='message',
|
|||
|
enable_search=False,
|
|||
|
max_tokens=1500,
|
|||
|
temperature=0.85,
|
|||
|
repetition_penalty=1.0
|
|||
|
)
|
|||
|
if response.status_code == HTTPStatus.OK:
|
|||
|
result = response['output']['choices'][0]['message']['content']
|
|||
|
return result
|
|||
|
else:
|
|||
|
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
|||
|
response.request_id, response.status_code,
|
|||
|
response.code, response.message
|
|||
|
))
|
|||
|
|
|||
|
return "llm_error"
|
|||
|
def update_company_name(file_id, company_name,company_code, cursor, conn):
|
|||
|
update_sql = f'''
|
|||
|
UPDATE report_check
|
|||
|
SET c_name = '{company_name}',c_code = '{company_code}'
|
|||
|
WHERE id = {file_id}
|
|||
|
'''
|
|||
|
cursor.execute(update_sql)
|
|||
|
conn.commit()
|
|||
|
def name_code_fix(file_id,file_path):
|
|||
|
conn = mysql.connector.connect(
|
|||
|
host = MYSQL_HOST,
|
|||
|
user = MYSQL_USER,
|
|||
|
password = MYSQL_PASSWORD,
|
|||
|
database = MYSQL_DB
|
|||
|
)
|
|||
|
# 创建一个cursor对象来执行SQL语句
|
|||
|
cursor = conn.cursor()
|
|||
|
|
|||
|
try:
|
|||
|
# file_id = data[0]
|
|||
|
# #生产环境地址
|
|||
|
# file_path = f'/usr/local/zhanglei/financial{data[1]}'
|
|||
|
# #测试环境地址
|
|||
|
# # file_path_1 = f'/root/pdf_parser/pdf/{data[1]}'
|
|||
|
# # file_path = file_path_1.replace('/upload/file/','')
|
|||
|
# print(f'财报{file_id}开始解析')
|
|||
|
# #file_id = '305'
|
|||
|
# #file_path = r"F:\11_pdf\7874.pdf"
|
|||
|
company_name = get_company_name(file_path)
|
|||
|
contains_newline = '\n' in company_name
|
|||
|
if contains_newline:
|
|||
|
lines = company_name.splitlines(True)
|
|||
|
company_name = lines[0]
|
|||
|
|
|||
|
company_code = get_company_code(file_path)
|
|||
|
contains_newline1 = '\n' in company_code
|
|||
|
if contains_newline1:
|
|||
|
lines = company_code.splitlines(True)
|
|||
|
company_code = lines[0]
|
|||
|
|
|||
|
if company_name != "llm_error" or company_code != "llm_error":
|
|||
|
#print(company_code)
|
|||
|
pattern = re.compile(r'^(\d{6}|\d{6}(,\d{6})*)$')
|
|||
|
if not pattern.match(company_code):
|
|||
|
company_code = '-'
|
|||
|
if len(company_name) > 15 or company_name == '-':
|
|||
|
company_name = ''
|
|||
|
update_company_name(file_id, company_name,company_code, cursor, conn)
|
|||
|
except Exception as e:
|
|||
|
print(f'财报解析失败',e)
|
|||
|
|
|||
|
cursor.close()
|
|||
|
conn.close()
|
|||
|
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
conn = mysql.connector.connect(
|
|||
|
host = MYSQL_HOST,
|
|||
|
user = MYSQL_USER,
|
|||
|
password = MYSQL_PASSWORD,
|
|||
|
database = MYSQL_DB
|
|||
|
)
|
|||
|
|
|||
|
# 创建一个cursor对象来执行SQL语句
|
|||
|
cursor = conn.cursor()
|
|||
|
|
|||
|
data_query = '''
|
|||
|
SELECT id,file_path FROM report_check where c_code is null
|
|||
|
'''
|
|||
|
|
|||
|
cursor.execute(data_query)
|
|||
|
data_list = cursor.fetchall()
|
|||
|
|
|||
|
for data in data_list:
|
|||
|
try:
|
|||
|
file_id = data[0]
|
|||
|
#生产环境地址
|
|||
|
file_path = f'/usr/local/zhanglei/financial{data[1]}'
|
|||
|
#测试环境地址
|
|||
|
# file_path_1 = f'/root/pdf_parser/pdf/{data[1]}'
|
|||
|
# file_path = file_path_1.replace('/upload/file/','')
|
|||
|
print(f'财报{file_id}开始解析')
|
|||
|
#file_id = '305'
|
|||
|
#file_path = r"F:\11_pdf\7874.pdf"
|
|||
|
|
|||
|
company_name = get_company_name(file_path)
|
|||
|
contains_newline = '\n' in company_name
|
|||
|
if contains_newline:
|
|||
|
lines = company_name.splitlines(True)
|
|||
|
company_name = lines[0]
|
|||
|
|
|||
|
company_code = get_company_code(file_path)
|
|||
|
contains_newline1 = '\n' in company_code
|
|||
|
if contains_newline1:
|
|||
|
lines = company_code.splitlines(True)
|
|||
|
company_code = lines[0]
|
|||
|
|
|||
|
if company_name != "llm_error" or company_code != "llm_error":
|
|||
|
#print(company_code)
|
|||
|
pattern = re.compile(r'^(\d{6}|\d{6}(,\d{6})*)$')
|
|||
|
if not pattern.match(company_code):
|
|||
|
company_code = '-'
|
|||
|
update_company_name(file_id, company_name,company_code, cursor, conn)
|
|||
|
except Exception as e:
|
|||
|
print(f'财报解析失败',e)
|
|||
|
|
|||
|
cursor.close()
|
|||
|
conn.close()
|