import pika import json import logging import time import os import re import threading from datetime import datetime from concurrent.futures import ThreadPoolExecutor from queue import Queue from config import * import jieba from llm_process import send_mq, get_label, get_translation, is_mostly_chinese, get_tdx_map, get_stock_map from label_check import load_label_mapping, validate_tags from check_sensitivity import SensitiveWordFilter from news_clustering import NewsClusterer from cal_etf_labels import * import pandas as pd import threading import time from datetime import datetime from apscheduler.schedulers.background import BackgroundScheduler # 声明一个全局变量,存媒体的权威度打分 media_score = {} with open("media_score.txt", "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: media, score = line.split("\t") media_score[media.strip()] = int(score) except ValueError as e: print(f"解析错误: {e},行内容: {line}") continue # 行业及概念的一二级标签强制检查 concept_mapping = load_label_mapping("concept_mapping.txt") industry_mapping = load_label_mapping("industry_mapping.txt") # 敏感词检查 sensitive_filter = SensitiveWordFilter("sensitive_words.txt") # 聚类打分 logging.getLogger("pika").setLevel(logging.WARNING) JIEBA_DICT_PATH = '/root/zzck/news_distinct_task/dict.txt.big' # jieba分词字典路径 if os.path.exists(JIEBA_DICT_PATH): jieba.set_dictionary(JIEBA_DICT_PATH) news_cluster = NewsClusterer() news_cluster.load_clusters() # 关键词加分 words_bonus = {} with open("bonus_words.txt", "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: word, bonus = line.split(",") words_bonus[word.strip()] = int(bonus) except ValueError as e: print(f"关键词加分项解析错误: {e},行内容: {line}") continue print(words_bonus) # 申万一级行业热度分, 从csv读入dataframe后直接转dict # --- 1. 初始化 --- sw_heat_file = "/root/zzck/industry_heat_task/sw_data/heat_ranking_newest_5d.csv" def load_sw_map(): """封装读取逻辑,增加异常处理防止程序因为文件错误崩溃""" try: df = pd.read_csv(sw_heat_file) # 使用你最优雅的一行逻辑 return df.set_index('name')['heat_score'].to_dict() except Exception as e: print(f"读取CSV失败: {e}") return None # 全局变量初始化 sw_heat_map = load_sw_map() # --- 2. 定义更新任务 --- def update_sw_heat_job(): global sw_heat_map print(f"[{datetime.now()}] 开始执行每日热度数据更新...") new_map = load_sw_map() if new_map: # 原子替换:这是线程安全的,正在运行的进程会自动切换到新引用 sw_heat_map = new_map print("sw_heat_map 更新成功") print(f"当前sw_heat_map示例: {list(sw_heat_map.items())[:5]}") # 打印前5项检查 # --- 3. 启动定时调度器 --- scheduler = BackgroundScheduler() # 每天 18:25 执行更新 scheduler.add_job(update_sw_heat_job, 'cron', hour=18, minute=25) scheduler.start() # 幂等性存储 - 记录已处理消息ID (使用线程安全的集合) processed_ids = set() processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问 # 创建消息队列用于批量处理 message_queue = Queue() BATCH_SIZE = 24 # 每批处理的消息数量 MAX_WORKERS = 24 # 线程池最大工作线程数 MIN_BATCH_SIZE = 12 # 最小批量处理消息数量 PROCESS_INTERVAL = 10 # 处理间隔(秒) def process_single_message(data): """处理单条消息的业务逻辑""" try: id_str = str(data["id"]) input_date = data["input_date"] # print(id_str + "\t" + str(input_date)) # 幂等性检查 with processed_ids_lock: if id_str in processed_ids: print(f"跳过已处理的消息: {id_str}") return None, True # 返回None表示不需要发送,True表示已处理 # 先标记为已处理,防止重复 processed_ids.add(id_str) if len(processed_ids) > 10000: processed_ids.clear() content = data.get('CN_content', "").strip() source = "其他" category_data = data.get('c', [{}]) category = "" if category_data: category = category_data[0].get('category', '') b_data = category_data[0].get('b', [{}]) if b_data: d_data = b_data[0].get('d', [{}]) if d_data: source = d_data[0].get('sourcename', "其他") source_impact = media_score.get(source, 5) tagged_news = get_label(content, source) # deal with the label problems tagged_news = validate_tags(tagged_news, "industry_label", "industry_confidence", industry_mapping) tagged_news = validate_tags(tagged_news, "concept_label", "concept_confidence", concept_mapping) industry_confidence = tagged_news.get("industry_confidence", []) concept_confidence = tagged_news.get("concept_confidence", []) industry_confidence = list(map(lambda x: round(x, 2), industry_confidence)) concept_confidence = list(map(lambda x: round(x, 2), concept_confidence)) industry_label = tagged_news.get("industry_label", []) sw_heat = 10 # 默认热度分,如果没有标签或热度数据缺失则使用默认值 if industry_label and industry_confidence: first_industry = industry_label[0].split("-")[0] first_industry_confidence = industry_confidence[0] if first_industry_confidence >= 0.75: sw_heat = sw_heat_map.get(first_industry, 0) public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分 China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度 news_score = source_impact * 0.03 + public_opinion_score * 0.3 + China_factor * 35 + sw_heat * 0.05 industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence)) concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence)) # 增加一个关键词加分, 例如原文标题命中 Google 加 5 分, 不区分大小写 ori_title = data.get('title_EN', '') hits = [] for word, bonus in words_bonus.items(): if re.search(rf'\b{re.escape(word)}\b', ori_title, re.IGNORECASE): hits.append((word, bonus)) if hits: # 按分数降序排序,取最高分项 hits_sorted = sorted(hits, key=lambda x: x[1], reverse=True) print(f"原文标题{ori_title} 命中关键词及分数:", hits_sorted) add_score = hits_sorted[0][-1] * 0.05 print(f"新闻{id_str} 当前分数news_score变动: {news_score} -> {news_score + add_score} ... 如果超出上限后续会调整...") news_score = min(99.0, news_score + add_score) else: pass # 增加一个聚类打分,目前仅作试验 if datetime.now().minute == 12: if not news_cluster.load_clusters(): print(f"聚类文件加载失败 {datetime.now()}") CN_title = data.get('title_txt', '') cluster_evaluate = news_cluster.evaluate_news(title=CN_title, content=content) cluster_score = (cluster_evaluate['weighted_score'] ** 0.5) * 10 if cluster_evaluate else 20.0 center_news_id = cluster_evaluate['center_news_id'] if cluster_evaluate else None news_score_exp = source_impact * 0.04 + cluster_score * 0.25 + China_factor * 35 # if it's Ads, decrease the score if all(keyword in content for keyword in ["邮箱", "电话", "网站"]): print("新闻判别为广告软文,分数强制8折") news_score *= 0.79 news_score_exp *= 0.79 # if there is no labels at all, decrease the score if news_score >= 80 and len(tagged_news.get("industry_label", [])) == 0 and len(tagged_news.get("concept_label", [])) == 0: print("新闻没有打上任何有效标签,分数强制8折") news_score *= 0.79 news_score_exp *= 0.79 # if there are sensitive words in title and content, decrease the score if news_score >= 80: title = data.get('title_txt', '') news_text = f"{title} {content}" result = sensitive_filter.detect(news_text) if result: print("新闻命中敏感词: ", ", ".join(list(result)), "; 分数强制8折") news_score *= 0.79 news_score_exp *= 0.79 # make sure the selected news label confidence is higher than 0.7 if news_score >= 80: industry_high_confidence = industry_confidence[0] if len(industry_confidence) > 0 else 0 concept_high_confidence = concept_confidence[0] if len(concept_confidence) > 0 else 0 if max(industry_high_confidence, concept_high_confidence) < 0.7: news_score *= 0.79 news_score_exp *= 0.79 print(f"行业标签置信度为{industry_high_confidence}, 概念标签置信度为{concept_high_confidence}, 新闻得分强制8折.") news_score = round(news_score, 2) news_score_exp = round(news_score_exp, 2) cluster_score = round(cluster_score, 2) # if score >= 80, add translation by llm if news_score >= 80: ori_content = data.get('EN_content', '') ori_title = data.get('title_EN', '') for _ in range(3): # 翻译功能存在不稳定, 至多重试三次 translation = get_translation(ori_content, ori_title) llm_content = translation.get("llm_content", "") if translation else "" llm_title = translation.get("llm_title", "") if translation else "" if not is_mostly_chinese(llm_title, threshold=0.4): print("新闻标题没有被翻译成中文, 会被舍弃, LLM标题为: ", llm_title) llm_title = "" continue if not is_mostly_chinese(llm_content, threshold=0.5): print("新闻正文没有被翻译成中文, 会被舍弃, LLM正文为: ", llm_content[:30] + "...") llm_content = "" continue # 20250922新增 海外事件 海外宏观 中国宏观 行业新闻 公司新闻 转载来源 公司名称 等精细处理字段 overseas_event = translation.get("overseas_event", None) if translation else None overseas_macro = translation.get("overseas_macro", None) if translation else None china_macro = translation.get("china_macro", None) if translation else None industry_news = translation.get("industry_news", None) if translation else None company_news = translation.get("company_news", None) if translation else None reprint_source = translation.get("reprint_source", None) if translation else None if reprint_source == "": reprint_source = None company_name = translation.get("company_name", None) if translation else None if company_name: # 打个补丁, 让公司析出数强制不超过3 2025-12-19 company_name_list = company_name.strip().split(",") if len(company_name_list) > 3: company_name = ",".join(company_name_list[:3]) if company_name == "": company_name = None ## 行业和之前的标签冗余,暂时去掉 # industry_name = translation.get("industry_name", None) if translation else None # 20251217新增 修改了翻译的一些规则; 新增通达信行业标签及置信度 tdx_industry_ori = translation.get("tdx_industry", None) if translation else None tdx_industry_confidence_ori = translation.get("tdx_industry_confidence", None) if translation else None tdx_map = get_tdx_map() tdx_industry = [] tdx_industry_confidence = [] if tdx_industry_ori and tdx_industry_confidence_ori and len(tdx_industry_confidence_ori) == len(tdx_industry_ori): for industry, confidence in zip(tdx_industry_ori, tdx_industry_confidence_ori): if industry in tdx_map: tdx_industry.append(industry) tdx_industry_confidence.append(float(confidence)) if len(tdx_industry) == 3: break # 20251226 新增 增加了一项析出A股代码的操作, 列表形式返回A股上市公司6位代码,至多5个 stock_codes_ori = translation.get("stock_codes", None) if translation else None if stock_codes_ori: stock_codes_ori = list(set(stock_codes_ori)) stock_codes = [] stock_names = [] stock_map = get_stock_map() if stock_codes_ori and isinstance(stock_codes_ori, list): for stock_code in stock_codes_ori: if stock_code in stock_map: stock_name = stock_map.get(stock_code) stock_codes.append(stock_code) stock_names.append(stock_name) # 20260128 新增审核, 返回政治敏感度、出口转内销两项打分 political_sensitivity = translation.get("political_sensitivity", None) if translation else None export_domestic = translation.get("export_domestic", None) if translation else None political_notes = translation.get("political_notes", None) if translation else None additional_notes = translation.get("additional_notes", None) if translation else None # 20260226 新增大模型摘要 llm_abstract = translation.get("llm_abstract", None) if translation else None if llm_content and llm_title and len(llm_content) > 30 and len(llm_title) > 3 and overseas_event is not None and overseas_macro is not None: break # 精选新闻添加etf etf_labels = get_etf_labels(title=llm_title, industry_label=tagged_news.get("industry_label", []), industry_confidence=tagged_news.get("industry_confidence", []), concept_label=tagged_news.get("concept_label", []), concept_confidence=tagged_news.get("concept_confidence", [])) etf_names = [etf_name[label] for label in etf_labels] tagged_news["etf_labels"] = etf_labels tagged_news["etf_names"] = etf_names tagged_news["title"] = llm_title tagged_news["rewrite_content"] = llm_content tagged_news["overseas_event"] = overseas_event tagged_news["overseas_macro"] = overseas_macro tagged_news["china_macro"] = china_macro tagged_news["industry_news"] = industry_news tagged_news["company_news"] = company_news tagged_news["reprint_source"] = reprint_source tagged_news["company_name"] = company_name tagged_news["tdx_industry"] = tdx_industry if tdx_industry else None tagged_news["tdx_industry_confidence"] = tdx_industry_confidence if tdx_industry_confidence else None tagged_news["stock_codes"] = stock_codes if stock_codes else None tagged_news["stock_names"] = stock_names if stock_names else None tagged_news["political_sensitivity"] = political_sensitivity if political_sensitivity else None tagged_news["export_domestic"] = export_domestic if export_domestic else None tagged_news["political_notes"] = political_notes if political_notes else None tagged_news["additional_notes"] = additional_notes if additional_notes else None tagged_news["llm_abstract"] = llm_abstract if llm_abstract else None tagged_news["source"] = source tagged_news["source_impact"] = source_impact tagged_news["industry_score"] = industry_score tagged_news["concept_score"] = concept_score tagged_news["news_score"] = news_score tagged_news["news_score_exp"] = news_score_exp tagged_news["center_news_id"] = center_news_id tagged_news["id"] = id_str tagged_news["cluster_score"] = cluster_score tagged_news["industry_confidence"] = industry_confidence tagged_news["concept_confidence"] = concept_confidence if news_score >= 80: # 保险起见对大模型翻译内容再做一遍敏感词筛查 20260128 news_title_llm = tagged_news.get("title", "") news_content_llm = tagged_news.get("rewrite_content", "") news_text_llm = f"{news_title_llm} {news_content_llm}" result = sensitive_filter.detect(news_text_llm) if result: print("大模型翻译内容命中敏感词: ", ", ".join(list(result)), "; 分数强制8折") tagged_news["news_score"] = news_score * 0.79 # export_domestic >= 80 或者 political_sensitivity >= 70 也直接移除精选池 if tagged_news.get("export_domestic", 0) >= 80 or tagged_news.get("political_sensitivity", 0) >= 70: print(f'本条资讯的出口转内销得分为{tagged_news.get("export_domestic", 0)}, 政治敏感得分为{tagged_news.get("political_sensitivity", 0)}, political_notes信息为【{tagged_news.get("political_notes", "")}】, 分数8折') tagged_news["news_score"] = news_score * 0.79 #print(json.dumps(tagged_news, ensure_ascii=False)) print(tagged_news["id"], tagged_news["news_score"], tagged_news["news_score_exp"], tagged_news["industry_label"], tagged_news["concept_label"], public_opinion_score, tagged_news["cluster_score"], tagged_news["center_news_id"], input_date) if news_score >= 80: print("LLM translation: ", tagged_news.get("rewrite_content", "")[:100]) print("LLM overseas_event score: ", tagged_news.get("overseas_event", None)) print("LLM overseas_macro score: ", tagged_news.get("overseas_macro", None)) print("LLM china_macro score: ", tagged_news.get("china_macro", None)) if tagged_news.get("reprint_source", None): print("LLM reprint_source: ", tagged_news.get("reprint_source", None)) if tagged_news.get("company_name", None): print("LLM company_name: ", tagged_news.get("company_name", None), tagged_news.get("company_news", None)) print("ETF NAMES: ", tagged_news.get("etf_names", [])) if tagged_news.get("tdx_industry", None): print("LLM_max tdx_industry: ", tagged_news.get("tdx_industry", None)) if tagged_news.get("tdx_industry_confidence", None): print("LLM_max tdx_industry_confidence: ", tagged_news.get("tdx_industry_confidence", None)) if tagged_news.get("stock_names", None): print("LLM_max stock codes & names: ", tagged_news.get("stock_codes", None), tagged_news.get("stock_names", None)) if tagged_news.get("political_sensitivity", None): print("LLM_max political_sensitivity: ", tagged_news.get("political_sensitivity", None)) if tagged_news.get("llm_abstract", None): print("LLM_max llm_abstract: ", tagged_news.get("llm_abstract", None)[:30]) return tagged_news, True except Exception as e: print(f"处理消息时出错: {str(e)}") # 处理失败,从已处理集合中移除 with processed_ids_lock: if id_str in processed_ids: processed_ids.remove(id_str) return None, False def process_message_batch(batch): start_time = time.time() """并行处理一批消息""" results = [] # 使用线程池并行处理 with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: futures = [] for data in batch: futures.append(executor.submit(process_single_message, data)) for future in futures: try: result, success = future.result() if result: results.append(result) except Exception as e: print(f"处理消息时发生异常: {str(e)}") # 发送处理结果到MQ for result in results: try: send_mq(result) except Exception as e: print(f"发送消息到MQ失败: {str(e)}") duration = time.time() - start_time print(f"批量处理 {len(batch)} 条消息, 耗时: {duration:.2f}s, " f"平均: {duration/len(batch):.3f}s/条") def message_callback(ch, method, properties, body): """消息处理回调函数(只负责入队)""" try: data = json.loads(body) # 将消息和delivery_tag一起放入队列 message_queue.put((data, method.delivery_tag)) except Exception as e: print(f"消息处理失败: {str(e)}") # 拒绝消息, 不重新入队 ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) def create_connection(): """创建并返回RabbitMQ连接""" credentials = pika.PlainCredentials(mq_user, mq_password) return pika.BlockingConnection( pika.ConnectionParameters( host="localhost", credentials=credentials, heartbeat=600, connection_attempts=3, retry_delay=5 # 重试延迟5秒 ) ) def start_consumer(): """启动MQ消费者(批量版本)""" while True: try: connection = create_connection() channel = connection.channel() # 设置QoS,一次预取足够数量的消息 channel.basic_qos(prefetch_count=BATCH_SIZE * 3) channel.exchange_declare( exchange="zzck_exchange", exchange_type="fanout" ) # 声明队列 res = channel.queue_declare(queue="to_ai") # res = channel.queue_declare(queue='', exclusive=True) mq_queue = res.method.queue channel.queue_bind( exchange="zzck_exchange", queue=mq_queue, ) # 启动消费,关闭自动ACK channel.basic_consume( queue=mq_queue, on_message_callback=message_callback, auto_ack=False ) print(f"消费者已启动,批量大小: {BATCH_SIZE}, 工作线程: {MAX_WORKERS}, 等待消息...") last_process_time = time.time() # 主循环 while True: # 处理网络事件 connection.process_data_events(time_limit=0.1) # 非阻塞处理 current_time = time.time() queue_size = message_queue.qsize() # 双重触发机制:达到批量大小或超过处理间隔 if queue_size >= BATCH_SIZE or \ (current_time - last_process_time >= PROCESS_INTERVAL and queue_size >= MIN_BATCH_SIZE): batch = [] delivery_tags = [] # 获取一批消息(最多BATCH_SIZE条) while not message_queue.empty() and len(batch) < BATCH_SIZE: data, delivery_tag = message_queue.get() batch.append(data) delivery_tags.append(delivery_tag) if batch: # 处理批量消息 process_message_batch(batch) # 确认消息 for tag in delivery_tags: channel.basic_ack(tag) last_process_time = current_time # 如果队列很小但等待时间过长,确保不会永远不处理 elif current_time - last_process_time >= PROCESS_INTERVAL * 5 and queue_size > 0: # 处理剩余的所有消息 batch = [] delivery_tags = [] while not message_queue.empty(): data, delivery_tag = message_queue.get() batch.append(data) delivery_tags.append(delivery_tag) if batch: process_message_batch(batch) for tag in delivery_tags: channel.basic_ack(tag) last_process_time = current_time # 检查连接是否关闭 if not connection or connection.is_closed: break except pika.exceptions.ConnectionClosedByBroker: print("连接被代理关闭,将在5秒后重试...") time.sleep(5) except pika.exceptions.AMQPConnectionError: print("连接失败,将在10秒后重试...") time.sleep(10) except KeyboardInterrupt: print("消费者被用户中断") try: if connection and connection.is_open: connection.close() except: pass break except Exception as e: print(f"消费者异常: {str(e)}") print("将在15秒后重试...") time.sleep(15) finally: try: if connection and connection.is_open: connection.close() except: pass if __name__ == "__main__": start_consumer()