pdf_code/zzb_data/pdf_company.py

221 lines
8.8 KiB
Python
Raw Normal View History

2024-10-31 15:35:27 +08:00
from config import MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB
import mysql.connector
from http import HTTPStatus
import dashscope
import random,re
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTTextBoxHorizontal
import PyPDF2
dashscope.api_key='sk-63c02fbb9b7d4b0494a3200bec1ae286'
def get_company_name(file_path):
line_text = ''
# 我们从PDF中提取页面,page_numbers=[4,5,6]
for pagenum, page in enumerate(extract_pages(file_path)):
if pagenum > 1:
break
# 找到所有的元素
page_elements = [(element.y1, element) for element in page._objs]
# 查找组成页面的元素
for i,component in enumerate(page_elements):
# 提取页面布局的元素
element = component[1]
# 检查该元素是否为文本元素
if isinstance(element, LTTextBoxHorizontal):
# 检查文本是否出现在表中
line_text += element.get_text()
return llm_service(line_text)
def get_company_code(file_path):
line_text = ''
# 我们从PDF中提取页面,page_numbers=[4,5,6]
for pagenum, page in enumerate(extract_pages(file_path)):
if pagenum > 1:
break
# 找到所有的元素
page_elements = [(element.y1, element) for element in page._objs]
# 查找组成页面的元素
for i,component in enumerate(page_elements):
# 提取页面布局的元素
element = component[1]
# 检查该元素是否为文本元素
if isinstance(element, LTTextBoxHorizontal):
# 检查文本是否出现在表中
line_text += element.get_text()
return llm_service_code(line_text)
#获取公司简介的那一页
# def get_code_page(pdf_path):
# with open(pdf_path, 'rb') as file:
# reader = PyPDF2.PdfReader(file)
# outlines = reader.outline
# company_profile_page = None
# def find_page_from_outlines(outlines):
# nonlocal company_profile_page
# for item in outlines:
# if isinstance(item, list): # 如果是子目录,则递归
# find_page_from_outlines(item)
# else:
# title = item.title
# if title is not None and '公司简介' in title:
# # 获取页面的实际页码
# page_num = reader.get_destination_page_number(item)
# company_profile_page = page_num
# return
# # 处理没有标题的情况
# elif item.page is not None:
# page_num = reader.get_destination_page_number(item)
# if page_num is not None:
# pass
# find_page_from_outlines(outlines)
# return company_profile_page
# def get_company_code(file_path):
# line_text = ''
# # 我们从PDF中提取页面,page_numbers=[4,5,6]
# for pagenum, page in enumerate(extract_pages(file_path)):
# print(f'页码是{get_code_page(file_path)+1}')
# if pagenum > 1 and pagenum != get_code_page(file_path)+1:
# break
# # 找到所有的元素
# #print(pagenum)
# page_elements = [(element.y1, element) for element in page._objs]
# # 查找组成页面的元素
# # for i,component in enumerate(page_elements):
# # # 提取页面布局的元素
# # element = component[1]
# # # 检查该元素是否为文本元素
# # if isinstance(element, LTTextBoxHorizontal):
# # # 检查文本是否出现在表中
# # line_text += element.get_text()
# for _, element in page_elements:
# if isinstance(element, LTTextBoxHorizontal):
# # 提取文本并添加到 line_text
# line_text += element.get_text()
# return llm_service_code(line_text)
def llm_service(user_prompt):
system_prompt = '''
从以下数据报告中提取公司全称只需要提取中文公司全称不要增加其他内容如果提取不到公司全称请返回-
<数据报告>
<user_prompt>
</数据报告>
'''
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
response = dashscope.Generation.call(
model='qwen-plus',
prompt = system_prompt,
seed=random.randint(1, 10000),
top_p=0.8,
result_format='message',
enable_search=False,
max_tokens=1500,
temperature=0.85,
repetition_penalty=1.0
)
if response.status_code == HTTPStatus.OK:
result = response['output']['choices'][0]['message']['content']
return result
else:
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return "llm_error"
def llm_service_code(user_prompt):
system_prompt = '''
从以下数据报告中提取6位数字的股票代码只需要提取股票代码如果有多个则以','隔开不要增加其他内容如果提取不到股票代码请返回-,不要返回其他任何内容
<数据报告>
<user_prompt>
</数据报告>
'''
system_prompt = system_prompt.replace('<user_prompt>', user_prompt)
response = dashscope.Generation.call(
model='qwen-plus',
prompt = system_prompt,
seed=random.randint(1, 10000),
top_p=0.8,
result_format='message',
enable_search=False,
max_tokens=1500,
temperature=0.85,
repetition_penalty=1.0
)
if response.status_code == HTTPStatus.OK:
result = response['output']['choices'][0]['message']['content']
return result
else:
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return "llm_error"
def update_company_name(file_id, company_name,company_code, cursor, conn):
update_sql = f'''
UPDATE report_check
SET c_name = '{company_name}',c_code = '{company_code}'
WHERE id = {file_id}
'''
cursor.execute(update_sql)
conn.commit()
if __name__ == '__main__':
conn = mysql.connector.connect(
host = MYSQL_HOST,
user = MYSQL_USER,
password = MYSQL_PASSWORD,
database = MYSQL_DB
)
# 创建一个cursor对象来执行SQL语句
cursor = conn.cursor()
data_query = '''
SELECT id,file_path FROM report_check where c_code is null
'''
cursor.execute(data_query)
data_list = cursor.fetchall()
for data in data_list:
try:
file_id = data[0]
#生产环境地址
file_path = f'/usr/local/zhanglei/financial{data[1]}'
#测试环境地址
# file_path_1 = f'/root/pdf_parser/pdf/{data[1]}'
# file_path = file_path_1.replace('/upload/file/','')
print(f'财报{file_id}开始解析')
#file_id = '305'
#file_path = r"F:\11_pdf\7874.pdf"
company_name = get_company_name(file_path)
contains_newline = '\n' in company_name
if contains_newline:
lines = company_name.splitlines(True)
company_name = lines[0]
company_code = get_company_code(file_path)
contains_newline1 = '\n' in company_code
if contains_newline1:
lines = company_code.splitlines(True)
company_code = lines[0]
if company_name != "llm_error" or company_code != "llm_error":
#print(company_code)
pattern = re.compile(r'^(\d{6}|\d{6}(,\d{6})*)$')
if not pattern.match(company_code):
company_code = '-'
update_company_name(file_id, company_name,company_code, cursor, conn)
except Exception as e:
print(f'财报解析失败',e)
cursor.close()
conn.close()