zzck_code/mqreceive_multithread.py

559 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pika
import json
import logging
import time
import os
import re
import threading
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from config import *
import jieba
from llm_process import send_mq, get_label, get_translation, is_mostly_chinese, get_tdx_map, get_stock_map
from label_check import load_label_mapping, validate_tags
from check_sensitivity import SensitiveWordFilter
from news_clustering import NewsClusterer
from cal_etf_labels import *
import pandas as pd
import threading
import time
from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
# 声明一个全局变量,存媒体的权威度打分
media_score = {}
with open("media_score.txt", "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
media, score = line.split("\t")
media_score[media.strip()] = int(score)
except ValueError as e:
print(f"解析错误: {e},行内容: {line}")
continue
# 行业及概念的一二级标签强制检查
concept_mapping = load_label_mapping("concept_mapping.txt")
industry_mapping = load_label_mapping("industry_mapping.txt")
# 敏感词检查
sensitive_filter = SensitiveWordFilter("sensitive_words.txt")
# 聚类打分
logging.getLogger("pika").setLevel(logging.WARNING)
JIEBA_DICT_PATH = '/root/zzck/news_distinct_task/dict.txt.big' # jieba分词字典路径
if os.path.exists(JIEBA_DICT_PATH):
jieba.set_dictionary(JIEBA_DICT_PATH)
news_cluster = NewsClusterer()
news_cluster.load_clusters()
# 关键词加分
words_bonus = {}
with open("bonus_words.txt", "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
word, bonus = line.split(",")
words_bonus[word.strip()] = int(bonus)
except ValueError as e:
print(f"关键词加分项解析错误: {e},行内容: {line}")
continue
print(words_bonus)
# 申万一级行业热度分, 从csv读入dataframe后直接转dict
# --- 1. 初始化 ---
sw_heat_file = "/root/zzck/industry_heat_task/sw_data/heat_ranking_newest_5d.csv"
def load_sw_map():
"""封装读取逻辑,增加异常处理防止程序因为文件错误崩溃"""
try:
df = pd.read_csv(sw_heat_file)
# 使用你最优雅的一行逻辑
return df.set_index('name')['heat_score'].to_dict()
except Exception as e:
print(f"读取CSV失败: {e}")
return None
# 全局变量初始化
sw_heat_map = load_sw_map()
# --- 2. 定义更新任务 ---
def update_sw_heat_job():
global sw_heat_map
print(f"[{datetime.now()}] 开始执行每日热度数据更新...")
new_map = load_sw_map()
if new_map:
# 原子替换:这是线程安全的,正在运行的进程会自动切换到新引用
sw_heat_map = new_map
print("sw_heat_map 更新成功")
print(f"当前sw_heat_map示例: {list(sw_heat_map.items())[:5]}") # 打印前5项检查
# --- 3. 启动定时调度器 ---
scheduler = BackgroundScheduler()
# 每天 18:25 执行更新
scheduler.add_job(update_sw_heat_job, 'cron', hour=18, minute=25)
scheduler.start()
# 幂等性存储 - 记录已处理消息ID (使用线程安全的集合)
processed_ids = set()
processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问
# 创建消息队列用于批量处理
message_queue = Queue()
BATCH_SIZE = 24 # 每批处理的消息数量
MAX_WORKERS = 24 # 线程池最大工作线程数
MIN_BATCH_SIZE = 12 # 最小批量处理消息数量
PROCESS_INTERVAL = 10 # 处理间隔(秒)
def process_single_message(data):
"""处理单条消息的业务逻辑"""
try:
id_str = str(data["id"])
input_date = data["input_date"]
# print(id_str + "\t" + str(input_date))
# 幂等性检查
with processed_ids_lock:
if id_str in processed_ids:
print(f"跳过已处理的消息: {id_str}")
return None, True # 返回None表示不需要发送True表示已处理
# 先标记为已处理,防止重复
processed_ids.add(id_str)
if len(processed_ids) > 10000:
processed_ids.clear()
content = data.get('CN_content', "").strip()
source = "其他"
category_data = data.get('c', [{}])
category = ""
if category_data:
category = category_data[0].get('category', '')
b_data = category_data[0].get('b', [{}])
if b_data:
d_data = b_data[0].get('d', [{}])
if d_data:
source = d_data[0].get('sourcename', "其他")
source_impact = media_score.get(source, 5)
tagged_news = get_label(content, source)
# deal with the label problems
tagged_news = validate_tags(tagged_news, "industry_label", "industry_confidence", industry_mapping)
tagged_news = validate_tags(tagged_news, "concept_label", "concept_confidence", concept_mapping)
industry_confidence = tagged_news.get("industry_confidence", [])
concept_confidence = tagged_news.get("concept_confidence", [])
industry_confidence = list(map(lambda x: round(x, 2), industry_confidence))
concept_confidence = list(map(lambda x: round(x, 2), concept_confidence))
industry_label = tagged_news.get("industry_label", [])
sw_heat = 10 # 默认热度分,如果没有标签或热度数据缺失则使用默认值
if industry_label and industry_confidence:
first_industry = industry_label[0].split("-")[0]
first_industry_confidence = industry_confidence[0]
if first_industry_confidence >= 0.75:
sw_heat = sw_heat_map.get(first_industry, 0)
public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分
China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度
news_score = source_impact * 0.03 + public_opinion_score * 0.3 + China_factor * 35 + sw_heat * 0.05
industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence))
concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence))
# 增加一个关键词加分, 例如原文标题命中 Google 加 5 分, 不区分大小写
ori_title = data.get('title_EN', '')
hits = []
for word, bonus in words_bonus.items():
if re.search(rf'\b{re.escape(word)}\b', ori_title, re.IGNORECASE):
hits.append((word, bonus))
if hits:
# 按分数降序排序,取最高分项
hits_sorted = sorted(hits, key=lambda x: x[1], reverse=True)
print(f"原文标题{ori_title} 命中关键词及分数:", hits_sorted)
add_score = hits_sorted[0][-1] * 0.05
print(f"新闻{id_str} 当前分数news_score变动: {news_score} -> {news_score + add_score} ... 如果超出上限后续会调整...")
news_score = min(99.0, news_score + add_score)
else:
pass
# 增加一个聚类打分,目前仅作试验
if datetime.now().minute == 12:
if not news_cluster.load_clusters():
print(f"聚类文件加载失败 {datetime.now()}")
CN_title = data.get('title_txt', '')
cluster_evaluate = news_cluster.evaluate_news(title=CN_title, content=content)
cluster_score = (cluster_evaluate['weighted_score'] ** 0.5) * 10 if cluster_evaluate else 20.0
center_news_id = cluster_evaluate['center_news_id'] if cluster_evaluate else None
news_score_exp = source_impact * 0.04 + cluster_score * 0.25 + China_factor * 35
# if it's Ads, decrease the score
if all(keyword in content for keyword in ["邮箱", "电话", "网站"]):
print("新闻判别为广告软文分数强制8折")
news_score *= 0.79
news_score_exp *= 0.79
# if there is no labels at all, decrease the score
if news_score >= 80 and len(tagged_news.get("industry_label", [])) == 0 and len(tagged_news.get("concept_label", [])) == 0:
print("新闻没有打上任何有效标签分数强制8折")
news_score *= 0.79
news_score_exp *= 0.79
# if there are sensitive words in title and content, decrease the score
if news_score >= 80:
title = data.get('title_txt', '')
news_text = f"{title} {content}"
result = sensitive_filter.detect(news_text)
if result:
print("新闻命中敏感词: ", ", ".join(list(result)), "; 分数强制8折")
news_score *= 0.79
news_score_exp *= 0.79
# make sure the selected news label confidence is higher than 0.7
if news_score >= 80:
industry_high_confidence = industry_confidence[0] if len(industry_confidence) > 0 else 0
concept_high_confidence = concept_confidence[0] if len(concept_confidence) > 0 else 0
if max(industry_high_confidence, concept_high_confidence) < 0.7:
news_score *= 0.79
news_score_exp *= 0.79
print(f"行业标签置信度为{industry_high_confidence}, 概念标签置信度为{concept_high_confidence}, 新闻得分强制8折.")
news_score = round(news_score, 2)
news_score_exp = round(news_score_exp, 2)
cluster_score = round(cluster_score, 2)
# if score >= 80, add translation by llm
if news_score >= 80:
ori_content = data.get('EN_content', '')
ori_title = data.get('title_EN', '')
for _ in range(3):
# 翻译功能存在不稳定, 至多重试三次
translation = get_translation(ori_content, ori_title)
llm_content = translation.get("llm_content", "") if translation else ""
llm_title = translation.get("llm_title", "") if translation else ""
if not is_mostly_chinese(llm_title, threshold=0.4):
print("新闻标题没有被翻译成中文, 会被舍弃, LLM标题为: ", llm_title)
llm_title = ""
continue
if not is_mostly_chinese(llm_content, threshold=0.5):
print("新闻正文没有被翻译成中文, 会被舍弃, LLM正文为: ", llm_content[:30] + "...")
llm_content = ""
continue
# 20250922新增 海外事件 海外宏观 中国宏观 行业新闻 公司新闻 转载来源 公司名称 等精细处理字段
overseas_event = translation.get("overseas_event", None) if translation else None
overseas_macro = translation.get("overseas_macro", None) if translation else None
china_macro = translation.get("china_macro", None) if translation else None
industry_news = translation.get("industry_news", None) if translation else None
company_news = translation.get("company_news", None) if translation else None
reprint_source = translation.get("reprint_source", None) if translation else None
if reprint_source == "":
reprint_source = None
company_name = translation.get("company_name", None) if translation else None
if company_name:
# 打个补丁, 让公司析出数强制不超过3 2025-12-19
company_name_list = company_name.strip().split(",")
if len(company_name_list) > 3:
company_name = ",".join(company_name_list[:3])
if company_name == "":
company_name = None
## 行业和之前的标签冗余,暂时去掉
# industry_name = translation.get("industry_name", None) if translation else None
# 20251217新增 修改了翻译的一些规则; 新增通达信行业标签及置信度
tdx_industry_ori = translation.get("tdx_industry", None) if translation else None
tdx_industry_confidence_ori = translation.get("tdx_industry_confidence", None) if translation else None
tdx_map = get_tdx_map()
tdx_industry = []
tdx_industry_confidence = []
if tdx_industry_ori and tdx_industry_confidence_ori and len(tdx_industry_confidence_ori) == len(tdx_industry_ori):
for industry, confidence in zip(tdx_industry_ori, tdx_industry_confidence_ori):
if industry in tdx_map:
tdx_industry.append(industry)
tdx_industry_confidence.append(float(confidence))
if len(tdx_industry) == 3:
break
# 20251226 新增 增加了一项析出A股代码的操作, 列表形式返回A股上市公司6位代码至多5个
stock_codes_ori = translation.get("stock_codes", None) if translation else None
if stock_codes_ori:
stock_codes_ori = list(set(stock_codes_ori))
stock_codes = []
stock_names = []
stock_map = get_stock_map()
if stock_codes_ori and isinstance(stock_codes_ori, list):
for stock_code in stock_codes_ori:
if stock_code in stock_map:
stock_name = stock_map.get(stock_code)
stock_codes.append(stock_code)
stock_names.append(stock_name)
# 20260128 新增审核, 返回政治敏感度、出口转内销两项打分
political_sensitivity = translation.get("political_sensitivity", None) if translation else None
export_domestic = translation.get("export_domestic", None) if translation else None
political_notes = translation.get("political_notes", None) if translation else None
additional_notes = translation.get("additional_notes", None) if translation else None
# 20260226 新增大模型摘要
llm_abstract = translation.get("llm_abstract", None) if translation else None
if llm_content and llm_title and len(llm_content) > 30 and len(llm_title) > 3 and overseas_event is not None and overseas_macro is not None:
break
# 精选新闻添加etf
etf_labels = get_etf_labels(title=llm_title, industry_label=tagged_news.get("industry_label", []), industry_confidence=tagged_news.get("industry_confidence", []), concept_label=tagged_news.get("concept_label", []), concept_confidence=tagged_news.get("concept_confidence", []))
etf_names = [etf_name[label] for label in etf_labels]
tagged_news["etf_labels"] = etf_labels
tagged_news["etf_names"] = etf_names
tagged_news["title"] = llm_title
tagged_news["rewrite_content"] = llm_content
tagged_news["overseas_event"] = overseas_event
tagged_news["overseas_macro"] = overseas_macro
tagged_news["china_macro"] = china_macro
tagged_news["industry_news"] = industry_news
tagged_news["company_news"] = company_news
tagged_news["reprint_source"] = reprint_source
tagged_news["company_name"] = company_name
tagged_news["tdx_industry"] = tdx_industry if tdx_industry else None
tagged_news["tdx_industry_confidence"] = tdx_industry_confidence if tdx_industry_confidence else None
tagged_news["stock_codes"] = stock_codes if stock_codes else None
tagged_news["stock_names"] = stock_names if stock_names else None
tagged_news["political_sensitivity"] = political_sensitivity if political_sensitivity else None
tagged_news["export_domestic"] = export_domestic if export_domestic else None
tagged_news["political_notes"] = political_notes if political_notes else None
tagged_news["additional_notes"] = additional_notes if additional_notes else None
tagged_news["llm_abstract"] = llm_abstract if llm_abstract else None
tagged_news["source"] = source
tagged_news["source_impact"] = source_impact
tagged_news["industry_score"] = industry_score
tagged_news["concept_score"] = concept_score
tagged_news["news_score"] = news_score
tagged_news["news_score_exp"] = news_score_exp
tagged_news["center_news_id"] = center_news_id
tagged_news["id"] = id_str
tagged_news["cluster_score"] = cluster_score
tagged_news["industry_confidence"] = industry_confidence
tagged_news["concept_confidence"] = concept_confidence
if news_score >= 80:
# 保险起见对大模型翻译内容再做一遍敏感词筛查 20260128
news_title_llm = tagged_news.get("title", "")
news_content_llm = tagged_news.get("rewrite_content", "")
news_text_llm = f"{news_title_llm} {news_content_llm}"
result = sensitive_filter.detect(news_text_llm)
if result:
print("大模型翻译内容命中敏感词: ", ", ".join(list(result)), "; 分数强制8折")
tagged_news["news_score"] = news_score * 0.79
# export_domestic >= 80 或者 political_sensitivity >= 70 也直接移除精选池
if tagged_news.get("export_domestic", 0) >= 80 or tagged_news.get("political_sensitivity", 0) >= 70:
print(f'本条资讯的出口转内销得分为{tagged_news.get("export_domestic", 0)}, 政治敏感得分为{tagged_news.get("political_sensitivity", 0)}, political_notes信息为【{tagged_news.get("political_notes", "")}】, 分数8折')
tagged_news["news_score"] = news_score * 0.79
#print(json.dumps(tagged_news, ensure_ascii=False))
print(tagged_news["id"], tagged_news["news_score"], tagged_news["news_score_exp"], tagged_news["industry_label"], tagged_news["concept_label"], public_opinion_score, tagged_news["cluster_score"], tagged_news["center_news_id"], input_date)
if news_score >= 80:
print("LLM translation: ", tagged_news.get("rewrite_content", "")[:100])
print("LLM overseas_event score: ", tagged_news.get("overseas_event", None))
print("LLM overseas_macro score: ", tagged_news.get("overseas_macro", None))
print("LLM china_macro score: ", tagged_news.get("china_macro", None))
if tagged_news.get("reprint_source", None):
print("LLM reprint_source: ", tagged_news.get("reprint_source", None))
if tagged_news.get("company_name", None):
print("LLM company_name: ", tagged_news.get("company_name", None), tagged_news.get("company_news", None))
print("ETF NAMES: ", tagged_news.get("etf_names", []))
if tagged_news.get("tdx_industry", None):
print("LLM_max tdx_industry: ", tagged_news.get("tdx_industry", None))
if tagged_news.get("tdx_industry_confidence", None):
print("LLM_max tdx_industry_confidence: ", tagged_news.get("tdx_industry_confidence", None))
if tagged_news.get("stock_names", None):
print("LLM_max stock codes & names: ", tagged_news.get("stock_codes", None), tagged_news.get("stock_names", None))
if tagged_news.get("political_sensitivity", None):
print("LLM_max political_sensitivity: ", tagged_news.get("political_sensitivity", None))
if tagged_news.get("llm_abstract", None):
print("LLM_max llm_abstract: ", tagged_news.get("llm_abstract", None)[:30])
return tagged_news, True
except Exception as e:
print(f"处理消息时出错: {str(e)}")
# 处理失败,从已处理集合中移除
with processed_ids_lock:
if id_str in processed_ids:
processed_ids.remove(id_str)
return None, False
def process_message_batch(batch):
start_time = time.time()
"""并行处理一批消息"""
results = []
# 使用线程池并行处理
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = []
for data in batch:
futures.append(executor.submit(process_single_message, data))
for future in futures:
try:
result, success = future.result()
if result:
results.append(result)
except Exception as e:
print(f"处理消息时发生异常: {str(e)}")
# 发送处理结果到MQ
for result in results:
try:
send_mq(result)
except Exception as e:
print(f"发送消息到MQ失败: {str(e)}")
duration = time.time() - start_time
print(f"批量处理 {len(batch)} 条消息, 耗时: {duration:.2f}s, "
f"平均: {duration/len(batch):.3f}s/条")
def message_callback(ch, method, properties, body):
"""消息处理回调函数(只负责入队)"""
try:
data = json.loads(body)
# 将消息和delivery_tag一起放入队列
message_queue.put((data, method.delivery_tag))
except Exception as e:
print(f"消息处理失败: {str(e)}")
# 拒绝消息, 不重新入队
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
def create_connection():
"""创建并返回RabbitMQ连接"""
credentials = pika.PlainCredentials(mq_user, mq_password)
return pika.BlockingConnection(
pika.ConnectionParameters(
host="localhost",
credentials=credentials,
heartbeat=600,
connection_attempts=3,
retry_delay=5 # 重试延迟5秒
)
)
def start_consumer():
"""启动MQ消费者批量版本"""
while True:
try:
connection = create_connection()
channel = connection.channel()
# 设置QoS一次预取足够数量的消息
channel.basic_qos(prefetch_count=BATCH_SIZE * 3)
channel.exchange_declare(
exchange="zzck_exchange",
exchange_type="fanout"
)
# 声明队列
res = channel.queue_declare(queue="to_ai")
# res = channel.queue_declare(queue='', exclusive=True)
mq_queue = res.method.queue
channel.queue_bind(
exchange="zzck_exchange",
queue=mq_queue,
)
# 启动消费关闭自动ACK
channel.basic_consume(
queue=mq_queue,
on_message_callback=message_callback,
auto_ack=False
)
print(f"消费者已启动,批量大小: {BATCH_SIZE}, 工作线程: {MAX_WORKERS}, 等待消息...")
last_process_time = time.time()
# 主循环
while True:
# 处理网络事件
connection.process_data_events(time_limit=0.1) # 非阻塞处理
current_time = time.time()
queue_size = message_queue.qsize()
# 双重触发机制:达到批量大小或超过处理间隔
if queue_size >= BATCH_SIZE or \
(current_time - last_process_time >= PROCESS_INTERVAL and queue_size >= MIN_BATCH_SIZE):
batch = []
delivery_tags = []
# 获取一批消息最多BATCH_SIZE条
while not message_queue.empty() and len(batch) < BATCH_SIZE:
data, delivery_tag = message_queue.get()
batch.append(data)
delivery_tags.append(delivery_tag)
if batch:
# 处理批量消息
process_message_batch(batch)
# 确认消息
for tag in delivery_tags:
channel.basic_ack(tag)
last_process_time = current_time
# 如果队列很小但等待时间过长,确保不会永远不处理
elif current_time - last_process_time >= PROCESS_INTERVAL * 5 and queue_size > 0:
# 处理剩余的所有消息
batch = []
delivery_tags = []
while not message_queue.empty():
data, delivery_tag = message_queue.get()
batch.append(data)
delivery_tags.append(delivery_tag)
if batch:
process_message_batch(batch)
for tag in delivery_tags:
channel.basic_ack(tag)
last_process_time = current_time
# 检查连接是否关闭
if not connection or connection.is_closed:
break
except pika.exceptions.ConnectionClosedByBroker:
print("连接被代理关闭将在5秒后重试...")
time.sleep(5)
except pika.exceptions.AMQPConnectionError:
print("连接失败将在10秒后重试...")
time.sleep(10)
except KeyboardInterrupt:
print("消费者被用户中断")
try:
if connection and connection.is_open:
connection.close()
except:
pass
break
except Exception as e:
print(f"消费者异常: {str(e)}")
print("将在15秒后重试...")
time.sleep(15)
finally:
try:
if connection and connection.is_open:
connection.close()
except:
pass
if __name__ == "__main__":
start_consumer()