20260305 对近期改动做一次线上备份

This commit is contained in:
朱思南 2026-03-05 17:34:29 +08:00
parent 24b67815a0
commit d67e85ac87
5 changed files with 1042 additions and 33 deletions

244
backfill_mysql.py Normal file
View File

@ -0,0 +1,244 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
申万行业热度数据回溯补录工具
读取本地CSV文件存入MySQL sw_heat_daily表
用于历史数据补录或重新导入
"""
import pandas as pd
import pymysql
import os
import re
from datetime import datetime
import argparse
def extract_date_from_filename(filename):
"""
从文件名提取日期
例如: heat_ranking_2026-03-04_5d.csv -> 2026-03-04 -> 20260304
"""
# 匹配日期格式 YYYY-MM-DD
pattern = r'(\d{4}-\d{2}-\d{2})'
match = re.search(pattern, filename)
if match:
date_str = match.group(1) # 2026-03-04
# 转换为 YYYYMMDD 格式
calc_date = date_str.replace('-', '') # 20260304
return calc_date, date_str
else:
raise ValueError(f"无法从文件名提取日期: {filename}")
def read_heat_csv(file_path):
"""
读取热度排名CSV文件
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
print(f"📂 读取文件: {file_path}")
df = pd.read_csv(file_path, encoding='utf-8-sig')
# 检查必要列是否存在
required_cols = ['code', 'name', 'total_return', 'avg_turnover',
'volume_ratio', 'momentum', 'total_return_score',
'avg_turnover_score', 'volume_ratio_score',
'momentum_score', 'heat_score']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
# 尝试兼容旧版列名(如果有的话)
col_mapping = {
'行业代码': 'code',
'行业名称': 'name',
# 添加其他可能的映射
}
for old_col, new_col in col_mapping.items():
if old_col in df.columns and new_col not in df.columns:
df[new_col] = df[old_col]
# 再次检查
still_missing = [col for col in required_cols if col not in df.columns]
if still_missing:
raise ValueError(f"CSV文件缺少必要列: {still_missing}")
print(f"✅ 成功读取 {len(df)} 条行业数据")
return df
def save_to_mysql(df, calc_date, db_config, dry_run=False):
"""
将DataFrame存入MySQL
:param calc_date: YYYYMMDD 格式的日期字符串
:param dry_run: 如果为True只打印SQL不执行
"""
if dry_run:
print(f"\n🔍 [DRY RUN] 预览模式,不实际写入数据库")
print(f"📅 目标日期: {calc_date}")
print(f"📊 数据预览:")
print(df[['code', 'name', 'heat_score']].head())
return True
conn = None
try:
conn = pymysql.connect(**db_config)
cursor = conn.cursor()
# 检查该日期是否已有数据
cursor.execute(
"SELECT COUNT(*) FROM sw_heat_daily WHERE calc_date = %s",
(calc_date,)
)
existing = cursor.fetchone()[0]
if existing > 0:
print(f"⚠️ 日期 {calc_date} 已有 {existing} 条记录,将执行覆盖更新")
# 准备插入SQL
insert_sql = """
INSERT INTO sw_heat_daily
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
name = VALUES(name),
total_return = VALUES(total_return),
avg_turnover = VALUES(avg_turnover),
volume_ratio = VALUES(volume_ratio),
momentum = VALUES(momentum),
total_return_score = VALUES(total_return_score),
avg_turnover_score = VALUES(avg_turnover_score),
volume_ratio_score = VALUES(volume_ratio_score),
momentum_score = VALUES(momentum_score),
heat_score = VALUES(heat_score),
created_at = CURRENT_TIMESTAMP;
"""
# 转换数据为列表元组
data_tuples = []
for _, row in df.iterrows():
data_tuples.append((
calc_date,
str(row['code']),
str(row['name']),
float(row['total_return']),
float(row['avg_turnover']),
float(row['volume_ratio']),
float(row['momentum']),
float(row['total_return_score']),
float(row['avg_turnover_score']),
float(row['volume_ratio_score']),
float(row['momentum_score']),
float(row['heat_score'])
))
# 执行批量插入
cursor.executemany(insert_sql, data_tuples)
conn.commit()
action = "覆盖更新" if existing > 0 else "新增"
print(f"✅ MySQL{action}成功: {calc_date}{len(data_tuples)} 条记录")
return True
except Exception as e:
print(f"❌ MySQL操作失败: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
def batch_import(data_dir, db_config, pattern=None, dry_run=False):
"""
批量导入目录下所有符合格式的CSV文件
"""
if pattern is None:
pattern = r'heat_ranking_\d{4}-\d{2}-\d{2}_.*\.csv'
files = [f for f in os.listdir(data_dir) if re.match(pattern, f)]
files.sort()
print(f"📁 发现 {len(files)} 个待导入文件")
success_count = 0
for filename in files:
file_path = os.path.join(data_dir, filename)
try:
calc_date, _ = extract_date_from_filename(filename)
df = read_heat_csv(file_path)
if save_to_mysql(df, calc_date, db_config, dry_run):
success_count += 1
except Exception as e:
print(f"❌ 处理 {filename} 失败: {e}")
continue
print(f"\n📊 导入完成: {success_count}/{len(files)} 个文件成功")
def main():
# 数据库配置(请根据实际情况修改)
config = {
'host': '10.127.2.207',
'user': 'financial_prod',
'password': 'mmTFncqmDal5HLRGY0BV',
'database': 'reference',
'port': 3306,
'charset': 'utf8mb4'
}
# 命令行参数解析
parser = argparse.ArgumentParser(description='申万行业热度数据回溯补录工具')
parser.add_argument('file', nargs='?', help='单个CSV文件路径')
parser.add_argument('--date', '-d', help='指定日期(YYYYMMDD),覆盖文件名中的日期')
parser.add_argument('--batch', '-b', action='store_true', help='批量导入目录下所有文件')
parser.add_argument('--dir', default='./sw_data', help='数据目录路径(默认: ./sw_data)')
parser.add_argument('--dry-run', action='store_true', help='预览模式,不实际写入数据库')
args = parser.parse_args()
# 批量模式
if args.batch:
batch_import(args.dir, config, dry_run=args.dry_run)
return
# 单文件模式
if args.file:
file_path = args.file
else:
# 默认示例文件
file_path = './sw_data/heat_ranking_2026-03-04_5d.csv'
if not os.path.exists(file_path):
print("请提供CSV文件路径或使用 --batch 批量导入")
parser.print_help()
return
try:
# 读取CSV
df = read_heat_csv(file_path)
# 确定日期
if args.date:
calc_date = args.date # 命令行指定的日期
print(f"📅 使用指定日期: {calc_date}")
else:
calc_date, original_date = extract_date_from_filename(file_path)
print(f"📅 从文件名提取日期: {original_date} -> {calc_date}")
# 存入MySQL
save_to_mysql(df, calc_date, config, dry_run=args.dry_run)
except Exception as e:
print(f"❌ 错误: {e}")
raise
if __name__ == "__main__":
main()

405
industry_heat_task.py Normal file
View File

@ -0,0 +1,405 @@
import akshare as ak
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import os
import time
import pymysql
import warnings
warnings.filterwarnings('ignore')
class SWIndustryHeatTracker:
"""
申万一级行业热度追踪器
使用 index_analysis_daily_sw 接口收盘后官方数据
支持自定义热度计算周期和权重数据自动存入MySQL
"""
def __init__(self, data_dir="./sw_data", lookback_days=5, db_config=None):
self.data_dir = data_dir
self.lookback_days = lookback_days
self.db_config = db_config # MySQL配置
os.makedirs(data_dir, exist_ok=True)
# 初始化交易日历
self.trading_dates = self._load_trading_calendar()
# 申万一级行业代码映射31个行业
self.industry_map = {
'801010': '农林牧渔', '801030': '基础化工', '801040': '钢铁',
'801050': '有色金属', '801080': '电子', '801110': '家用电器',
'801120': '食品饮料', '801130': '纺织服饰', '801140': '轻工制造',
'801150': '医药生物', '801160': '公用事业', '801170': '交通运输',
'801180': '房地产', '801200': '商贸零售', '801210': '社会服务',
'801230': '综合', '801710': '建筑材料', '801720': '建筑装饰',
'801730': '电力设备', '801740': '国防军工', '801750': '计算机',
'801760': '传媒', '801770': '通信', '801780': '银行',
'801790': '非银金融', '801880': '汽车', '801890': '机械设备',
'801950': '煤炭', '801960': '石油石化', '801970': '环保',
'801980': '美容护理'
}
# 热度指标权重配置
self.weights = {
'return': 0.35,
'turnover': 0.25,
'volume_ratio': 0.25,
'momentum': 0.15
}
# 初始化数据库表
if self.db_config:
self._init_database()
def _init_database(self):
"""初始化MySQL数据库和表"""
conn = None
try:
conn = pymysql.connect(**self.db_config)
with conn.cursor() as cursor:
# 创建表(如果不存在)
create_table_sql = """
CREATE TABLE IF NOT EXISTS sw_heat_daily (
id INT AUTO_INCREMENT PRIMARY KEY,
calc_date VARCHAR(8) NOT NULL COMMENT '计算日期(YYYYMMDD)',
code VARCHAR(10) NOT NULL COMMENT '行业代码',
name VARCHAR(20) NOT NULL COMMENT '行业名称',
total_return DECIMAL(10,4) COMMENT '累计涨跌幅(%)',
avg_turnover DECIMAL(10,4) COMMENT '平均换手率(%)',
volume_ratio DECIMAL(10,4) COMMENT '量比',
momentum DECIMAL(5,4) COMMENT '动量(上涨天数占比)',
total_return_score DECIMAL(5,4) COMMENT '收益因子得分(0-1)',
avg_turnover_score DECIMAL(5,4) COMMENT '换手因子得分(0-1)',
volume_ratio_score DECIMAL(5,4) COMMENT '量比因子得分(0-1)',
momentum_score DECIMAL(5,4) COMMENT '动量因子得分(0-1)',
heat_score DECIMAL(6,2) COMMENT '热度综合得分(0-100)',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_date_code (calc_date, code),
INDEX idx_date (calc_date),
INDEX idx_heat_score (calc_date, heat_score)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='申万一级行业热度日表';
"""
cursor.execute(create_table_sql)
conn.commit()
print("✅ 数据库表 sw_heat_daily 初始化成功")
except Exception as e:
print(f"❌ 数据库初始化失败: {e}")
raise
finally:
if conn:
conn.close()
def _load_trading_calendar(self):
"""加载交易日历"""
cache_file = f"{self.data_dir}/trading_calendar.csv"
if os.path.exists(cache_file):
df = pd.read_csv(cache_file)
dates = df['trade_date'].astype(str).tolist()
cache_time = datetime.fromtimestamp(os.path.getmtime(cache_file))
if (datetime.now() - cache_time).days > 7:
self._update_calendar_cache(cache_file)
else:
dates = self._update_calendar_cache(cache_file)
return set(dates)
def _update_calendar_cache(self, cache_file):
"""更新交易日历缓存"""
print("正在更新交易日历...")
df = ak.tool_trade_date_hist_sina()
df.to_csv(cache_file, index=False)
return df['trade_date'].astype(str).tolist()
def get_recent_trading_dates(self, n_days, end_date=None):
"""获取最近N个交易日"""
if end_date is None:
end_date = datetime.now().strftime('%Y-%m-%d')
valid_dates = [d for d in self.trading_dates if d <= end_date]
return sorted(valid_dates)[-n_days:]
def fetch_daily_data(self, trade_date):
"""获取指定交易日的申万行业日报数据"""
date_str = trade_date.replace('-', '')
try:
df = ak.index_analysis_daily_sw(
symbol="一级行业",
start_date=date_str,
end_date=date_str
)
if df.empty:
print(f"⚠️ {trade_date} 无数据")
return None
df = df.rename(columns={
'指数代码': 'code',
'指数名称': 'name',
'发布日期': 'date',
'收盘指数': 'close',
'涨跌幅': 'return',
'换手率': 'turnover',
'成交量': 'volume',
'成交额': 'amount',
'市盈率': 'pe',
'市净率': 'pb'
})
df['trade_date'] = trade_date
df['code'] = df['code'].astype(str)
return df
except Exception as e:
print(f"❌ 获取 {trade_date} 数据失败: {e}")
return None
def fetch_multi_days_data(self, n_days=None):
"""获取最近N天的数据"""
if n_days is None:
n_days = self.lookback_days + 5
dates = self.get_recent_trading_dates(n_days)
print(f"📅 获取数据区间: {dates[0]}{dates[-1]}")
all_data = []
for date in dates:
df = self.fetch_daily_data(date)
if df is not None:
all_data.append(df)
time.sleep(0.5)
if not all_data:
return None
combined = pd.concat(all_data, ignore_index=True)
combined['trade_date'] = pd.to_datetime(combined['trade_date'])
return combined.sort_values(['code', 'trade_date'])
def calculate_heat_score(self, df, lookback_days=None):
"""计算板块热度复合指标"""
if lookback_days is None:
lookback_days = self.lookback_days
latest_dates = df['trade_date'].unique()[-lookback_days:]
recent_df = df[df['trade_date'].isin(latest_dates)].copy()
all_dates = df['trade_date'].unique()
if len(all_dates) >= lookback_days * 2:
hist_dates = all_dates[-lookback_days*2:-lookback_days]
hist_df = df[df['trade_date'].isin(hist_dates)]
else:
hist_df = None
heat_data = []
for code in recent_df['code'].unique():
industry_data = recent_df[recent_df['code'] == code]
if len(industry_data) < lookback_days:
continue
name = industry_data['name'].iloc[0]
total_return = industry_data['return'].sum()
avg_turnover = industry_data['turnover'].mean()
recent_volume = industry_data['volume'].mean()
if hist_df is not None and code in hist_df['code'].values:
hist_volume = hist_df[hist_df['code'] == code]['volume'].mean()
volume_ratio = recent_volume / hist_volume if hist_volume > 0 else 1.0
else:
volume_ratio = 1.0
up_days = (industry_data['return'] > 0).sum()
momentum = up_days / lookback_days
heat_data.append({
'code': code,
'name': name,
'total_return': total_return,
'avg_turnover': avg_turnover,
'volume_ratio': volume_ratio,
'momentum': momentum,
'latest_close': industry_data['close'].iloc[-1],
'latest_return': industry_data['return'].iloc[-1],
'up_days': up_days
})
heat_df = pd.DataFrame(heat_data)
for col in ['total_return', 'avg_turnover', 'volume_ratio', 'momentum']:
heat_df[f'{col}_score'] = heat_df[col].rank(pct=True)
heat_df['heat_score'] = (
heat_df['total_return_score'] * self.weights['return'] +
heat_df['avg_turnover_score'] * self.weights['turnover'] +
heat_df['volume_ratio_score'] * self.weights['volume_ratio'] +
heat_df['momentum_score'] * self.weights['momentum']
) * 100
heat_df['heat_level'] = pd.cut(
heat_df['heat_score'],
bins=[0, 20, 40, 60, 80, 100],
labels=['极冷', '偏冷', '温和', '偏热', '极热']
)
heat_df['rank'] = heat_df['heat_score'].rank(ascending=False, method='min').astype(int)
return heat_df.sort_values('heat_score', ascending=False)
def save_to_mysql(self, heat_df, calc_date):
"""
将热度数据存入MySQL
:param heat_df: 热度数据DataFrame
:param calc_date: 计算日期格式'YYYYMMDD'
"""
if not self.db_config:
print("⚠️ 未配置数据库跳过MySQL存储")
return False
conn = None
try:
conn = pymysql.connect(**self.db_config)
cursor = conn.cursor()
# 准备插入数据
insert_sql = """
INSERT INTO sw_heat_daily
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
name = VALUES(name),
total_return = VALUES(total_return),
avg_turnover = VALUES(avg_turnover),
volume_ratio = VALUES(volume_ratio),
momentum = VALUES(momentum),
total_return_score = VALUES(total_return_score),
avg_turnover_score = VALUES(avg_turnover_score),
volume_ratio_score = VALUES(volume_ratio_score),
momentum_score = VALUES(momentum_score),
heat_score = VALUES(heat_score),
created_at = CURRENT_TIMESTAMP;
"""
# 转换数据为列表元组
data_tuples = []
for _, row in heat_df.iterrows():
data_tuples.append((
calc_date,
row['code'],
row['name'],
float(row['total_return']),
float(row['avg_turnover']),
float(row['volume_ratio']),
float(row['momentum']),
float(row['total_return_score']),
float(row['avg_turnover_score']),
float(row['volume_ratio_score']),
float(row['momentum_score']),
float(row['heat_score'])
))
# 批量插入
cursor.executemany(insert_sql, data_tuples)
conn.commit()
print(f"✅ MySQL存储成功: {calc_date}{len(data_tuples)} 条记录")
return True
except Exception as e:
print(f"❌ MySQL存储失败: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
def generate_report(self, save=True, to_mysql=True):
"""生成热度分析报告并存储"""
calc_date = datetime.now().strftime('%Y%m%d') # YYYYMMDD格式
print("="*80)
print(f"🔥 申万一级行业热度报告 ({self.lookback_days}日复合指标)")
print(f"📅 计算日期: {calc_date}")
print(f"⏰ 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*80)
# 获取数据
raw_data = self.fetch_multi_days_data()
if raw_data is None:
print("❌ 数据获取失败")
return None
# 计算热度
heat_df = self.calculate_heat_score(raw_data)
# 打印报告(省略,与之前相同)...
print(f"\n🏆 热度 TOP 10 (近{self.lookback_days}日):")
print("-"*80)
top10 = heat_df.head(10)[['rank', 'name', 'heat_score', 'heat_level',
'total_return', 'avg_turnover', 'volume_ratio', 'momentum']]
for _, row in top10.iterrows():
print(f"{row['rank']:2d}. {row['name']:8s} | "
f"热度:{row['heat_score']:5.1f}({row['heat_level']}) | "
f"收益:{row['total_return']:+6.2f}% | "
f"换手:{row['avg_turnover']:5.2f}% | "
f"量比:{row['volume_ratio']:4.2f} | "
f"动量:{row['momentum']:.0%}")
# 保存数据部分
if save:
today = datetime.now().strftime('%Y-%m-%d')
# # 1. 保存JSON可选用于调试
# result_file = f"{self.data_dir}/heat_report_{today}_{self.lookback_days}d.json"
# heat_df.to_json(result_file, orient='records', force_ascii=False, indent=2)
# print(f"\n💾 JSON报告已保存: {result_file}")
# 2. 保存CSV可选便于Excel查看
csv_file = f"{self.data_dir}/heat_ranking_{today}_{self.lookback_days}d.csv"
heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig')
csv_file = f"{self.data_dir}/heat_ranking_newest_{self.lookback_days}d.csv"
heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig')
print(f"💾 CSV排名已保存: {csv_file}")
# 3. 存入MySQL新增
if to_mysql and self.db_config:
self.save_to_mysql(heat_df, calc_date)
return heat_df
# 使用示例
if __name__ == "__main__":
# 数据库配置
# config = {
# 'host': 'localhost',
# 'port': 3306,
# 'user': 'your_username',
# 'password': 'your_password',
# 'database': 'your_database',
# 'charset': 'utf8mb4'
# }
config = {
'host': '10.127.2.207',
'user': 'financial_prod',
'password': 'mmTFncqmDal5HLRGY0BV',
'database': 'reference',
'port': 3306,
'charset': 'utf8mb4'
}
# 创建追踪器,传入数据库配置
tracker = SWIndustryHeatTracker(
lookback_days=5,
db_config=config # 传入数据库配置
)
# 生成报告并自动存入MySQL
result = tracker.generate_report(save=True, to_mysql=True)

View File

@ -24,19 +24,13 @@ with open("media_score.txt", "r", encoding="utf-8") as f:
processed_ids = set() processed_ids = set()
def message_callback(ch, method, properties, body): def message_callback(ch, method, properties, body):
"""消息处理回调函数""" """消息处理回调函数"""
try: try:
start_time = time.time() start_time = time.time()
data = json.loads(body) data = json.loads(body)
id_str = str(data["id"]) id_str = str(data["id"])
input_date = data["input_date"] input_date = data["input_date"]
print(id_str, input_date) print(id_str, input_date)
# ch.basic_ack(delivery_tag=method.delivery_tag)
# print(f"接收到消息: {id_str}")
# return
# 幂等性检查:如果消息已处理过,直接确认并跳过 # 幂等性检查:如果消息已处理过,直接确认并跳过
if id_str in processed_ids: if id_str in processed_ids:
@ -116,7 +110,7 @@ def start_consumer():
connection = create_connection() connection = create_connection()
channel = connection.channel() channel = connection.channel()
# 设置QoS限制每次只取条消息 # 设置QoS限制每次只取50条消息
channel.basic_qos(prefetch_count=50) channel.basic_qos(prefetch_count=50)
channel.exchange_declare( channel.exchange_declare(

View File

@ -3,11 +3,24 @@ import json
import logging import logging
import time import time
import os import os
import re
import threading import threading
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from queue import Queue from queue import Queue
from config import * from config import *
from llm_process import send_mq, get_label import jieba
from llm_process import send_mq, get_label, get_translation, is_mostly_chinese, get_tdx_map, get_stock_map
from label_check import load_label_mapping, validate_tags
from check_sensitivity import SensitiveWordFilter
from news_clustering import NewsClusterer
from cal_etf_labels import *
import pandas as pd
import threading
import time
from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
# 声明一个全局变量,存媒体的权威度打分 # 声明一个全局变量,存媒体的权威度打分
media_score = {} media_score = {}
@ -23,6 +36,70 @@ with open("media_score.txt", "r", encoding="utf-8") as f:
print(f"解析错误: {e},行内容: {line}") print(f"解析错误: {e},行内容: {line}")
continue continue
# 行业及概念的一二级标签强制检查
concept_mapping = load_label_mapping("concept_mapping.txt")
industry_mapping = load_label_mapping("industry_mapping.txt")
# 敏感词检查
sensitive_filter = SensitiveWordFilter("sensitive_words.txt")
# 聚类打分
logging.getLogger("pika").setLevel(logging.WARNING)
JIEBA_DICT_PATH = '/root/zzck/news_distinct_task/dict.txt.big' # jieba分词字典路径
if os.path.exists(JIEBA_DICT_PATH):
jieba.set_dictionary(JIEBA_DICT_PATH)
news_cluster = NewsClusterer()
news_cluster.load_clusters()
# 关键词加分
words_bonus = {}
with open("bonus_words.txt", "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
word, bonus = line.split(",")
words_bonus[word.strip()] = int(bonus)
except ValueError as e:
print(f"关键词加分项解析错误: {e},行内容: {line}")
continue
print(words_bonus)
# 申万一级行业热度分, 从csv读入dataframe后直接转dict
# --- 1. 初始化 ---
sw_heat_file = "/root/zzck/industry_heat_task/sw_data/heat_ranking_newest_5d.csv"
def load_sw_map():
"""封装读取逻辑,增加异常处理防止程序因为文件错误崩溃"""
try:
df = pd.read_csv(sw_heat_file)
# 使用你最优雅的一行逻辑
return df.set_index('name')['heat_score'].to_dict()
except Exception as e:
print(f"读取CSV失败: {e}")
return None
# 全局变量初始化
sw_heat_map = load_sw_map()
# --- 2. 定义更新任务 ---
def update_sw_heat_job():
global sw_heat_map
print(f"[{datetime.now()}] 开始执行每日热度数据更新...")
new_map = load_sw_map()
if new_map:
# 原子替换:这是线程安全的,正在运行的进程会自动切换到新引用
sw_heat_map = new_map
print("sw_heat_map 更新成功")
print(f"当前sw_heat_map示例: {list(sw_heat_map.items())[:5]}") # 打印前5项检查
# --- 3. 启动定时调度器 ---
scheduler = BackgroundScheduler()
# 每天 18:10 执行更新
scheduler.add_job(update_sw_heat_job, 'cron', hour=18, minute=10)
scheduler.start()
# 幂等性存储 - 记录已处理消息ID (使用线程安全的集合) # 幂等性存储 - 记录已处理消息ID (使用线程安全的集合)
processed_ids = set() processed_ids = set()
processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问 processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问
@ -64,31 +141,246 @@ def process_single_message(data):
source = d_data[0].get('sourcename', "其他") source = d_data[0].get('sourcename', "其他")
source_impact = media_score.get(source, 5) source_impact = media_score.get(source, 5)
tagged_news = get_label(content, source) tagged_news = get_label(content, source)
public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分 # deal with the label problems
China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度 tagged_news = validate_tags(tagged_news, "industry_label", "industry_confidence", industry_mapping)
news_score = source_impact * 0.04 + public_opinion_score * 0.25 + China_factor * 35 tagged_news = validate_tags(tagged_news, "concept_label", "concept_confidence", concept_mapping)
news_score = round(news_score, 2)
industry_confidence = tagged_news.get("industry_confidence", []) industry_confidence = tagged_news.get("industry_confidence", [])
industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence))
concept_confidence = tagged_news.get("concept_confidence", []) concept_confidence = tagged_news.get("concept_confidence", [])
concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence))
# 确保最终展示的分数是两位小数
industry_confidence = list(map(lambda x: round(x, 2), industry_confidence)) industry_confidence = list(map(lambda x: round(x, 2), industry_confidence))
concept_confidence = list(map(lambda x: round(x, 2), concept_confidence)) concept_confidence = list(map(lambda x: round(x, 2), concept_confidence))
industry_label = tagged_news.get("industry_label", [])
sw_heat = 10 # 默认热度分,如果没有标签或热度数据缺失则使用默认值
if industry_label and industry_confidence:
first_industry = industry_label[0].split("-")[0]
first_industry_confidence = industry_confidence[0]
if first_industry_confidence >= 0.75:
sw_heat = sw_heat_map.get(first_industry, 0)
public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分
China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度
news_score = source_impact * 0.03 + public_opinion_score * 0.3 + China_factor * 35 + sw_heat * 0.05
industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence))
concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence))
# 增加一个关键词加分, 例如原文标题命中 Google 加 5 分, 不区分大小写
ori_title = data.get('title_EN', '')
hits = []
for word, bonus in words_bonus.items():
if re.search(rf'\b{re.escape(word)}\b', ori_title, re.IGNORECASE):
hits.append((word, bonus))
if hits:
# 按分数降序排序,取最高分项
hits_sorted = sorted(hits, key=lambda x: x[1], reverse=True)
print(f"原文标题{ori_title} 命中关键词及分数:", hits_sorted)
add_score = hits_sorted[0][-1] * 0.05
print(f"新闻{id_str} 当前分数news_score变动: {news_score} -> {news_score + add_score} ... 如果超出上限后续会调整...")
news_score = min(99.0, news_score + add_score)
else:
pass
# 增加一个聚类打分,目前仅作试验
if datetime.now().minute == 12:
if not news_cluster.load_clusters():
print(f"聚类文件加载失败 {datetime.now()}")
CN_title = data.get('title_txt', '')
cluster_evaluate = news_cluster.evaluate_news(title=CN_title, content=content)
cluster_score = (cluster_evaluate['weighted_score'] ** 0.5) * 10 if cluster_evaluate else 20.0
center_news_id = cluster_evaluate['center_news_id'] if cluster_evaluate else None
news_score_exp = source_impact * 0.04 + cluster_score * 0.25 + China_factor * 35
# if it's Ads, decrease the score
if all(keyword in content for keyword in ["邮箱", "电话", "网站"]):
print("新闻判别为广告软文分数强制8折")
news_score *= 0.79
news_score_exp *= 0.79
# if there is no labels at all, decrease the score
if news_score >= 80 and len(tagged_news.get("industry_label", [])) == 0 and len(tagged_news.get("concept_label", [])) == 0:
print("新闻没有打上任何有效标签分数强制8折")
news_score *= 0.79
news_score_exp *= 0.79
# if there are sensitive words in title and content, decrease the score
if news_score >= 80:
title = data.get('title_txt', '')
news_text = f"{title} {content}"
result = sensitive_filter.detect(news_text)
if result:
print("新闻命中敏感词: ", ", ".join(list(result)), "; 分数强制8折")
news_score *= 0.79
news_score_exp *= 0.79
# make sure the selected news label confidence is higher than 0.7
if news_score >= 80:
industry_high_confidence = industry_confidence[0] if len(industry_confidence) > 0 else 0
concept_high_confidence = concept_confidence[0] if len(concept_confidence) > 0 else 0
if max(industry_high_confidence, concept_high_confidence) < 0.7:
news_score *= 0.79
news_score_exp *= 0.79
print(f"行业标签置信度为{industry_high_confidence}, 概念标签置信度为{concept_high_confidence}, 新闻得分强制8折.")
news_score = round(news_score, 2)
news_score_exp = round(news_score_exp, 2)
cluster_score = round(cluster_score, 2)
# if score >= 80, add translation by llm
if news_score >= 80:
ori_content = data.get('EN_content', '')
ori_title = data.get('title_EN', '')
for _ in range(3):
# 翻译功能存在不稳定, 至多重试三次
translation = get_translation(ori_content, ori_title)
llm_content = translation.get("llm_content", "") if translation else ""
llm_title = translation.get("llm_title", "") if translation else ""
if not is_mostly_chinese(llm_title, threshold=0.4):
print("新闻标题没有被翻译成中文, 会被舍弃, LLM标题为: ", llm_title)
llm_title = ""
continue
if not is_mostly_chinese(llm_content, threshold=0.5):
print("新闻正文没有被翻译成中文, 会被舍弃, LLM正文为: ", llm_content[:30] + "...")
llm_content = ""
continue
# 20250922新增 海外事件 海外宏观 中国宏观 行业新闻 公司新闻 转载来源 公司名称 等精细处理字段
overseas_event = translation.get("overseas_event", None) if translation else None
overseas_macro = translation.get("overseas_macro", None) if translation else None
china_macro = translation.get("china_macro", None) if translation else None
industry_news = translation.get("industry_news", None) if translation else None
company_news = translation.get("company_news", None) if translation else None
reprint_source = translation.get("reprint_source", None) if translation else None
if reprint_source == "":
reprint_source = None
company_name = translation.get("company_name", None) if translation else None
if company_name:
# 打个补丁, 让公司析出数强制不超过3 2025-12-19
company_name_list = company_name.strip().split(",")
if len(company_name_list) > 3:
company_name = ",".join(company_name_list[:3])
if company_name == "":
company_name = None
## 行业和之前的标签冗余,暂时去掉
# industry_name = translation.get("industry_name", None) if translation else None
# 20251217新增 修改了翻译的一些规则; 新增通达信行业标签及置信度
tdx_industry_ori = translation.get("tdx_industry", None) if translation else None
tdx_industry_confidence_ori = translation.get("tdx_industry_confidence", None) if translation else None
tdx_map = get_tdx_map()
tdx_industry = []
tdx_industry_confidence = []
if tdx_industry_ori and tdx_industry_confidence_ori and len(tdx_industry_confidence_ori) == len(tdx_industry_ori):
for industry, confidence in zip(tdx_industry_ori, tdx_industry_confidence_ori):
if industry in tdx_map:
tdx_industry.append(industry)
tdx_industry_confidence.append(float(confidence))
if len(tdx_industry) == 3:
break
# 20251226 新增 增加了一项析出A股代码的操作, 列表形式返回A股上市公司6位代码至多5个
stock_codes_ori = translation.get("stock_codes", None) if translation else None
if stock_codes_ori:
stock_codes_ori = list(set(stock_codes_ori))
stock_codes = []
stock_names = []
stock_map = get_stock_map()
if stock_codes_ori and isinstance(stock_codes_ori, list):
for stock_code in stock_codes_ori:
if stock_code in stock_map:
stock_name = stock_map.get(stock_code)
stock_codes.append(stock_code)
stock_names.append(stock_name)
# 20260128 新增审核, 返回政治敏感度、出口转内销两项打分
political_sensitivity = translation.get("political_sensitivity", None) if translation else None
export_domestic = translation.get("export_domestic", None) if translation else None
political_notes = translation.get("political_notes", None) if translation else None
additional_notes = translation.get("additional_notes", None) if translation else None
# 20260226 新增大模型摘要
llm_abstract = translation.get("llm_abstract", None) if translation else None
if llm_content and llm_title and len(llm_content) > 30 and len(llm_title) > 3 and overseas_event is not None and overseas_macro is not None:
break
# 精选新闻添加etf
etf_labels = get_etf_labels(title=llm_title, industry_label=tagged_news.get("industry_label", []), industry_confidence=tagged_news.get("industry_confidence", []), concept_label=tagged_news.get("concept_label", []), concept_confidence=tagged_news.get("concept_confidence", []))
etf_names = [etf_name[label] for label in etf_labels]
tagged_news["etf_labels"] = etf_labels
tagged_news["etf_names"] = etf_names
tagged_news["title"] = llm_title
tagged_news["rewrite_content"] = llm_content
tagged_news["overseas_event"] = overseas_event
tagged_news["overseas_macro"] = overseas_macro
tagged_news["china_macro"] = china_macro
tagged_news["industry_news"] = industry_news
tagged_news["company_news"] = company_news
tagged_news["reprint_source"] = reprint_source
tagged_news["company_name"] = company_name
tagged_news["tdx_industry"] = tdx_industry if tdx_industry else None
tagged_news["tdx_industry_confidence"] = tdx_industry_confidence if tdx_industry_confidence else None
tagged_news["stock_codes"] = stock_codes if stock_codes else None
tagged_news["stock_names"] = stock_names if stock_names else None
tagged_news["political_sensitivity"] = political_sensitivity if political_sensitivity else None
tagged_news["export_domestic"] = export_domestic if export_domestic else None
tagged_news["political_notes"] = political_notes if political_notes else None
tagged_news["additional_notes"] = additional_notes if additional_notes else None
tagged_news["llm_abstract"] = llm_abstract if llm_abstract else None
tagged_news["source"] = source tagged_news["source"] = source
tagged_news["source_impact"] = source_impact tagged_news["source_impact"] = source_impact
tagged_news["industry_score"] = industry_score tagged_news["industry_score"] = industry_score
tagged_news["concept_score"] = concept_score tagged_news["concept_score"] = concept_score
tagged_news["news_score"] = news_score tagged_news["news_score"] = news_score
tagged_news["news_score_exp"] = news_score_exp
tagged_news["center_news_id"] = center_news_id
tagged_news["id"] = id_str tagged_news["id"] = id_str
tagged_news["cluster_score"] = cluster_score
tagged_news["industry_confidence"] = industry_confidence tagged_news["industry_confidence"] = industry_confidence
tagged_news["concept_confidence"] = concept_confidence tagged_news["concept_confidence"] = concept_confidence
if news_score >= 80:
# 保险起见对大模型翻译内容再做一遍敏感词筛查 20260128
news_title_llm = tagged_news.get("title", "")
news_content_llm = tagged_news.get("rewrite_content", "")
news_text_llm = f"{news_title_llm} {news_content_llm}"
result = sensitive_filter.detect(news_text_llm)
if result:
print("大模型翻译内容命中敏感词: ", ", ".join(list(result)), "; 分数强制8折")
tagged_news["news_score"] = news_score * 0.79
# export_domestic >= 80 或者 political_sensitivity >= 70 也直接移除精选池
if tagged_news.get("export_domestic", 0) >= 80 or tagged_news.get("political_sensitivity", 0) >= 70:
print(f'本条资讯的出口转内销得分为{tagged_news.get("export_domestic", 0)}, 政治敏感得分为{tagged_news.get("political_sensitivity", 0)}, political_notes信息为【{tagged_news.get("political_notes", "")}】, 分数8折')
tagged_news["news_score"] = news_score * 0.79
#print(json.dumps(tagged_news, ensure_ascii=False)) #print(json.dumps(tagged_news, ensure_ascii=False))
print(tagged_news["id"], tagged_news["title"], tagged_news["news_score"], tagged_news["industry_label"], input_date) print(tagged_news["id"], tagged_news["news_score"], tagged_news["news_score_exp"], tagged_news["industry_label"], tagged_news["concept_label"], public_opinion_score, tagged_news["cluster_score"], tagged_news["center_news_id"], input_date)
if news_score >= 80:
print("LLM translation: ", tagged_news.get("rewrite_content", "")[:100])
print("LLM overseas_event score: ", tagged_news.get("overseas_event", None))
print("LLM overseas_macro score: ", tagged_news.get("overseas_macro", None))
print("LLM china_macro score: ", tagged_news.get("china_macro", None))
if tagged_news.get("reprint_source", None):
print("LLM reprint_source: ", tagged_news.get("reprint_source", None))
if tagged_news.get("company_name", None):
print("LLM company_name: ", tagged_news.get("company_name", None), tagged_news.get("company_news", None))
print("ETF NAMES: ", tagged_news.get("etf_names", []))
if tagged_news.get("tdx_industry", None):
print("LLM_max tdx_industry: ", tagged_news.get("tdx_industry", None))
if tagged_news.get("tdx_industry_confidence", None):
print("LLM_max tdx_industry_confidence: ", tagged_news.get("tdx_industry_confidence", None))
if tagged_news.get("stock_names", None):
print("LLM_max stock codes & names: ", tagged_news.get("stock_codes", None), tagged_news.get("stock_names", None))
if tagged_news.get("political_sensitivity", None):
print("LLM_max political_sensitivity: ", tagged_news.get("political_sensitivity", None))
if tagged_news.get("llm_abstract", None):
print("LLM_max llm_abstract: ", tagged_news.get("llm_abstract", None)[:30])
return tagged_news, True return tagged_news, True
except Exception as e: except Exception as e:

