zzck/mqreceivefromllm.py

259 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 获取 标签的消息 写入数据库
import pika
import json
import logging, time
from config import *
import pymysql
from elasticsearch import Elasticsearch
import datetime
import requests
def message_callback(ch, method, properties, body):
"""消息处理回调函数"""
try:
data = json.loads(body)
news_score = data.get('news_score', -1)
if news_score < 0:
ch.basic_ack(delivery_tag=method.delivery_tag)
return
# 在此处添加业务处理逻辑 写入mysql数据库
write_to_mysql(data)
# 数据写入es
write_to_es(data)
# 数据写入资讯精选表
write_to_news(data)
# 手动确认消息
ch.basic_ack(delivery_tag=method.delivery_tag)
except Exception as e:
print(f"消息处理失败: {str(e)}")
# 拒绝消息并重新入队
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
def write_to_news(data):
news_score = data.get('news_score', 0.0)
if float(news_score) < 80: # 过滤掉news_score小于80的消息
return
# 获取返回数据里面的 新闻id
news_id = data.get('id', "")
adr = jx_adr.replace("news_id", news_id)
print(f"接口地址为{adr}")
response = requests.get(adr)
if response.status_code != 200:
print(f"新闻id:{news_id} 得分:{news_score}, 调用精选接口失败, 错误码:{response.status_code}")
return
print(f"新闻id:{news_id} 得分:{news_score}, 调用精选接口成功")
def write_to_es(data):
"""写入ES"""
# 初始化ES连接添加在文件顶部
es = Elasticsearch(
[f"http://{ES_HOST}:{ES_PORT}"], # 将协议直接包含在hosts中
basic_auth=(ES_USER, ES_PASSWORD)
)
news_id = data.get('id', "")
es.update(
index="news_info",
id=news_id,
doc={
"news_tags": {
"id": news_id,
"abstract": data.get('abstract', ""),
"title": data.get('title', ""),
"rewrite_content": data.get('rewrite_content', ""),
"industry_label": data.get('industry_label', []),
"industry_confidence": data.get('industry_confidence', []),
"industry_score": data.get('industry_score', []),
"concept_label": data.get('concept_label', []),
"concept_confidence": data.get('concept_confidence', []),
"concept_score": data.get('concept_score', []),
"public_opinion_score": data.get('public_opinion_score', 10),
"China_factor": data.get('China_factor', 0.1),
"source": data.get('source', "其他"),
"source_impact": data.get('source_impact', 5),
"news_score": data.get('news_score', 0.0),
"news_id": news_id,
"deleted": '0',
"create_time": datetime.datetime.now(),
"update_time": datetime.datetime.now()
}
}
)
print(f"news_id:{news_id} 得分:{data.get('news_score', 0.0)}, 写入ES成功")
def write_to_mysql(data):
conn = pymysql.connect(
host=MYSQL_HOST_APP,
port=MYSQL_PORT_APP,
user=MYSQL_USER_APP,
password=MYSQL_PASSWORD_APP,
db=MYSQL_DB_APP,
charset='utf8mb4'
)
try:
with conn.cursor() as cursor:
# 新增JSON结构解析逻辑
# 修改后的SQL语句
sql = """INSERT INTO news_tags
(abstract, title, rewrite_content, industry_label, industry_confidence, industry_score, concept_label, concept_confidence, concept_score, public_opinion_score, China_factor, source, source_impact, news_score, news_id)
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """
values = (data.get('abstract', ""),
data.get('title', ""),
data.get('rewrite_content', ""),
json.dumps(data.get('industry_label', [])),
json.dumps(data.get('industry_confidence', [])),
json.dumps(data.get('industry_score', [])),
json.dumps(data.get('concept_label', [])),
json.dumps(data.get('concept_confidence', [])),
json.dumps(data.get('concept_score', [])),
data.get('public_opinion_score', 10),
data.get('China_factor', 0.1),
data.get('source', "其他"),
data.get('source_impact', 5),
data.get('news_score', 0.0),
data.get('id', "")
)
cursor.execute(sql, values)
conn.commit()
id = data.get('id', "")
industry_label = data.get('industry_label', [])
concept_label = data.get('concept_label', [])
print(f"{id} {industry_label} {concept_label}, 写入news_tags 表成功")
except Exception as e:
print(f"写入news_tags失败: {str(e)}")
finally:
conn.close()
return True
def create_connection():
"""创建并返回RabbitMQ连接"""
credentials = pika.PlainCredentials(mq_user, mq_password)
return pika.BlockingConnection(
pika.ConnectionParameters(
host="localhost",
credentials=credentials,
heartbeat=600,
connection_attempts=3,
retry_delay=5 # 重试延迟5秒
)
)
def start_consumer():
"""启动MQ消费者"""
while True: # 使用循环而不是递归,避免递归深度问题
try:
connection = create_connection()
channel = connection.channel()
# 设置QoS限制每次只取一条消息
channel.basic_qos(prefetch_count=1)
channel.exchange_declare(
exchange="zzck_llm_exchange",
exchange_type="fanout"
)
# 声明持久化队列
res = channel.queue_declare(
queue="from_ai_to_mysql"
)
mq_queue = res.method.queue
channel.queue_bind(
exchange="zzck_llm_exchange",
queue=mq_queue,
)
# 启动消费关闭自动ACK
channel.basic_consume(
queue=mq_queue,
on_message_callback=message_callback,
auto_ack=False # 关闭自动确认
)
print("消费者已启动,等待消息...")
channel.start_consuming()
except pika.exceptions.ConnectionClosedByBroker:
# 代理主动关闭连接,可能是临时错误
print("连接被代理关闭将在5秒后重试...")
time.sleep(5)
except pika.exceptions.AMQPConnectionError:
# 连接错误
print("连接失败将在10秒后重试...")
time.sleep(10)
except KeyboardInterrupt:
print("消费者被用户中断")
try:
if connection and connection.is_open:
connection.close()
except:
pass
break
except Exception as e:
print(f"消费者异常: {str(e)}")
print("将在15秒后重试...")
time.sleep(15)
finally:
try:
if connection and connection.is_open:
connection.close()
except:
pass
# def start_consumer():
# """启动MQ消费者"""
# try:
# credentials = pika.PlainCredentials(mq_user, mq_password)
# connection = pika.BlockingConnection(
# pika.ConnectionParameters(
# host="localhost",
# credentials=credentials,
# heartbeat=600
# )
# )
# channel = connection.channel()
# channel.exchange_declare(
# exchange="zzck_exchange",
# exchange_type="fanout",
# )
# # 声明队列(匹配现有队列类型) queue 的名字可以自定义
# res = channel.queue_declare(
# queue="from_ai_to_mysql"
# )
# mq_queue = res.method.queue
# channel.queue_bind(
# exchange="zzck_llm_exchange",
# queue=mq_queue,
# )
# # 启动消费
# channel.basic_consume(
# queue=mq_queue,
# on_message_callback=message_callback,
# )
# print("消费者已启动,等待消息...")
# channel.start_consuming()
# except Exception as e:
# print(f"消费者启动失败: {str(e)}")
# start_consumer()
if __name__ == "__main__":
start_consumer()