pdf_code/zzb_data/main.py

1173 lines
56 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import camelot
import re
from multiprocessing import Pool
import os, time, random
import json
from config import MILVUS_CLIENT,MYSQL_HOST,MYSQL_USER,MYSQL_PASSWORD,MYSQL_DB,MEASURE_COUNT,MYSQL_HOST_APP,MYSQL_USER_APP,MYSQL_PASSWORD_APP,MYSQL_DB_APP
from datetime import datetime
# 读取PDF
import PyPDF2
# 分析PDF的layout提取文本
from pdfminer.high_level import extract_pages
from pdfminer.layout import LTTextBoxHorizontal
import pdfplumber
import mysql.connector
import utils
from pymilvus import MilvusClient
import llm_service
import db_service
import pdf_title
import numpy as np
from multiprocessing import Process
from config import REDIS_HOST,REDIS_PORT,REDIS_PASSWORD
import redis
'''
已知发现问题:
1.表格和文本提取错误,表格和文本内容在同一页,文本在前表格在后的,文本数据提取不出来
2.大模型抽取错抽取2023年营业收入主营业务收入、分产品的营业收入、变动比例被错误抽取
3.表格中的指标被抽取成文本中
4.大模型抽取指标时,语义完全不同的指标被放一起,考虑用向量相似度来判断
'''
# 数据处理流程
# 1. get_table_range多进程获取所有表格及表格上下文输出为一个完整的列表
# 2. 单进程进行表格分页合并,输出一个新的表格对象数组
# 3. 新表格对象数组多进程开始原来的解析指标流程
STR_PATTERN = '营业收入|净利润|变动比例|损益|现金流量净额|现金净流量|现金流|每股收益|总资产|资产总额|收益率|货币资金|应收账款|存货|固定资产|在建工程|商誉|短期借款|应付账款|合同负债|长期借款|营业成本|销售费用|管理费用|财务费用|研发费用|研发投入|计入当期损益的政府补助'
PATTERN = '品牌类型|分门店|销售渠道|行业名称|产品名称|地区名称|子公司名称|业绩快报|调整情况说明|调整年初资产负债表|主要子公司|分部|母公司资产负债表|显示服务|渠道|商品类型|合同分类|会计政策变更|地区分类|研发项目|分类产品|表头不合规的表格|内部控制评价|关联方|国内地区|国外地区|销售区域|存货库龄|外币|逾期60天以上|欧元|英镑|美元|日元'
MUILT_PATTERN = '调整前'
#unit_pattern = re.compile(r'单位[|:]?(百万元|千万元|亿元|万元|千元|元)')
unit_pattern = re.compile(r'(单位|单元|人民币).{0,6}?(百万元|千万元|亿元|万元|千元|元).{0,3}?')#修改单位匹配规则,不限制冒号,只限制距离
#获取指标的表头信息
def get_col_num_info(array,row_num,col_num,x,y):
num_info=""
for j in range(col_num):
if len(str(array[x][j])) > 50:
continue
num_info += str(array[x][j])
return num_info.replace('%','')
#获取指标的表头信息
def get_row_num_info(array,row_num,col_num,x,y):
num_info=""
for i in range(row_num):
if len(str(array[i][y])) > 50:
continue
num_info += str(array[i][y])
return num_info
def table_converter(table):
table_string = ''
# 遍历表格的每一行
for row_num in range(len(table)):
row = table[row_num]
# 从warp的文字删除线路断路器
cleaned_row = [item.replace('\n', ' ') if item is not None and '\n' in item else 'None' if item is None else item for item in row]
# 将表格转换为字符串,注意'|'、'\n'
table_string+=(','.join(cleaned_row))
# 删除最后一个换行符
table_string = table_string[:-1]
return table_string
def safe_process_array(func, arr):
try:
return func(arr)
except Exception as e:
print(f"这个函数出现了报错{func.__name__}: {e}")
return arr # 返回原数组以便继续后续处理
#单独针对三季报的资产负债表识别合并问题
def process_array(arr, years=['2022', '2023', '2024'], keyword='项目'):
# 确保 row 有足够的列来存储分割后的数据
def ensure_columns(row, num_columns):
while len(row) < num_columns:
row.append('')
def is_valid_header(header, years, keyword):
header_text = header.lower() # 转小写以提高匹配的鲁棒性
return any(year in header_text for year in years) and keyword in header_text
# 对字符串进行清理
def clean_text(text):
# 去除“年”和“月”相邻的空格
text = re.sub(r'\s*(年|月)\s*', r'\1', text)
# 去除“日”左侧相邻的空格
text = re.sub(r'\s*日', '', text)
return text
# 将 numpy 数组转换为列表
arr = arr.tolist() if isinstance(arr, np.ndarray) else arr
if len(arr[0]) == 1 and is_valid_header(arr[0][0], years, keyword):
remaining_value = arr[0][0]
# 清理字符串
remaining_value = clean_text(remaining_value)
parts = remaining_value.split()
ensure_columns(arr[0], len(parts))
for i in range(len(parts)):
arr[0][i] = parts[i]
header_columns = len(arr[0])
for i in range(1, len(arr)):
if len(arr[i]) == 1:
remaining_value = arr[i][0]
parts = remaining_value.split()
if len(parts) > header_columns:
parts = parts[:header_columns]
ensure_columns(arr[i], header_columns)
for j in range(len(parts)):
arr[i][j] = parts[j]
# 如果分割出的值不足,填充空值
if len(parts) < header_columns:
for j in range(len(parts), header_columns):
arr[i][j] = ''
return arr
#三季报中针对性修改,本报告期和年初至报告期末的两个上年同期进行区分
def process_array_with_annual_comparison(arr, keywords=['本报告期', '年初至报告期末', '上年同期']):
def contains_all_keywords(header, keywords):
return all(keyword in header for keyword in keywords)
def split_and_replace_occurrences(header, target, replacement):
# 找到所有 target 出现的位置
indices = [i for i, x in enumerate(header) if x == target]
if len(indices) > 1:
split_index = len(indices) // 2
for i in range(split_index):
header[indices[i]] = replacement
return header
# 将 numpy 数组转换为列表
arr = arr.tolist() if isinstance(arr, np.ndarray) else arr
if len(arr) > 0 and len(arr[0]) > 0:
first_row = arr[0]
if contains_all_keywords(first_row, keywords):
# 将 "上年同期" 拆分并替换
first_row = split_and_replace_occurrences(first_row, '上年同期', '三季报中无需识别的上年同期')
arr[0] = first_row
return arr
#三季报的非经常损益的单独处理
def process_array_with_grants(arr, keywords=['本报告期', '年初至报告期'], target='计入当期损益的政府补助', replacement='非经常性损益'):
# 检查第一行是否包含所有关键词
def contains_all_keywords(header, keywords):
#return all(keyword in header for keyword in keywords)
return all(any(keyword in str(cell) for cell in header) for keyword in keywords)
# 检查第一列中是否存在目标文本
def contains_target_in_first_column(arr, target):
return any(target in str(item[0]) for item in arr)
# 替换第一列中的特定值
def replace_in_first_column(arr, target, replacement):
for i in range(len(arr)):
if arr[i][0] == target:
arr[i][0] = replacement
return arr
# 将 numpy 数组转换为列表
arr = arr.tolist() if isinstance(arr, np.ndarray) else arr
if len(arr) > 0 and len(arr[0]) > 0:
first_row = arr[0]
# 检查第一行和第一列的条件
if contains_all_keywords(first_row, keywords) and contains_target_in_first_column(arr, target):
# 替换第一列中的 "合计"
arr = replace_in_first_column(arr, '合计', replacement)
return arr
def get_table_range(file_path, file_id, pages, tables_range):
print('Run task %s (%s)...' % (f'解析表格{pages}', os.getpid()))
start = time.time()
conn = mysql.connector.connect(
host= MYSQL_HOST,
user= MYSQL_USER,
password= MYSQL_PASSWORD,
database= MYSQL_DB
)
# 创建一个cursor对象来执行SQL语句
cursor = conn.cursor(buffered=True)
conn_app = mysql.connector.connect(
host= MYSQL_HOST_APP,
user= MYSQL_USER_APP,
password= MYSQL_PASSWORD_APP,
database= MYSQL_DB_APP
)
cursor_app = conn_app.cursor(buffered=True)
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6)
try:
tables = camelot.read_pdf(file_path, pages=pages, strip_text=',\n', copy_text=['v','h'],shift_text = ['l'])
for t in tables:
top = t._bbox[3]
buttom = t._bbox[1]
page_num = int(t.page)
table_index = int(t.order)
arr = np.array(t.data)
arr = safe_process_array(process_array, arr) #部分资产负债表合并问题
arr = safe_process_array(process_array_with_annual_comparison, arr) #复杂表格的优化"多个上年同期时处理"
arr = safe_process_array(process_array_with_grants, arr) #三季报的非经常损益
if len(arr[0]) == 6 and arr[0][0]== "项目" and arr[0][1] == '' and '2022' in arr[0][2] and '2021' in arr[0][2]:
remaining_value = arr[0][2]#initial_value.replace("项目", "", 1)
split_index = len(remaining_value) // 2
arr[0][1] = remaining_value[:split_index]
arr[0][2] = remaining_value[split_index:]
if len(arr[0]) == 3 and arr[0][0]== "项目" and arr[0][1] == '' and '2022' in arr[0][2] and '2023' in arr[0][2]:
remaining_value = arr[0][2]#initial_value.replace("项目", "", 1)
split_index = len(remaining_value) // 2
arr[0][1] = remaining_value[:split_index]
arr[0][2] = remaining_value[split_index:]
if len(arr[0]) == 5 and arr[0][0]== "项目" and arr[0][2] == arr[0][4] and '同比' in arr[0][2] and arr[0][1] != arr[0][3]:
arr[0][2] = arr[0][1]+arr[0][2]
arr[0][4] = arr[0][3]+arr[0][4]
if len(arr[0]) == 4 and all(value == arr[0][0] for value in arr[0]) and all("项目" in arr[0][0] and "附注" in arr[0][0] for value in arr[0]):
initial_value = arr[0][0].replace(' ','')
project_value = "项目"
note_value = "附注"
remaining_value = initial_value.replace("项目", "", 1).replace("附注", "", 1)
split_index = len(remaining_value) // 2
first_half = remaining_value[:split_index]
second_half = remaining_value[split_index:]
# 判断 "项目" 在 original_value 中的位置
if "项目" in initial_value and first_half in initial_value and second_half in initial_value :
project_index = initial_value.index("项目")
year_index = initial_value.index(first_half)
year_index_2 = initial_value.index(second_half)
# 判断 "项目" 是否在 first_half 的前面
if project_index > year_index and project_index < year_index_2:
first_half, second_half = second_half, first_half
arr[0] = [project_value, note_value, first_half, second_half]
if len(arr[0]) == 3 and all(value == arr[0][0] for value in arr[0]) and all("项目" in arr[0][0] for value in arr[0]):
initial_value = arr[0][0]
project_value = "项目"
#note_value = "附注"
remaining_value = initial_value.replace("项目", "", 1).replace("1-9 月) 1-9 月)","")
split_index = len(remaining_value) // 2
first_half = remaining_value[:split_index]
second_half = remaining_value[split_index:]
arr[0] = [project_value, first_half, second_half]
#for i in range(len(arr[0])):
#if arr[0][i] == arr[1][i] and len(arr[0][i])<5:
#print(f'{arr[0][i]}')
#arr[1][i] = ''
#保留camelot中的空格在这里依据空格进行手动表格拆分
#for line in arr:
for line in arr:
if not line[0].replace('.', '', 1).isdigit() and any(line[i] == line[i+1] and ' ' in line[i] for i in range(1, len(line) - 1)):
for i in range(1, len(line) - 1):
if line[i] == line[i+1] and ' ' in line[i]:
split_value = line[i]
split_parts = split_value.split(' ', 1) # 使用 split 方法进行分割
if len(split_parts) == 2: # 确保确实进行了分割
first_half, second_half = split_parts
line[i] = first_half
line[i+1] = second_half
break
#处理完之后保证arr中不再存在空格
#arr = [[item.rieplace(' ', '') for item in line] for line in arr]
arr = np.char.replace(arr, ' ', '')
#这里是防止出现表格左右拼接的情况
first_row = arr[0]
if len(first_row) % 2 == 0 and all(cell.strip() for cell in first_row):
mid_point = len(first_row) // 2
if np.array_equal(first_row[:mid_point], first_row[mid_point:]):
new_arr = []
for i in range(mid_point):
new_row = np.concatenate([arr[:, i], arr[:, i + mid_point]])
new_arr.append(new_row)
arr = np.array(new_arr).T
#这里开始对无效的表头进行处理
try:
invalid_headers = ["上年年末余额"]
non_empty_values = [value for value in first_row if value]#要求就是首行除了空值外的值都必须是一致的
if len(set(non_empty_values)) == 1 and non_empty_values[0] in invalid_headers:
arr[0] = ["表头不合规的表格"] * len(first_row)
except Exception as e:
print(f'在识别表头是否合规时出现了报错:{e}')
#这里是防止出现'2023年度2022年度'camelot识别错误
if not arr[0][0].replace('.', '', 1).isdigit() and any(arr[0][i] == arr[0][i+1] and '2023' in arr[0][i] and '2022' in arr[0][i] for i in range(1, len(arr[0])-1)):
for i in range(1, len(arr[0])-1):
if arr[0][i] == arr[0][i+1] and '2023' in arr[0][i] and '2022' in arr[0][i]:
split_value = arr[0][i]
split_index = len(split_value) // 2
first_half = split_value[:split_index]
second_half = split_value[split_index:]
arr[0][i] = first_half
arr[0][i+1] = second_half
break
#防止2023与2022同时出现
if not arr[0][0].replace('.', '', 1).isdigit():
# 遍历第一行的值
for i in range(1, len(arr[0]) - 1):
# 检查相邻的两个值是否同时包含 '2023' 和 '2022'(且 '2023' 在 '2022' 之前)
if (('2023' in arr[0][i] and '2022' in arr[0][i+1]) and
(arr[0][i].index('2023') < arr[0][i+1].index('2022'))):
# 更新这两个值
arr[0][i] = '2023年'
arr[0][i+1] = '2022年'
break
#这里开始对可能解析错误的值做判断:
for i, row in enumerate(arr):
if len(row) >= 4:
# 检查条件:第一列不为数字,第二列和第四列为空,第三列有三个小数点【三列的数字被识别到一起了】
if (not row[0].replace('.', '', 1).isdigit()) and (row[1] == '') and (len(row[2].split('.')) == 4 and len(row[2].rsplit('.', 1)[-1]) == 2) and (row[3] == ''):
split_values = row[2].split('.')
# 确保可以正确拆分成三个数值
if len(split_values) == 4:
new_value1 = f"{split_values[0]}.{split_values[1][:2]}"
new_value2 = f"{split_values[1][2:]}.{split_values[2][:2]}"
new_value3 = f"{split_values[2][2:]}.{split_values[3]}"
row[1] = new_value1
row[2] = new_value2
row[3] = new_value3
#检查条件:第一列不为数字,第二列第四列为空,第三列两个小数点,第五列两个小数点【两列的数字被识别到一起了】
if len(row) >= 5 and (not row[0].replace('.', '', 1).isdigit()) and (row[1] == '') and (len(row[2].split('.')) == 3) and (row[3] == '') and (len(row[4].split('.')) == 3) and len(row[2].rsplit('.', 1)[-1]) == 2 and len(row[4].rsplit('.', 1)[-1]) == 2:
split_value_3 = row[2].split('.')
split_value_5 = row[4].split('.')
if len(split_value_3) == 3:
new_value2 = f"{split_value_3[0]}.{split_value_3[1][:2]}"
new_value3 = f"{split_value_3[1][2:]}.{split_value_3[2]}"
if len(split_value_5) == 3:
new_value4 = f"{split_value_5[0]}.{split_value_5[1][:2]}"
new_value5 = f"{split_value_5[1][2:]}.{split_value_5[2]}"
row[1] = new_value2
row[2] = new_value3
row[3] = new_value4
row[4] = new_value5
#检查条件:第一列不为数字,第二列为空,第三列有两个小数点,第四列为正常数字【两列的数字被识别到一起了】
if len(row) >= 4 and (not row[0].replace('.', '', 1).isdigit()) and (row[1] == '') and (len(row[2].split('.')) == 3) and len(row[2].rsplit('.', 1)[-1]) == 2 and (row[3].replace('-', '', 1).replace('.', '', 1).isdigit()):
split_values = row[2].split('.')
if len(split_values) == 3:
new_value2 = f"{split_values[0]}.{split_values[1][:2]}"
new_value3 = f"{split_values[1][2:]}.{split_values[2]}"
row[1] = new_value2
row[2] = new_value3
#检查条件:第一列不位数字,后面有一列中的值存在“%”并且"%"不是结尾,就进行拆分
if not row[0].replace('.', '', 1).isdigit():
for i in range(1, len(row) - 1):
if row[i] == '' and '%' in row[i + 1] and len(row[i + 1].split('%')) == 2:
split_values = row[i + 1].split('%')
new_value1 = f"{split_values[0]}%"
new_value2 = f"{split_values[1]}"
row[i] = new_value1
row[i + 1] = new_value2
break
new_data = arr.tolist()#用于后面保存到数据库中
new_data = utils.check_black_table_list(new_data)
rows, cols = arr.shape
if rows == 1 and cols == 1:
continue
arr_str = ''.join([''.join(map(str, row)) for row in arr])
#过滤掉不包含需抽取指标表格的文本
matches = re.findall(STR_PATTERN, arr_str)
pattern = re.findall(PATTERN,arr_str)
muilt_pattern = re.findall(MUILT_PATTERN,arr_str)
if len(matches) > 0 and len(pattern) == 0 and len(muilt_pattern)<5:
if not tables_range.get(page_num):
tables_range[page_num] = []
tables_range[page_num].append({
'top' : top,
'buttom' : buttom,
'table_index' : table_index,
'page_num' : page_num,
})
db_service.insert_pdf_parse_process({
'file_id': file_id,
'page_num' : page_num,
'page_count' : 100,
'type' : 'parse_table',
'content':{
'top' : top,
'buttom' : buttom,
'page_num' : page_num,
'table_index' : table_index,
"type" : "table",
"data" : new_data,
'sort_num' : page_num*1000 - top
}},conn_app,cursor_app)
except Exception as e:
print(f'camelot解析表格时出现了{e}')
get_text_content(file_path, file_id, tables_range, pages, conn, cursor, redis_client, conn_app, cursor_app)
cursor.close()
conn.close()
cursor_app.close()
conn_app.close()
redis_client.close()
end = time.time()
print('Task %s runs %0.2f seconds.' % (f'解析表格{pages}', (end - start)))
def text_in_table(top, tables_range, page_num):
if tables_range.get(page_num):
for range in tables_range[page_num]:
if top < range['top'] and top > range['buttom']:
return True
return False
def get_text_type(text: str):
text = re.sub(r"\s", "", text)
first_re = r'年度报告|季度报告'
page_number_pattern = re.compile(r'^\d+(/\d+)?$')
if re.search(first_re, text.strip()):
return 'page_header'
if page_number_pattern.match(text.strip()):
return 'page_footer'
if len(text) < 20 and text.endswith(''):
return 'page_footer'
return 'text'
# 读取pdf文件中文本内容不包括表格
def get_text_content(pdf_path,file_id,tables_range,pages,conn,cursor,redis_client, conn_app, cursor_app):
"""
:return: 返回pdf文件中文本内容不包括表格
"""
#print(f'tables_range 的值为{tables_range}')
#print('----------------')
#print(pages)
page_start = pages.split('-')[0]
page_end = pages.split('-')[1]
print(f'pages的值为{pages}')
select_year_select = f"""select report_type,year from report_check where id = {file_id}"""
cursor.execute(select_year_select)
record_select = cursor.fetchall()
report_type = record_select[0][0]
report_year = record_select[0][1]
select_pdf_text_check = f"""select count(1) from pdf_text_info where file_id = {file_id}"""
#check_if_empty_query = f"SELECT COUNT(*) FROM pdf_text_info where file_id = {file_id} and page_num = {page_num}"
cursor.execute(select_pdf_text_check)
is_empty = cursor.fetchone()[0] == 0
query = "SELECT title_list,button_list FROM table_title_list WHERE report_year = %s"
cursor_dict = conn.cursor(dictionary=True)
cursor_dict.execute(query, (report_year,))
result = cursor_dict.fetchone()
title_list = result['title_list']
button_list = result['button_list']
# 我们从PDF中提取页面,page_numbers=[4,5,6]
for pagenum, page in enumerate(extract_pages(pdf_path)):
try:
if pagenum+1 < int(page_start) or pagenum+1 > int(page_end):
continue
#更新redis已解析页码
if not redis_client.exists(f'parsed_page_count_{file_id}'):
redis_client.set(f'parsed_page_count_{file_id}', 0)
redis_client.incr(f'parsed_page_count_{file_id}')
# 找到所有的元素
page_elements = [(element.y1, element) for element in page._objs]
# 查找组成页面的元素
line_texts = []
#if not utils.pdf_text_flag(line_text):
# line_texts.append(line_text)
for i,component in enumerate(page_elements):
# 提取页面布局的元素
element = component[1]
# 检查该元素是否为文本元素
if isinstance(element, LTTextBoxHorizontal):
# 检查文本是否出现在表中
line_text = element.get_text().replace('\n','')
line_text = re.sub(r"\s", "", line_text)
#提取符合要求的文本写入pdf_text_info用于文本书写错误识别
#if not utils.pdf_text_flag(line_text):
line_texts.append(line_text)
#db_service.insert_pdf_text_info({
# 'file_id': file_id,
# 'page_num' : pagenum+1,
# 'text' : line_text
# },conn,cursor)
element_top = element.bbox[3]
element_buttom = element.bbox[1]
out_table_list = ['母公司现金流量表','母公司利润表','母公司资产负债表','子公司']
# 检查该文本是否出现在表中
if tables_range.get(pagenum+1):
for range in tables_range[pagenum+1]:
if element_top < range['top'] and element_top > range['buttom']:#总是有母公司表被识别到上一个表里面:
pass
else:
if element_top - range['top'] < 150 and element_top - range['top'] > 5 and (not text_in_table(element_top, tables_range, pagenum+1) or any(word in line_text for word in out_table_list)):#or any(word in line_text for word in out_table_list)
text_type = get_text_type(line_text)
if text_type in ('page_header','page_footer'):
break
if pagenum ==44:
print(f'line_text在第44页的值有{line_text}')
#这个对一整页都有用,会去掉很多正确的表
# 记录需要过滤掉的页码
if len(re.findall('母公司|现金流量表补充', line_text)) > 0 :
db_service.insert_measure_parser_info({
'file_id': file_id,
'content': pagenum+1,
'type': 'parent_com',
},conn_app,cursor_app)
# 保存每个表格上方小范围区域的文字,这部分内容包含了表格的标题和指标单位
table_info = {}
if utils.check_table_title_black_list(line_text,title_list):
db_service.insert_measure_parser_info({
'file_id': file_id,
'content': f"{range['page_num']}_{range['table_index']}",
'type': 'table_index',
},conn_app,cursor_app)
if utils.check_table_title_black_list_measure(line_text):
db_service.insert_measure_parser_info_measure({
'file_id': file_id,
'content': f"{range['page_num']}_{range['table_index']}",
'type': 'measure_index',
},conn_app,cursor_app,line_text)
if re.findall(unit_pattern, line_text):
range['unit_flag'] = True
table_info = get_table_unit_info(file_id,line_text,range['page_num'],range['table_index'])
db_service.insert_table_unit_info_v1(table_info,conn,cursor)
# if utils.check_table_title_black_list(line_text):
# db_service.insert_measure_parser_info({
# 'file_id': file_id,
# 'content': f"{range['page_num']}_{range['table_index']}",
# 'type': 'table_index',
# },conn,cursor)
else:
if len(line_text) <= 5 or len(re.findall('单位|适用', line_text)) > 0 :
pass
#else:
# table_info = get_table_text_info(file_id,line_text,range['page_num'],range['table_index'])
# db_service.insert_table_text_info(table_info,conn,cursor)
#通过关键词黑名单匹配表格上方的文本区域,提取需要过滤的表格
# if utils.check_table_title_black_list(line_text):
# db_service.insert_measure_parser_info({
# 'file_id': file_id,
# 'content': f"{range['page_num']}_{range['table_index']}",
# 'type': 'table_index',
# },conn,cursor)
if utils.check_line_text(line_text):
db_service.insert_pdf_parse_process({
'file_id': file_id,
'page_num' : pagenum+1,
'page_count' : 100,
'type' : 'parse_table',
'content':{
'top' : element_top,
'buttom' : element_buttom,
'page_num' : range['page_num'],
'table_index' : range['table_index'],
"type" : text_type,
'content' : line_text,
'sort_num' : range['page_num']*1000 - element_top
}},conn_app,cursor_app)
break
#处理母公司表格标题在页面底部,完整表格在下一页
if element_buttom < 150 and not text_in_table(element_top, tables_range, pagenum+1):
text_type = get_text_type(line_text)
if text_type == 'page_footer':
continue
table_info = {}
# 记录需要过滤掉的页码
if len(re.findall('母公司|现金流量表补充', line_text)) > 0:
db_service.insert_measure_parser_info({
'file_id': file_id,
'content': pagenum+2,
'type': 'parent_com',
},conn_app,cursor_app)
#通过关键词黑名单匹配本页面末尾文字,如果出现
if utils.check_table_title_black_list_button(line_text,button_list):
db_service.insert_measure_parser_info({
'file_id': file_id,
'content': f"{pagenum+2}_1",
'type': 'table_index',
},conn_app,cursor_app)
if utils.check_table_title_black_list_measure(line_text):
db_service.insert_measure_parser_info_measure({
'file_id': file_id,
'content': f"{pagenum+2}_1",
'type': 'measure_index',
},conn_app,cursor_app,line_text)
if re.findall(unit_pattern, line_text):
table_info = get_table_unit_info(file_id,line_text,pagenum+2,1)
db_service.insert_table_unit_info(table_info,conn,cursor)
if utils.check_line_text(line_text):
db_service.insert_pdf_parse_process({
'file_id': file_id,
'page_num' : pagenum+1,
'page_count' : 100,
'type' : 'parse_table',
'content':{
'top' : element_top,
'buttom' : element_buttom,
'page_num' : pagenum+1,
"type" : text_type,
'content' : line_text,
'sort_num' : (pagenum+1)*1000 - element_top
}},conn_app,cursor_app)
if is_empty:
db_service.batch_insert_page_text_nocheck({
'file_id': file_id,
'page_num' : pagenum+1,
'text' : line_texts
},conn,cursor)
#print('文本这里没有重跑')
else:
db_service.batch_insert_page_text({
'file_id': file_id,
'page_num' : pagenum+1,
'text' : line_texts
},conn,cursor)
except Exception as e:
print(f'{pagenum}页处理异常')
print(e)
def get_text_content_disclosure(pdf_path,file_id,tables_range,pages,conn,cursor,redis_client, conn_app, cursor_app):
"""
:return: 返回pdf文件中文本内容不包括表格
"""
#print(f'tables_range 的值为{tables_range}')
#print('----------------')
#print(pages)
page_start = pages.split('-')[0]
page_end = pages.split('-')[1]
print(f'pages的值为{pages}')
# select_year_select = f"""select report_type,year from report_check where id = {file_id}"""
# cursor.execute(select_year_select)
# record_select = cursor.fetchall()
# report_type = record_select[0][0]
# report_year = record_select[0][1]
select_pdf_text_check = f"""select count(1) from pdf_text_info_disclosure where file_id = {file_id}"""
#check_if_empty_query = f"SELECT COUNT(*) FROM pdf_text_info where file_id = {file_id} and page_num = {page_num}"
cursor.execute(select_pdf_text_check)
is_empty = cursor.fetchone()[0] == 0
# 我们从PDF中提取页面,page_numbers=[4,5,6]
for pagenum, page in enumerate(extract_pages(pdf_path)):
try:
if pagenum+1 < int(page_start) or pagenum+1 > int(page_end):
continue
#更新redis已解析页码
if not redis_client.exists(f'parsed_page_count_{file_id}'):
redis_client.set(f'parsed_page_count_{file_id}', 0)
redis_client.incr(f'parsed_page_count_{file_id}')
# 找到所有的元素
page_elements = [(element.y1, element) for element in page._objs]
# 查找组成页面的元素
line_texts = []
#if not utils.pdf_text_flag(line_text):
# line_texts.append(line_text)
for i,component in enumerate(page_elements):
# 提取页面布局的元素
element = component[1]
# 检查该元素是否为文本元素
if isinstance(element, LTTextBoxHorizontal):
# 检查文本是否出现在表中
line_text = element.get_text().replace('\n','')
line_text = re.sub(r"\s", "", line_text)
#提取符合要求的文本写入pdf_text_info用于文本书写错误识别
#if not utils.pdf_text_flag(line_text):
line_texts.append(line_text)
#db_service.insert_pdf_text_info({
# 'file_id': file_id,
# 'page_num' : pagenum+1,
# 'text' : line_text
# },conn,cursor)
if is_empty:
db_service.batch_insert_page_text_nocheck_disclosure({
'file_id': file_id,
'page_num' : pagenum+1,
'text' : line_texts
},conn,cursor)
#print('文本这里没有重跑')
else:
db_service.batch_insert_page_text_disclosure({
'file_id': file_id,
'page_num' : pagenum+1,
'text' : line_texts
},conn,cursor)
except Exception as e:
print(f'{pagenum}页处理异常')
print(e)
def get_table_unit_info(file_id,line_text,page_num,table_index):
table_info = {}
table_info['file_id'] = file_id
match = unit_pattern.search(line_text)
if match:
unit = match.group(2)
table_info['unit'] = unit
table_info['page_num'] = page_num
table_info['table_index'] = table_index
#print(table_info)
return table_info
def get_table_text_info(file_id,line_text,page_num,table_index):
table_info = {}
table_info['file_id'] = file_id
table_info['text_info'] = line_text
table_info['page_num'] = page_num
table_info['table_index'] = table_index
#print(table_info)
return table_info
# 读取pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
def get_table_measure(file_id, pdf_tables, record_range):
"""
:return: pdf中的表格,并将表格中指标和表头合并eg: 2022年1季度营业收入为xxxxx
"""
try:
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6)
conn = mysql.connector.connect(
host = MYSQL_HOST,
user = MYSQL_USER,
password = MYSQL_PASSWORD,
database = MYSQL_DB
)
# 创建一个cursor对象来执行SQL语句
cursor = conn.cursor(buffered=True)
conn_app = mysql.connector.connect(
host = MYSQL_HOST_APP,
user = MYSQL_USER_APP,
password = MYSQL_PASSWORD_APP,
database = MYSQL_DB_APP
)
# 创建一个cursor对象来执行SQL语句
cursor_app = conn_app.cursor(buffered=True)
select_year_select = f"""select report_type,year from report_check where id = {file_id}"""
cursor.execute(select_year_select)
record_select = cursor.fetchall()
report_type = record_select[0][0]
report_year = record_select[0][1]
client = MilvusClient(
uri= MILVUS_CLIENT
)
print('提取指标任务 %s (%s)...' % (record_range, os.getpid()))
start = time.time()
record_start = record_range.split('-')[0]
record_end = record_range.split('-')[1]
for index in range(int(record_start),int(record_end)):
t = pdf_tables[index]
measure_obj =[]
data_dict = {}
measure_list = []
try:
arr = np.array(t['data'])
rows, cols = arr.shape
if rows == 1 and cols == 1:
continue
row_num , col_num = -1 , -1
# 使用嵌套循环遍历数组,获取第一个数值位置
for i in range(rows):
for j in range(cols):
if j == 0 or i == 0:#防止第一列识别出数字
continue
measure_value_config = str(arr[i, j]).replace('(','').replace(')','')
if re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', measure_value_config):
if j == cols-1:
row_num , col_num = i , j
break
elif (re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)(%?)$', measure_value_config)
or measure_value_config == '-'):
row_num , col_num = i , j
break
else:
continue
break
# 遍历数值二维数组,转成带语义的指标
if row_num != -1 and col_num != -1:
for i in range(row_num,arr.shape[0]):
for j in range(col_num,arr.shape[1]):
measure_value = str(arr[i, j]).replace('%','').replace('(','-').replace(')','')
if measure_value == '-' or measure_value == '' or len(measure_value) > 20:
continue
else:
row_num_info = get_row_num_info(arr,row_num,col_num,i,j)
col_num_info = get_col_num_info(arr,row_num,col_num,i,j)
#如果上表头为空则认为是被截断,除了研发投入特殊处理其它过滤
if row_num_info in ('','-',')',''):
continue
#特殊处理非经常性损益合计和非经常性损益净额同时出现时保留净额
if col_num_info == '非经常性损益合计':
continue
if utils.check_pdf_measure_black_list(f"{col_num_info}{row_num_info}"):
continue
#去掉没有周期的指标
if utils.check_pdf_measure(f"{col_num_info}{row_num_info}"):
continue
#判断上表头和左表头周期是否一致,不一致过滤
row_period = utils.get_period_type_other(row_num_info, report_year)
col_period = utils.get_period_type_other(col_num_info, report_year)
if(row_period != col_period and row_period != 'c_n' and col_period != 'c_n'):
continue
units_mapping = {
"百万元": "百万元",
"千万元": "千万元",
"亿元": "亿元",
"万元": "万元",
"千元": "千元",
"": "",
"元/股": ""
}
row_num_info = row_num_info.replace('%','增减')
#num_info = f"{col_num_info}{row_num_info}".replace('','').replace('加:','').replace('减:','').replace('%','')
num_info = utils.get_clean_text(f"{row_num_info}{col_num_info}")
num_info_bak = utils.get_clean_text(f"{col_num_info}{row_num_info}")
measure_unit = ''
#"%": "同期增减"
combined_info = f"{row_num_info} {col_num_info}"
# for unit in units_mapping:
# if unit in row_num_info:
# measure_unit = units_mapping[unit]
# break
if utils.get_percent_flag(row_num_info) == '1':
measure_unit = ''
else:
for unit in units_mapping:
if re.search(rf'\\s*{unit}(\s*人民币)?\s*\|\(\s*{unit}(\s*人民币)?\s*\)', combined_info) or (re.search(rf'{unit}', combined_info) and any(re.search('单位', item) for item in arr[0])):
measure_unit = units_mapping[unit]
break
measure_list.append({
'measure_name': num_info,
'measure_value': measure_value,
'measure_unit':measure_unit,
})
measure_list.append({
'measure_name': num_info_bak,
'measure_value': measure_value,
'measure_unit':measure_unit,
})
if not redis_client.exists(f'parsed_measure_count_{file_id}'):
redis_client.set(f'parsed_measure_count_{file_id}', 0)
redis_client.incr(f'parsed_measure_count_{file_id}')
if len(measure_list) > 0:
data_dict["measure_list"] = measure_list
data_dict["page_num"] = f"{str(t['page_num'])}_{str(t['table_index'])}"
data_dict['file_id'] = file_id
measure_obj.append(data_dict)
db_service.insert_measure_data_to_milvus(client,measure_obj,cursor_app,conn_app)
except Exception as e:
print(f"循环获取表格数据这里报错了,数据是{t['data']},位置在{index}")
print(f"错误是:{e}")
end = time.time()
print('提取指标 %s runs %0.2f seconds.' % (record_range, (end - start)))
except Exception as e:
print(f'这个错误是{e},所在的位置是{record_start}-{record_end}')
record_start = record_range.split('-')[0]
record_end = record_range.split('-')[1]
for index in range(int(record_start),int(record_end)):
t = pdf_tables[index]
measure_obj =[]
data_dict = {}
measure_list = []
try:
arr = np.array(t['data'])
except Exception as e:
print(f'这个错误是{e}的arr的值是{arr}')
finally:
redis_client.close()
client.close()
cursor.close()
conn.close()
cursor_app.close()
conn_app.close()
#多进程任务分发,根据参数判断是调表格还是正文
def dispatch_job(job_info):
try:
type = job_info['type']
path = job_info['path']
file_id = job_info['file_id']
page_num = job_info['page_num']
tables_range = job_info['tables_range']
if type == 'table':
get_table_range(path, file_id, page_num, tables_range)
except Exception as e:
print(e)
def dispatch_disclosure(job_info):
try:
type = job_info['type']
path = job_info['path']
file_id = job_info['file_id']
page_num = job_info['page_num']
tables_range = job_info['tables_range']
conn = mysql.connector.connect(
host= MYSQL_HOST,
user= MYSQL_USER,
password= MYSQL_PASSWORD,
database= MYSQL_DB
)
# 创建一个cursor对象来执行SQL语句
cursor = conn.cursor(buffered=True)
conn_app = mysql.connector.connect(
host= MYSQL_HOST_APP,
user= MYSQL_USER_APP,
password= MYSQL_PASSWORD_APP,
database= MYSQL_DB_APP
)
cursor_app = conn_app.cursor(buffered=True)
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6)
if type == 'table':
get_text_content_disclosure(path,file_id,tables_range,page_num,conn,cursor,redis_client, conn_app, cursor_app)
except Exception as e:
print(e)
#指标归一化处理
def update_measure_data(file_id,file_path,parent_table_pages):
conn = mysql.connector.connect(
host = MYSQL_HOST,
user = MYSQL_USER,
password = MYSQL_PASSWORD,
database = MYSQL_DB
)
# 创建一个cursor对象来执行SQL语句
cursor = conn.cursor(buffered=True)
# #通过向量查询指标
conn_app = mysql.connector.connect(
host = MYSQL_HOST_APP,
user = MYSQL_USER_APP,
password = MYSQL_PASSWORD_APP,
database = MYSQL_DB_APP
)
# 创建一个cursor对象来执行SQL语句
cursor_app = conn_app.cursor(buffered=True)
print(f'目录黑名单为:{parent_table_pages}')
db_service.delete_to_run(conn,cursor,file_id)
db_service.insert_table_measure_from_vector_async_process(cursor,parent_table_pages,file_id,file_path)
# #指标归一化处理
db_service.update_ori_measure(conn,cursor,file_id)
#db_service.delete_database(conn_app,cursor_app,file_id)
#保证同一页同一个表的指标在页面展示时,只出现一次
db_service.update_ori_measure_name(conn,cursor,file_id)
cursor.close()
conn.close()
cursor_app.close()
conn_app.close()
def merge_consecutive_arrays(pdf_info):
merged_objects = []
temp_array = {}
for info_obj in pdf_info:
try:
if info_obj['type'] == 'table':
# 如果对象是表格,将其元素添加到临时列表中
if not temp_array.get('page_num'):
temp_array = info_obj
#else:
# temp_array['data'].extend(info_obj['data'])
elif len(temp_array['data'][0]) == len(info_obj['data'][0]):
temp_array['data'].extend(info_obj['data'])
else:
if temp_array:
# 将临时列表中的元素合并成一个数组,并添加到新的对象列表中
merged_objects.append(temp_array)
temp_array = {} # 重置临时列表
else:
# 如果对象不是表格,检查临时列表是否为空
if temp_array:
# 将临时列表中的元素合并成一个数组,并添加到新的对象列表中
merged_objects.append(temp_array)
temp_array = {} # 重置临时列表
except Exception as e:
#print(info_obj)
print(f"解析数据错误: {e}")
if temp_array:
merged_objects.append(temp_array)
return merged_objects
def merge_consecutive_arrays_v1(pdf_info):
merged_objects = []
temp_array = {}
def is_same_dimension(data1, data2):
# 检查两个表的每行长度是否相同
if len(data1) != len(data2):
return False
return all(len(row1) == len(row2) for row1, row2 in zip(data1, data2))
for info_obj in pdf_info:
try:
if info_obj['type'] == 'table':
if not temp_array:
# 如果临时列表为空,则初始化临时列表
temp_array = info_obj
else:
# 检查当前表与临时列表中的表是否同维度
if is_same_dimension(temp_array['data'], info_obj['data']):
# 如果是同维度,则合并数据
temp_array['data'].extend(info_obj['data'])
else:
# 如果不是同维度,将现有临时列表添加到结果中,并重置临时列表
merged_objects.append(temp_array)
temp_array = info_obj
else:
# 如果对象不是表格,检查临时列表是否非空
if temp_array:
# 将临时列表中的元素合并成一个数组,并添加到新的对象列表中
merged_objects.append(temp_array)
temp_array = {} # 重置临时列表
except Exception as e:
print(f"解析数据错误: {e}")
# 循环结束后,检查临时列表是否非空,如果非空,则添加到结果中
if temp_array:
merged_objects.append(temp_array)
return merged_objects
def start_table_measure_job(file_id):
conn_app = mysql.connector.connect(
host = MYSQL_HOST_APP,
user = MYSQL_USER_APP,
password = MYSQL_PASSWORD_APP,
database = MYSQL_DB_APP
)
# 创建一个cursor对象来执行SQL语句
cursor_app = conn_app.cursor(buffered=True)
select_process_query = '''
select content from pdf_parse_process WHERE file_id = '{file_id}' and type='parse_table'
'''.format(file_id=file_id)
cursor_app.execute(select_process_query)
records = cursor_app.fetchall()
pdf_info = []
for record in records:
pdf_info.append(eval(record[0]))
sorted_pdf_info = sorted(pdf_info, key=lambda k: k['sort_num'])
pdf_tables = merge_consecutive_arrays(sorted_pdf_info)
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=6)
redis_client.set(f'measure_count_{file_id}', len(pdf_tables))
cursor_app.close()
conn_app.close()
redis_client.close()
records_range_parts = utils.get_range(len(pdf_tables),MEASURE_COUNT)
print(f'records_range_part识别页码的值为{records_range_parts}')
processes = []
for record_range in records_range_parts:
p = Process(target=get_table_measure, args=(file_id,pdf_tables,record_range,))
processes.append(p)
p.start()
for p in processes:
p.join()
if __name__ == "__main__":
file_id = '1778'
page_num = 11
conn = mysql.connector.connect(
host = MYSQL_HOST,
user = MYSQL_USER,
password = MYSQL_PASSWORD,
database = MYSQL_DB
)
# 创建一个cursor对象来执行SQL语句
cursor = conn.cursor(buffered=True)
select_process_query = '''
select content from pdf_parse_process WHERE file_id = '{file_id}' and type='parse_table'
and page_num in(41,42,43)
'''.format(file_id=file_id, page_num=page_num)
cursor.execute(select_process_query)
records = cursor.fetchall()
pdf_info = []
for record in records:
pdf_info.append(eval(record[0]))
sorted_pdf_info = sorted(pdf_info, key=lambda k: k['sort_num'])
pdf_tables = merge_consecutive_arrays(sorted_pdf_info)
get_table_measure(file_id,pdf_tables,'0-2')
# sorted_pdf_info = sorted(pdf_info, key=lambda k: k['sort_num'])
# pdf_tables = merge_consecutive_arrays(sorted_pdf_info)
# for table in pdf_tables:
# print(table)#修改测试