View File

@ -24,7 +24,6 @@ def message_callback(ch, method, properties, body):
# 数据写入资讯精选表 # 数据写入资讯精选表
write_to_news(data) write_to_news(data)
# 手动确认消息 # 手动确认消息
ch.basic_ack(delivery_tag=method.delivery_tag) ch.basic_ack(delivery_tag=method.delivery_tag)
@ -33,11 +32,20 @@ def message_callback(ch, method, properties, body):
# 拒绝消息并重新入队 # 拒绝消息并重新入队
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
def prepare_db_value(value):
"""准备数据库值空列表和None都转为None其他情况JSON序列化"""
if value is None:
return None
if isinstance(value, list):
return json.dumps(value) if value else None
# 如果不是列表,保持原样(或者根据需求处理)
return value
def write_to_news(data): def write_to_news(data):
news_score = data.get('news_score', 0.0) news_score = data.get('news_score', 0.0)
if float(news_score) < 80: # 过滤掉news_score小于80的消息 if float(news_score) < 80: # 过滤掉news_score小于80的消息
return return
# 获取返回数据里面的 新闻id # 获取返回数据里面的 新闻id
news_id = data.get('id', "") news_id = data.get('id', "")
adr = jx_adr.replace("news_id", news_id) adr = jx_adr.replace("news_id", news_id)
@ -48,7 +56,67 @@ def write_to_news(data):
return return
print(f"新闻id:{news_id} 得分:{news_score}, 调用精选接口成功") print(f"新闻id:{news_id} 得分:{news_score}, 调用精选接口成功")
# 某些自建字段由python脚本自行更新 by 朱思南
conn = pymysql.connect(
host=MYSQL_HOST_APP,
port=MYSQL_PORT_APP,
user=MYSQL_USER_APP,
password=MYSQL_PASSWORD_APP,
db=MYSQL_DB_APP,
charset='utf8mb4'
)
try:
with conn.cursor() as cursor:
sql = '''UPDATE news
set overseas_event = %s,
overseas_macro = %s,
china_macro = %s,
industry_news = %s,
company_news = %s,
reprint_source = %s,
company_name = %s,
etf_labels = %s,
etf_names = %s,
tdx_industry = %s,
tdx_industry_confidence = %s,
stock_codes = %s,
stock_names = %s,
political_sensitivity = %s,
export_domestic = %s,
political_notes = %s,
additional_notes = %s,
llm_abstract = %s
WHERE newsinfo_id = %s '''
values = (data.get("overseas_event", None),
data.get("overseas_macro", None),
data.get("china_macro", None),
data.get("industry_news", None),
data.get("company_news", None),
data.get("reprint_source", None),
data.get("company_name", None),
prepare_db_value(data.get("etf_labels", [])),
prepare_db_value(data.get("etf_names", [])),
prepare_db_value(data.get("tdx_industry", [])),
prepare_db_value(data.get("tdx_industry_confidence", [])),
prepare_db_value(data.get("stock_codes", [])),
prepare_db_value(data.get("stock_names", [])),
data.get("political_sensitivity", None),
data.get("export_domestic", None),
data.get("political_notes", None),
data.get("additional_notes", None),
data.get("llm_abstract", None),
news_id)
cursor.execute(sql, values)
if cursor.rowcount == 0:
print(f'warning: newsinfo_id={news_id} 不存在,更新 0 行')
conn.commit()
print(f"新闻id:{news_id} 海外事件置信度得分:{data.get('overseas_event', None)}, 关联ETF名称列表:{data.get('etf_names', None)}, 审核政治敏感度:{data.get('political_sensitivity', None)}. 自建字段写入精选news表成功!")
except Exception as e:
print(f"自建字段写入精选news表失败: {str(e)}")
finally:
if conn:
conn.close()
def write_to_es(data): def write_to_es(data):
"""写入ES""" """写入ES"""
@ -67,9 +135,9 @@ def write_to_es(data):
"abstract": data.get('abstract', ""), "abstract": data.get('abstract', ""),
"title": data.get('title', ""), "title": data.get('title', ""),
"rewrite_content": data.get('rewrite_content', ""), "rewrite_content": data.get('rewrite_content', ""),
"industry_label": json.dumps(data.get('industry_label', [])), "industry_label": data.get('industry_label', []),
"industry_confidence": data.get('industry_confidence', []), "industry_confidence": data.get('industry_confidence', []),
"industry_score": data.get('industry_score', ""), "industry_score": data.get('industry_score', []),
"concept_label": data.get('concept_label', []), "concept_label": data.get('concept_label', []),
"concept_confidence": data.get('concept_confidence', []), "concept_confidence": data.get('concept_confidence', []),
"concept_score": data.get('concept_score', []), "concept_score": data.get('concept_score', []),
@ -104,29 +172,35 @@ def write_to_mysql(data):
# 新增JSON结构解析逻辑 # 新增JSON结构解析逻辑
# 修改后的SQL语句 # 修改后的SQL语句
sql = """INSERT INTO news_tags sql = """INSERT INTO news_tags
(abstract, title, rewrite_content, industry_label, industry_confidence, industry_score, concept_label, concept_confidence, concept_score, public_opinion_score, China_factor, source, source_impact, news_score, news_id) (abstract, title, rewrite_content, industry_label, industry_confidence, industry_score, concept_label, concept_confidence, concept_score, public_opinion_score, China_factor, source, source_impact, news_score, news_id, news_score_exp, cluster_score, center_news_id)
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """ VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """
values = (data.get('abstract', ""), values = (data.get('abstract', ""),
data.get('title', ""), data.get('title', ""),
data.get('rewrite_content', ""), data.get('rewrite_content', ""),
json.dumps(data.get('industry_label', [])), prepare_db_value(data.get('industry_label', [])),
json.dumps(data.get('industry_confidence', [])), prepare_db_value(data.get('industry_confidence', [])),
json.dumps(data.get('industry_score', [])), prepare_db_value(data.get('industry_score', [])),
json.dumps(data.get('concept_label', [])), prepare_db_value(data.get('concept_label', [])),
json.dumps(data.get('concept_confidence', [])), prepare_db_value(data.get('concept_confidence', [])),
json.dumps(data.get('concept_score', [])), prepare_db_value(data.get('concept_score', [])),
data.get('public_opinion_score', 10), data.get('public_opinion_score', 10),
data.get('China_factor', 0.1), data.get('China_factor', 0.1),
data.get('source', "其他"), data.get('source', "其他"),
data.get('source_impact', 5), data.get('source_impact', 5),
data.get('news_score', 0.0), data.get('news_score', 0.0),
data.get('id', "") data.get('id', ""),
data.get('news_score_exp', 5.0),
data.get('cluster_score', 5.0),
data.get('center_news_id', None)
) )
cursor.execute(sql, values) cursor.execute(sql, values)
conn.commit() conn.commit()
abstract = data.get('abstract', "") id = data.get('id', "")
print(f"{abstract}, 写入news_tags 表成功") industry_label = data.get('industry_label', [])
concept_label = data.get('concept_label', [])
overseas_event = data.get('overseas_event', None)
print(f"{id} {industry_label} {concept_label} {data.get('news_score', 0.0)} {data.get('news_score_exp', 5.0)} 写入news_tags 表成功")
except Exception as e: except Exception as e:
print(f"写入news_tags失败: {str(e)}") print(f"写入news_tags失败: {str(e)}")