265 lines
10 KiB
Python
265 lines
10 KiB
Python
|
|
import pika
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import time
|
|||
|
|
import os
|
|||
|
|
import threading
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
from queue import Queue
|
|||
|
|
from config import *
|
|||
|
|
from llm_process import send_mq, get_label
|
|||
|
|
|
|||
|
|
# 声明一个全局变量,存媒体的权威度打分
|
|||
|
|
media_score = {}
|
|||
|
|
with open("media_score.txt", "r", encoding="utf-8") as f:
|
|||
|
|
for line in f:
|
|||
|
|
line = line.strip()
|
|||
|
|
if not line:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
media, score = line.split("\t")
|
|||
|
|
media_score[media.strip()] = int(score)
|
|||
|
|
except ValueError as e:
|
|||
|
|
print(f"解析错误: {e},行内容: {line}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 幂等性存储 - 记录已处理消息ID (使用线程安全的集合)
|
|||
|
|
processed_ids = set()
|
|||
|
|
processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问
|
|||
|
|
|
|||
|
|
# 创建消息队列用于批量处理
|
|||
|
|
message_queue = Queue()
|
|||
|
|
BATCH_SIZE = 24 # 每批处理的消息数量
|
|||
|
|
MAX_WORKERS = 24 # 线程池最大工作线程数
|
|||
|
|
MIN_BATCH_SIZE = 12 # 最小批量处理消息数量
|
|||
|
|
PROCESS_INTERVAL = 10 # 处理间隔(秒)
|
|||
|
|
|
|||
|
|
def process_single_message(data):
|
|||
|
|
"""处理单条消息的业务逻辑"""
|
|||
|
|
try:
|
|||
|
|
id_str = str(data["id"])
|
|||
|
|
input_date = data["input_date"]
|
|||
|
|
# print(id_str + "\t" + str(input_date))
|
|||
|
|
|
|||
|
|
# 幂等性检查
|
|||
|
|
with processed_ids_lock:
|
|||
|
|
if id_str in processed_ids:
|
|||
|
|
print(f"跳过已处理的消息: {id_str}")
|
|||
|
|
return None, True # 返回None表示不需要发送,True表示已处理
|
|||
|
|
# 先标记为已处理,防止重复
|
|||
|
|
processed_ids.add(id_str)
|
|||
|
|
if len(processed_ids) > 10000:
|
|||
|
|
processed_ids.clear()
|
|||
|
|
|
|||
|
|
content = data.get('CN_content', "").strip()
|
|||
|
|
source = "其他"
|
|||
|
|
category_data = data.get('c', [{}])
|
|||
|
|
category = ""
|
|||
|
|
if category_data:
|
|||
|
|
category = category_data[0].get('category', '')
|
|||
|
|
b_data = category_data[0].get('b', [{}])
|
|||
|
|
if b_data:
|
|||
|
|
d_data = b_data[0].get('d', [{}])
|
|||
|
|
if d_data:
|
|||
|
|
source = d_data[0].get('sourcename', "其他")
|
|||
|
|
source_impact = media_score.get(source, 5)
|
|||
|
|
tagged_news = get_label(content, source)
|
|||
|
|
public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分
|
|||
|
|
China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度
|
|||
|
|
news_score = source_impact * 0.04 + public_opinion_score * 0.25 + China_factor * 35
|
|||
|
|
news_score = round(news_score, 2)
|
|||
|
|
|
|||
|
|
industry_confidence = tagged_news.get("industry_confidence", [])
|
|||
|
|
industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence))
|
|||
|
|
concept_confidence = tagged_news.get("concept_confidence", [])
|
|||
|
|
concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence))
|
|||
|
|
|
|||
|
|
# 确保最终展示的分数是两位小数
|
|||
|
|
industry_confidence = list(map(lambda x: round(x, 2), industry_confidence))
|
|||
|
|
concept_confidence = list(map(lambda x: round(x, 2), concept_confidence))
|
|||
|
|
|
|||
|
|
tagged_news["source"] = source
|
|||
|
|
tagged_news["source_impact"] = source_impact
|
|||
|
|
tagged_news["industry_score"] = industry_score
|
|||
|
|
tagged_news["concept_score"] = concept_score
|
|||
|
|
tagged_news["news_score"] = news_score
|
|||
|
|
tagged_news["id"] = id_str
|
|||
|
|
|
|||
|
|
#print(json.dumps(tagged_news, ensure_ascii=False))
|
|||
|
|
print(tagged_news["id"], tagged_news["title"], tagged_news["news_score"], tagged_news["industry_label"], input_date)
|
|||
|
|
return tagged_news, True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"处理消息时出错: {str(e)}")
|
|||
|
|
# 处理失败,从已处理集合中移除
|
|||
|
|
with processed_ids_lock:
|
|||
|
|
if id_str in processed_ids:
|
|||
|
|
processed_ids.remove(id_str)
|
|||
|
|
return None, False
|
|||
|
|
|
|||
|
|
def process_message_batch(batch):
|
|||
|
|
start_time = time.time()
|
|||
|
|
"""并行处理一批消息"""
|
|||
|
|
results = []
|
|||
|
|
# 使用线程池并行处理
|
|||
|
|
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
|||
|
|
futures = []
|
|||
|
|
for data in batch:
|
|||
|
|
futures.append(executor.submit(process_single_message, data))
|
|||
|
|
|
|||
|
|
for future in futures:
|
|||
|
|
try:
|
|||
|
|
result, success = future.result()
|
|||
|
|
if result:
|
|||
|
|
results.append(result)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"处理消息时发生异常: {str(e)}")
|
|||
|
|
|
|||
|
|
# 发送处理结果到MQ
|
|||
|
|
for result in results:
|
|||
|
|
try:
|
|||
|
|
send_mq(result)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"发送消息到MQ失败: {str(e)}")
|
|||
|
|
|
|||
|
|
duration = time.time() - start_time
|
|||
|
|
print(f"批量处理 {len(batch)} 条消息, 耗时: {duration:.2f}s, "
|
|||
|
|
f"平均: {duration/len(batch):.3f}s/条")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def message_callback(ch, method, properties, body):
|
|||
|
|
"""消息处理回调函数(只负责入队)"""
|
|||
|
|
try:
|
|||
|
|
data = json.loads(body)
|
|||
|
|
# 将消息和delivery_tag一起放入队列
|
|||
|
|
message_queue.put((data, method.delivery_tag))
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"消息处理失败: {str(e)}")
|
|||
|
|
# 拒绝消息, 不重新入队
|
|||
|
|
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
|
|||
|
|
|
|||
|
|
def create_connection():
|
|||
|
|
"""创建并返回RabbitMQ连接"""
|
|||
|
|
credentials = pika.PlainCredentials(mq_user, mq_password)
|
|||
|
|
return pika.BlockingConnection(
|
|||
|
|
pika.ConnectionParameters(
|
|||
|
|
host="localhost",
|
|||
|
|
credentials=credentials,
|
|||
|
|
heartbeat=600,
|
|||
|
|
connection_attempts=3,
|
|||
|
|
retry_delay=5 # 重试延迟5秒
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def start_consumer():
|
|||
|
|
"""启动MQ消费者(批量版本)"""
|
|||
|
|
while True:
|
|||
|
|
try:
|
|||
|
|
connection = create_connection()
|
|||
|
|
channel = connection.channel()
|
|||
|
|
|
|||
|
|
# 设置QoS,一次预取足够数量的消息
|
|||
|
|
channel.basic_qos(prefetch_count=BATCH_SIZE * 3)
|
|||
|
|
|
|||
|
|
channel.exchange_declare(
|
|||
|
|
exchange="zzck_exchange",
|
|||
|
|
exchange_type="fanout"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 声明队列
|
|||
|
|
res = channel.queue_declare(queue="to_ai")
|
|||
|
|
# res = channel.queue_declare(queue='', exclusive=True)
|
|||
|
|
mq_queue = res.method.queue
|
|||
|
|
channel.queue_bind(
|
|||
|
|
exchange="zzck_exchange",
|
|||
|
|
queue=mq_queue,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 启动消费,关闭自动ACK
|
|||
|
|
channel.basic_consume(
|
|||
|
|
queue=mq_queue,
|
|||
|
|
on_message_callback=message_callback,
|
|||
|
|
auto_ack=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"消费者已启动,批量大小: {BATCH_SIZE}, 工作线程: {MAX_WORKERS}, 等待消息...")
|
|||
|
|
|
|||
|
|
last_process_time = time.time()
|
|||
|
|
# 主循环
|
|||
|
|
while True:
|
|||
|
|
# 处理网络事件
|
|||
|
|
connection.process_data_events(time_limit=0.1) # 非阻塞处理
|
|||
|
|
|
|||
|
|
current_time = time.time()
|
|||
|
|
|
|||
|
|
queue_size = message_queue.qsize()
|
|||
|
|
# 双重触发机制:达到批量大小或超过处理间隔
|
|||
|
|
if queue_size >= BATCH_SIZE or \
|
|||
|
|
(current_time - last_process_time >= PROCESS_INTERVAL and queue_size >= MIN_BATCH_SIZE):
|
|||
|
|
|
|||
|
|
batch = []
|
|||
|
|
delivery_tags = []
|
|||
|
|
# 获取一批消息(最多BATCH_SIZE条)
|
|||
|
|
while not message_queue.empty() and len(batch) < BATCH_SIZE:
|
|||
|
|
data, delivery_tag = message_queue.get()
|
|||
|
|
batch.append(data)
|
|||
|
|
delivery_tags.append(delivery_tag)
|
|||
|
|
|
|||
|
|
if batch:
|
|||
|
|
# 处理批量消息
|
|||
|
|
process_message_batch(batch)
|
|||
|
|
|
|||
|
|
# 确认消息
|
|||
|
|
for tag in delivery_tags:
|
|||
|
|
channel.basic_ack(tag)
|
|||
|
|
|
|||
|
|
last_process_time = current_time
|
|||
|
|
|
|||
|
|
# 如果队列很小但等待时间过长,确保不会永远不处理
|
|||
|
|
elif current_time - last_process_time >= PROCESS_INTERVAL * 5 and queue_size > 0:
|
|||
|
|
# 处理剩余的所有消息
|
|||
|
|
batch = []
|
|||
|
|
delivery_tags = []
|
|||
|
|
while not message_queue.empty():
|
|||
|
|
data, delivery_tag = message_queue.get()
|
|||
|
|
batch.append(data)
|
|||
|
|
delivery_tags.append(delivery_tag)
|
|||
|
|
|
|||
|
|
if batch:
|
|||
|
|
process_message_batch(batch)
|
|||
|
|
for tag in delivery_tags:
|
|||
|
|
channel.basic_ack(tag)
|
|||
|
|
last_process_time = current_time
|
|||
|
|
|
|||
|
|
# 检查连接是否关闭
|
|||
|
|
if not connection or connection.is_closed:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
except pika.exceptions.ConnectionClosedByBroker:
|
|||
|
|
print("连接被代理关闭,将在5秒后重试...")
|
|||
|
|
time.sleep(5)
|
|||
|
|
except pika.exceptions.AMQPConnectionError:
|
|||
|
|
print("连接失败,将在10秒后重试...")
|
|||
|
|
time.sleep(10)
|
|||
|
|
except KeyboardInterrupt:
|
|||
|
|
print("消费者被用户中断")
|
|||
|
|
try:
|
|||
|
|
if connection and connection.is_open:
|
|||
|
|
connection.close()
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
break
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"消费者异常: {str(e)}")
|
|||
|
|
print("将在15秒后重试...")
|
|||
|
|
time.sleep(15)
|
|||
|
|
finally:
|
|||
|
|
try:
|
|||
|
|
if connection and connection.is_open:
|
|||
|
|
connection.close()
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
start_consumer()
|