109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
from config import MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB
|
|
import mysql.connector
|
|
from http import HTTPStatus
|
|
import dashscope
|
|
import random,re
|
|
from pdfminer.high_level import extract_pages
|
|
from pdfminer.layout import LTTextBoxHorizontal
|
|
|
|
dashscope.api_key='sk-63c02fbb9b7d4b0494a3200bec1ae286'
|
|
|
|
def get_company_name(file_path):
|
|
line_text = ''
|
|
# 我们从PDF中提取页面,page_numbers=[4,5,6]
|
|
for pagenum, page in enumerate(extract_pages(file_path)):
|
|
if pagenum > 1:
|
|
break
|
|
# 找到所有的元素
|
|
page_elements = [(element.y1, element) for element in page._objs]
|
|
# 查找组成页面的元素
|
|
for i,component in enumerate(page_elements):
|
|
# 提取页面布局的元素
|
|
element = component[1]
|
|
# 检查该元素是否为文本元素
|
|
if isinstance(element, LTTextBoxHorizontal):
|
|
# 检查文本是否出现在表中
|
|
line_text += element.get_text()
|
|
|
|
return llm_service(line_text)
|
|
|
|
def llm_service(user_prompt):
|
|
|
|
system_prompt = '''
|
|
从以下数据报告中提取公司全称,只需要提取中文公司全称,不要增加其他内容,如果提取不到公司全称,请返回-。
|
|
<数据报告>
|
|
<user_prompt>
|
|
</数据报告>
|
|
'''
|
|
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
|
|
response = dashscope.Generation.call(
|
|
model='qwen-plus',
|
|
prompt = system_prompt,
|
|
seed=random.randint(1, 10000),
|
|
top_p=0.8,
|
|
result_format='message',
|
|
enable_search=False,
|
|
max_tokens=1500,
|
|
temperature=0.85,
|
|
repetition_penalty=1.0
|
|
)
|
|
if response.status_code == HTTPStatus.OK:
|
|
result = response['output']['choices'][0]['message']['content']
|
|
return result
|
|
else:
|
|
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
|
response.request_id, response.status_code,
|
|
response.code, response.message
|
|
))
|
|
|
|
return "llm_error"
|
|
|
|
def update_company_name(file_id, company_name, cursor, conn):
|
|
update_sql = f'''
|
|
UPDATE report_check
|
|
SET c_name = '{company_name}'
|
|
WHERE id = {file_id}
|
|
'''
|
|
cursor.execute(update_sql)
|
|
conn.commit()
|
|
|
|
if __name__ == '__main__':
|
|
conn = mysql.connector.connect(
|
|
host = MYSQL_HOST,
|
|
user = MYSQL_USER,
|
|
password = MYSQL_PASSWORD,
|
|
database = MYSQL_DB
|
|
)
|
|
|
|
# 创建一个cursor对象来执行SQL语句
|
|
cursor = conn.cursor()
|
|
|
|
data_query = '''
|
|
SELECT id,file_path FROM report_check where c_name is null
|
|
'''
|
|
|
|
cursor.execute(data_query)
|
|
data_list = cursor.fetchall()
|
|
|
|
for data in data_list:
|
|
try:
|
|
file_id = data[0]
|
|
file_path = f'/usr/local/zhanglei/financial/{data[1]}'
|
|
print(f'财报{file_id}开始解析')
|
|
# file_id = '1329'
|
|
# file_path = '/Users/zhengfei/Desktop/cb/zhangjun-600271-2023-nb-nb.pdf'
|
|
|
|
company_name = get_company_name(file_path)
|
|
contains_newline = '\n' in company_name
|
|
if contains_newline:
|
|
lines = company_name.splitlines(True)
|
|
company_name = lines[0]
|
|
|
|
if company_name != "llm_error":
|
|
update_company_name(file_id, company_name, cursor, conn)
|
|
except Exception as e:
|
|
print(f'财报{file_id}解析失败',e)
|
|
|
|
cursor.close()
|
|
conn.close()
|