import pika import json import logging import time import os import threading from concurrent.futures import ThreadPoolExecutor from queue import Queue from config import * from llm_process import send_mq, get_label # 声明一个全局变量,存媒体的权威度打分 media_score = {} with open("media_score.txt", "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: media, score = line.split("\t") media_score[media.strip()] = int(score) except ValueError as e: print(f"解析错误: {e},行内容: {line}") continue # 幂等性存储 - 记录已处理消息ID (使用线程安全的集合) processed_ids = set() processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问 # 创建消息队列用于批量处理 message_queue = Queue() BATCH_SIZE = 24 # 每批处理的消息数量 MAX_WORKERS = 24 # 线程池最大工作线程数 MIN_BATCH_SIZE = 12 # 最小批量处理消息数量 PROCESS_INTERVAL = 10 # 处理间隔(秒) def process_single_message(data): """处理单条消息的业务逻辑""" try: id_str = str(data["id"]) input_date = data["input_date"] # print(id_str + "\t" + str(input_date)) # 幂等性检查 with processed_ids_lock: if id_str in processed_ids: print(f"跳过已处理的消息: {id_str}") return None, True # 返回None表示不需要发送,True表示已处理 # 先标记为已处理,防止重复 processed_ids.add(id_str) if len(processed_ids) > 10000: processed_ids.clear() content = data.get('CN_content', "").strip() source = "其他" category_data = data.get('c', [{}]) category = "" if category_data: category = category_data[0].get('category', '') b_data = category_data[0].get('b', [{}]) if b_data: d_data = b_data[0].get('d', [{}]) if d_data: source = d_data[0].get('sourcename', "其他") source_impact = media_score.get(source, 5) tagged_news = get_label(content, source) public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分 China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度 news_score = source_impact * 0.04 + public_opinion_score * 0.25 + China_factor * 35 news_score = round(news_score, 2) industry_confidence = tagged_news.get("industry_confidence", []) industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence)) concept_confidence = tagged_news.get("concept_confidence", []) concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence)) # 确保最终展示的分数是两位小数 industry_confidence = list(map(lambda x: round(x, 2), industry_confidence)) concept_confidence = list(map(lambda x: round(x, 2), concept_confidence)) tagged_news["source"] = source tagged_news["source_impact"] = source_impact tagged_news["industry_score"] = industry_score tagged_news["concept_score"] = concept_score tagged_news["news_score"] = news_score tagged_news["id"] = id_str tagged_news["industry_confidence"] = industry_confidence tagged_news["concept_confidence"] = concept_confidence #print(json.dumps(tagged_news, ensure_ascii=False)) print(tagged_news["id"], tagged_news["title"], tagged_news["news_score"], tagged_news["industry_label"], input_date) return tagged_news, True except Exception as e: print(f"处理消息时出错: {str(e)}") # 处理失败,从已处理集合中移除 with processed_ids_lock: if id_str in processed_ids: processed_ids.remove(id_str) return None, False def process_message_batch(batch): start_time = time.time() """并行处理一批消息""" results = [] # 使用线程池并行处理 with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: futures = [] for data in batch: futures.append(executor.submit(process_single_message, data)) for future in futures: try: result, success = future.result() if result: results.append(result) except Exception as e: print(f"处理消息时发生异常: {str(e)}") # 发送处理结果到MQ for result in results: try: send_mq(result) except Exception as e: print(f"发送消息到MQ失败: {str(e)}") duration = time.time() - start_time print(f"批量处理 {len(batch)} 条消息, 耗时: {duration:.2f}s, " f"平均: {duration/len(batch):.3f}s/条") def message_callback(ch, method, properties, body): """消息处理回调函数(只负责入队)""" try: data = json.loads(body) # 将消息和delivery_tag一起放入队列 message_queue.put((data, method.delivery_tag)) except Exception as e: print(f"消息处理失败: {str(e)}") # 拒绝消息, 不重新入队 ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) def create_connection(): """创建并返回RabbitMQ连接""" credentials = pika.PlainCredentials(mq_user, mq_password) return pika.BlockingConnection( pika.ConnectionParameters( host="localhost", credentials=credentials, heartbeat=600, connection_attempts=3, retry_delay=5 # 重试延迟5秒 ) ) def start_consumer(): """启动MQ消费者(批量版本)""" while True: try: connection = create_connection() channel = connection.channel() # 设置QoS,一次预取足够数量的消息 channel.basic_qos(prefetch_count=BATCH_SIZE * 3) channel.exchange_declare( exchange="zzck_exchange", exchange_type="fanout" ) # 声明队列 res = channel.queue_declare(queue="to_ai") # res = channel.queue_declare(queue='', exclusive=True) mq_queue = res.method.queue channel.queue_bind( exchange="zzck_exchange", queue=mq_queue, ) # 启动消费,关闭自动ACK channel.basic_consume( queue=mq_queue, on_message_callback=message_callback, auto_ack=False ) print(f"消费者已启动,批量大小: {BATCH_SIZE}, 工作线程: {MAX_WORKERS}, 等待消息...") last_process_time = time.time() # 主循环 while True: # 处理网络事件 connection.process_data_events(time_limit=0.1) # 非阻塞处理 current_time = time.time() queue_size = message_queue.qsize() # 双重触发机制:达到批量大小或超过处理间隔 if queue_size >= BATCH_SIZE or \ (current_time - last_process_time >= PROCESS_INTERVAL and queue_size >= MIN_BATCH_SIZE): batch = [] delivery_tags = [] # 获取一批消息(最多BATCH_SIZE条) while not message_queue.empty() and len(batch) < BATCH_SIZE: data, delivery_tag = message_queue.get() batch.append(data) delivery_tags.append(delivery_tag) if batch: # 处理批量消息 process_message_batch(batch) # 确认消息 for tag in delivery_tags: channel.basic_ack(tag) last_process_time = current_time # 如果队列很小但等待时间过长,确保不会永远不处理 elif current_time - last_process_time >= PROCESS_INTERVAL * 5 and queue_size > 0: # 处理剩余的所有消息 batch = [] delivery_tags = [] while not message_queue.empty(): data, delivery_tag = message_queue.get() batch.append(data) delivery_tags.append(delivery_tag) if batch: process_message_batch(batch) for tag in delivery_tags: channel.basic_ack(tag) last_process_time = current_time # 检查连接是否关闭 if not connection or connection.is_closed: break except pika.exceptions.ConnectionClosedByBroker: print("连接被代理关闭,将在5秒后重试...") time.sleep(5) except pika.exceptions.AMQPConnectionError: print("连接失败,将在10秒后重试...") time.sleep(10) except KeyboardInterrupt: print("消费者被用户中断") try: if connection and connection.is_open: connection.close() except: pass break except Exception as e: print(f"消费者异常: {str(e)}") print("将在15秒后重试...") time.sleep(15) finally: try: if connection and connection.is_open: connection.close() except: pass if __name__ == "__main__": start_consumer()