From d67e85ac878069fc1167c6b59bf672c87781c98c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=80=9D=E5=8D=97?= <15083356+zhu-sinan@user.noreply.gitee.com> Date: Thu, 5 Mar 2026 17:34:29 +0800 Subject: [PATCH] =?UTF-8?q?20260305=20=E5=AF=B9=E8=BF=91=E6=9C=9F=E6=94=B9?= =?UTF-8?q?=E5=8A=A8=E5=81=9A=E4=B8=80=E6=AC=A1=E7=BA=BF=E4=B8=8A=E5=A4=87?= =?UTF-8?q?=E4=BB=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backfill_mysql.py | 244 +++++++++++++++++++++++ industry_heat_task.py | 405 +++++++++++++++++++++++++++++++++++++++ mqreceive.py | 10 +- mqreceive_multithread.py | 310 +++++++++++++++++++++++++++++- mqreceivefromllm.py | 106 ++++++++-- 5 files changed, 1042 insertions(+), 33 deletions(-) create mode 100644 backfill_mysql.py create mode 100644 industry_heat_task.py diff --git a/backfill_mysql.py b/backfill_mysql.py new file mode 100644 index 0000000..7ca1840 --- /dev/null +++ b/backfill_mysql.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +申万行业热度数据回溯补录工具 +读取本地CSV文件,存入MySQL sw_heat_daily表 +用于历史数据补录或重新导入 +""" + +import pandas as pd +import pymysql +import os +import re +from datetime import datetime +import argparse + + +def extract_date_from_filename(filename): + """ + 从文件名提取日期 + 例如: heat_ranking_2026-03-04_5d.csv -> 2026-03-04 -> 20260304 + """ + # 匹配日期格式 YYYY-MM-DD + pattern = r'(\d{4}-\d{2}-\d{2})' + match = re.search(pattern, filename) + + if match: + date_str = match.group(1) # 2026-03-04 + # 转换为 YYYYMMDD 格式 + calc_date = date_str.replace('-', '') # 20260304 + return calc_date, date_str + else: + raise ValueError(f"无法从文件名提取日期: {filename}") + + +def read_heat_csv(file_path): + """ + 读取热度排名CSV文件 + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"文件不存在: {file_path}") + + print(f"📂 读取文件: {file_path}") + df = pd.read_csv(file_path, encoding='utf-8-sig') + + # 检查必要列是否存在 + required_cols = ['code', 'name', 'total_return', 'avg_turnover', + 'volume_ratio', 'momentum', 'total_return_score', + 'avg_turnover_score', 'volume_ratio_score', + 'momentum_score', 'heat_score'] + + missing_cols = [col for col in required_cols if col not in df.columns] + if missing_cols: + # 尝试兼容旧版列名(如果有的话) + col_mapping = { + '行业代码': 'code', + '行业名称': 'name', + # 添加其他可能的映射 + } + for old_col, new_col in col_mapping.items(): + if old_col in df.columns and new_col not in df.columns: + df[new_col] = df[old_col] + + # 再次检查 + still_missing = [col for col in required_cols if col not in df.columns] + if still_missing: + raise ValueError(f"CSV文件缺少必要列: {still_missing}") + + print(f"✅ 成功读取 {len(df)} 条行业数据") + return df + + +def save_to_mysql(df, calc_date, db_config, dry_run=False): + """ + 将DataFrame存入MySQL + :param calc_date: YYYYMMDD 格式的日期字符串 + :param dry_run: 如果为True,只打印SQL不执行 + """ + if dry_run: + print(f"\n🔍 [DRY RUN] 预览模式,不实际写入数据库") + print(f"📅 目标日期: {calc_date}") + print(f"📊 数据预览:") + print(df[['code', 'name', 'heat_score']].head()) + return True + + conn = None + try: + conn = pymysql.connect(**db_config) + cursor = conn.cursor() + + # 检查该日期是否已有数据 + cursor.execute( + "SELECT COUNT(*) FROM sw_heat_daily WHERE calc_date = %s", + (calc_date,) + ) + existing = cursor.fetchone()[0] + + if existing > 0: + print(f"⚠️ 日期 {calc_date} 已有 {existing} 条记录,将执行覆盖更新") + + # 准备插入SQL + insert_sql = """ + INSERT INTO sw_heat_daily + (calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum, + total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score) + VALUES + (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + name = VALUES(name), + total_return = VALUES(total_return), + avg_turnover = VALUES(avg_turnover), + volume_ratio = VALUES(volume_ratio), + momentum = VALUES(momentum), + total_return_score = VALUES(total_return_score), + avg_turnover_score = VALUES(avg_turnover_score), + volume_ratio_score = VALUES(volume_ratio_score), + momentum_score = VALUES(momentum_score), + heat_score = VALUES(heat_score), + created_at = CURRENT_TIMESTAMP; + """ + + # 转换数据为列表元组 + data_tuples = [] + for _, row in df.iterrows(): + data_tuples.append(( + calc_date, + str(row['code']), + str(row['name']), + float(row['total_return']), + float(row['avg_turnover']), + float(row['volume_ratio']), + float(row['momentum']), + float(row['total_return_score']), + float(row['avg_turnover_score']), + float(row['volume_ratio_score']), + float(row['momentum_score']), + float(row['heat_score']) + )) + + # 执行批量插入 + cursor.executemany(insert_sql, data_tuples) + conn.commit() + + action = "覆盖更新" if existing > 0 else "新增" + print(f"✅ MySQL{action}成功: {calc_date} 共 {len(data_tuples)} 条记录") + return True + + except Exception as e: + print(f"❌ MySQL操作失败: {e}") + if conn: + conn.rollback() + return False + finally: + if conn: + conn.close() + + +def batch_import(data_dir, db_config, pattern=None, dry_run=False): + """ + 批量导入目录下所有符合格式的CSV文件 + """ + if pattern is None: + pattern = r'heat_ranking_\d{4}-\d{2}-\d{2}_.*\.csv' + + files = [f for f in os.listdir(data_dir) if re.match(pattern, f)] + files.sort() + + print(f"📁 发现 {len(files)} 个待导入文件") + + success_count = 0 + for filename in files: + file_path = os.path.join(data_dir, filename) + try: + calc_date, _ = extract_date_from_filename(filename) + df = read_heat_csv(file_path) + + if save_to_mysql(df, calc_date, db_config, dry_run): + success_count += 1 + + except Exception as e: + print(f"❌ 处理 {filename} 失败: {e}") + continue + + print(f"\n📊 导入完成: {success_count}/{len(files)} 个文件成功") + + +def main(): + # 数据库配置(请根据实际情况修改) + config = { + 'host': '10.127.2.207', + 'user': 'financial_prod', + 'password': 'mmTFncqmDal5HLRGY0BV', + 'database': 'reference', + 'port': 3306, + 'charset': 'utf8mb4' + } + + # 命令行参数解析 + parser = argparse.ArgumentParser(description='申万行业热度数据回溯补录工具') + parser.add_argument('file', nargs='?', help='单个CSV文件路径') + parser.add_argument('--date', '-d', help='指定日期(YYYYMMDD),覆盖文件名中的日期') + parser.add_argument('--batch', '-b', action='store_true', help='批量导入目录下所有文件') + parser.add_argument('--dir', default='./sw_data', help='数据目录路径(默认: ./sw_data)') + parser.add_argument('--dry-run', action='store_true', help='预览模式,不实际写入数据库') + + args = parser.parse_args() + + # 批量模式 + if args.batch: + batch_import(args.dir, config, dry_run=args.dry_run) + return + + # 单文件模式 + if args.file: + file_path = args.file + else: + # 默认示例文件 + file_path = './sw_data/heat_ranking_2026-03-04_5d.csv' + if not os.path.exists(file_path): + print("请提供CSV文件路径,或使用 --batch 批量导入") + parser.print_help() + return + + try: + # 读取CSV + df = read_heat_csv(file_path) + + # 确定日期 + if args.date: + calc_date = args.date # 命令行指定的日期 + print(f"📅 使用指定日期: {calc_date}") + else: + calc_date, original_date = extract_date_from_filename(file_path) + print(f"📅 从文件名提取日期: {original_date} -> {calc_date}") + + # 存入MySQL + save_to_mysql(df, calc_date, config, dry_run=args.dry_run) + + except Exception as e: + print(f"❌ 错误: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/industry_heat_task.py b/industry_heat_task.py new file mode 100644 index 0000000..9b0de75 --- /dev/null +++ b/industry_heat_task.py @@ -0,0 +1,405 @@ +import akshare as ak +import pandas as pd +import numpy as np +from datetime import datetime, timedelta +import os +import time +import pymysql +import warnings +warnings.filterwarnings('ignore') + +class SWIndustryHeatTracker: + """ + 申万一级行业热度追踪器 + 使用 index_analysis_daily_sw 接口(收盘后官方数据) + 支持自定义热度计算周期和权重,数据自动存入MySQL + """ + + def __init__(self, data_dir="./sw_data", lookback_days=5, db_config=None): + self.data_dir = data_dir + self.lookback_days = lookback_days + self.db_config = db_config # MySQL配置 + os.makedirs(data_dir, exist_ok=True) + + # 初始化交易日历 + self.trading_dates = self._load_trading_calendar() + + # 申万一级行业代码映射(31个行业) + self.industry_map = { + '801010': '农林牧渔', '801030': '基础化工', '801040': '钢铁', + '801050': '有色金属', '801080': '电子', '801110': '家用电器', + '801120': '食品饮料', '801130': '纺织服饰', '801140': '轻工制造', + '801150': '医药生物', '801160': '公用事业', '801170': '交通运输', + '801180': '房地产', '801200': '商贸零售', '801210': '社会服务', + '801230': '综合', '801710': '建筑材料', '801720': '建筑装饰', + '801730': '电力设备', '801740': '国防军工', '801750': '计算机', + '801760': '传媒', '801770': '通信', '801780': '银行', + '801790': '非银金融', '801880': '汽车', '801890': '机械设备', + '801950': '煤炭', '801960': '石油石化', '801970': '环保', + '801980': '美容护理' + } + + # 热度指标权重配置 + self.weights = { + 'return': 0.35, + 'turnover': 0.25, + 'volume_ratio': 0.25, + 'momentum': 0.15 + } + + # 初始化数据库表 + if self.db_config: + self._init_database() + + def _init_database(self): + """初始化MySQL数据库和表""" + conn = None + try: + conn = pymysql.connect(**self.db_config) + with conn.cursor() as cursor: + # 创建表(如果不存在) + create_table_sql = """ + CREATE TABLE IF NOT EXISTS sw_heat_daily ( + id INT AUTO_INCREMENT PRIMARY KEY, + calc_date VARCHAR(8) NOT NULL COMMENT '计算日期(YYYYMMDD)', + code VARCHAR(10) NOT NULL COMMENT '行业代码', + name VARCHAR(20) NOT NULL COMMENT '行业名称', + total_return DECIMAL(10,4) COMMENT '累计涨跌幅(%)', + avg_turnover DECIMAL(10,4) COMMENT '平均换手率(%)', + volume_ratio DECIMAL(10,4) COMMENT '量比', + momentum DECIMAL(5,4) COMMENT '动量(上涨天数占比)', + total_return_score DECIMAL(5,4) COMMENT '收益因子得分(0-1)', + avg_turnover_score DECIMAL(5,4) COMMENT '换手因子得分(0-1)', + volume_ratio_score DECIMAL(5,4) COMMENT '量比因子得分(0-1)', + momentum_score DECIMAL(5,4) COMMENT '动量因子得分(0-1)', + heat_score DECIMAL(6,2) COMMENT '热度综合得分(0-100)', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uk_date_code (calc_date, code), + INDEX idx_date (calc_date), + INDEX idx_heat_score (calc_date, heat_score) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='申万一级行业热度日表'; + """ + cursor.execute(create_table_sql) + conn.commit() + print("✅ 数据库表 sw_heat_daily 初始化成功") + except Exception as e: + print(f"❌ 数据库初始化失败: {e}") + raise + finally: + if conn: + conn.close() + + def _load_trading_calendar(self): + """加载交易日历""" + cache_file = f"{self.data_dir}/trading_calendar.csv" + + if os.path.exists(cache_file): + df = pd.read_csv(cache_file) + dates = df['trade_date'].astype(str).tolist() + cache_time = datetime.fromtimestamp(os.path.getmtime(cache_file)) + if (datetime.now() - cache_time).days > 7: + self._update_calendar_cache(cache_file) + else: + dates = self._update_calendar_cache(cache_file) + + return set(dates) + + def _update_calendar_cache(self, cache_file): + """更新交易日历缓存""" + print("正在更新交易日历...") + df = ak.tool_trade_date_hist_sina() + df.to_csv(cache_file, index=False) + return df['trade_date'].astype(str).tolist() + + def get_recent_trading_dates(self, n_days, end_date=None): + """获取最近N个交易日""" + if end_date is None: + end_date = datetime.now().strftime('%Y-%m-%d') + + valid_dates = [d for d in self.trading_dates if d <= end_date] + return sorted(valid_dates)[-n_days:] + + def fetch_daily_data(self, trade_date): + """获取指定交易日的申万行业日报数据""" + date_str = trade_date.replace('-', '') + + try: + df = ak.index_analysis_daily_sw( + symbol="一级行业", + start_date=date_str, + end_date=date_str + ) + + if df.empty: + print(f"⚠️ {trade_date} 无数据") + return None + + df = df.rename(columns={ + '指数代码': 'code', + '指数名称': 'name', + '发布日期': 'date', + '收盘指数': 'close', + '涨跌幅': 'return', + '换手率': 'turnover', + '成交量': 'volume', + '成交额': 'amount', + '市盈率': 'pe', + '市净率': 'pb' + }) + + df['trade_date'] = trade_date + df['code'] = df['code'].astype(str) + + return df + + except Exception as e: + print(f"❌ 获取 {trade_date} 数据失败: {e}") + return None + + def fetch_multi_days_data(self, n_days=None): + """获取最近N天的数据""" + if n_days is None: + n_days = self.lookback_days + 5 + + dates = self.get_recent_trading_dates(n_days) + print(f"📅 获取数据区间: {dates[0]} 至 {dates[-1]}") + + all_data = [] + for date in dates: + df = self.fetch_daily_data(date) + if df is not None: + all_data.append(df) + time.sleep(0.5) + + if not all_data: + return None + + combined = pd.concat(all_data, ignore_index=True) + combined['trade_date'] = pd.to_datetime(combined['trade_date']) + return combined.sort_values(['code', 'trade_date']) + + def calculate_heat_score(self, df, lookback_days=None): + """计算板块热度复合指标""" + if lookback_days is None: + lookback_days = self.lookback_days + + latest_dates = df['trade_date'].unique()[-lookback_days:] + recent_df = df[df['trade_date'].isin(latest_dates)].copy() + + all_dates = df['trade_date'].unique() + if len(all_dates) >= lookback_days * 2: + hist_dates = all_dates[-lookback_days*2:-lookback_days] + hist_df = df[df['trade_date'].isin(hist_dates)] + else: + hist_df = None + + heat_data = [] + + for code in recent_df['code'].unique(): + industry_data = recent_df[recent_df['code'] == code] + + if len(industry_data) < lookback_days: + continue + + name = industry_data['name'].iloc[0] + total_return = industry_data['return'].sum() + avg_turnover = industry_data['turnover'].mean() + recent_volume = industry_data['volume'].mean() + + if hist_df is not None and code in hist_df['code'].values: + hist_volume = hist_df[hist_df['code'] == code]['volume'].mean() + volume_ratio = recent_volume / hist_volume if hist_volume > 0 else 1.0 + else: + volume_ratio = 1.0 + + up_days = (industry_data['return'] > 0).sum() + momentum = up_days / lookback_days + + heat_data.append({ + 'code': code, + 'name': name, + 'total_return': total_return, + 'avg_turnover': avg_turnover, + 'volume_ratio': volume_ratio, + 'momentum': momentum, + 'latest_close': industry_data['close'].iloc[-1], + 'latest_return': industry_data['return'].iloc[-1], + 'up_days': up_days + }) + + heat_df = pd.DataFrame(heat_data) + + for col in ['total_return', 'avg_turnover', 'volume_ratio', 'momentum']: + heat_df[f'{col}_score'] = heat_df[col].rank(pct=True) + + heat_df['heat_score'] = ( + heat_df['total_return_score'] * self.weights['return'] + + heat_df['avg_turnover_score'] * self.weights['turnover'] + + heat_df['volume_ratio_score'] * self.weights['volume_ratio'] + + heat_df['momentum_score'] * self.weights['momentum'] + ) * 100 + + heat_df['heat_level'] = pd.cut( + heat_df['heat_score'], + bins=[0, 20, 40, 60, 80, 100], + labels=['极冷', '偏冷', '温和', '偏热', '极热'] + ) + + heat_df['rank'] = heat_df['heat_score'].rank(ascending=False, method='min').astype(int) + + return heat_df.sort_values('heat_score', ascending=False) + + def save_to_mysql(self, heat_df, calc_date): + """ + 将热度数据存入MySQL + :param heat_df: 热度数据DataFrame + :param calc_date: 计算日期,格式'YYYYMMDD' + """ + if not self.db_config: + print("⚠️ 未配置数据库,跳过MySQL存储") + return False + + conn = None + try: + conn = pymysql.connect(**self.db_config) + cursor = conn.cursor() + + # 准备插入数据 + insert_sql = """ + INSERT INTO sw_heat_daily + (calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum, + total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score) + VALUES + (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE + name = VALUES(name), + total_return = VALUES(total_return), + avg_turnover = VALUES(avg_turnover), + volume_ratio = VALUES(volume_ratio), + momentum = VALUES(momentum), + total_return_score = VALUES(total_return_score), + avg_turnover_score = VALUES(avg_turnover_score), + volume_ratio_score = VALUES(volume_ratio_score), + momentum_score = VALUES(momentum_score), + heat_score = VALUES(heat_score), + created_at = CURRENT_TIMESTAMP; + """ + + # 转换数据为列表元组 + data_tuples = [] + for _, row in heat_df.iterrows(): + data_tuples.append(( + calc_date, + row['code'], + row['name'], + float(row['total_return']), + float(row['avg_turnover']), + float(row['volume_ratio']), + float(row['momentum']), + float(row['total_return_score']), + float(row['avg_turnover_score']), + float(row['volume_ratio_score']), + float(row['momentum_score']), + float(row['heat_score']) + )) + + # 批量插入 + cursor.executemany(insert_sql, data_tuples) + conn.commit() + + print(f"✅ MySQL存储成功: {calc_date} 共 {len(data_tuples)} 条记录") + return True + + except Exception as e: + print(f"❌ MySQL存储失败: {e}") + if conn: + conn.rollback() + return False + finally: + if conn: + conn.close() + + def generate_report(self, save=True, to_mysql=True): + """生成热度分析报告并存储""" + calc_date = datetime.now().strftime('%Y%m%d') # YYYYMMDD格式 + + print("="*80) + print(f"🔥 申万一级行业热度报告 ({self.lookback_days}日复合指标)") + print(f"📅 计算日期: {calc_date}") + print(f"⏰ 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + print("="*80) + + # 获取数据 + raw_data = self.fetch_multi_days_data() + if raw_data is None: + print("❌ 数据获取失败") + return None + + # 计算热度 + heat_df = self.calculate_heat_score(raw_data) + + # 打印报告(省略,与之前相同)... + print(f"\n🏆 热度 TOP 10 (近{self.lookback_days}日):") + print("-"*80) + top10 = heat_df.head(10)[['rank', 'name', 'heat_score', 'heat_level', + 'total_return', 'avg_turnover', 'volume_ratio', 'momentum']] + + for _, row in top10.iterrows(): + print(f"{row['rank']:2d}. {row['name']:8s} | " + f"热度:{row['heat_score']:5.1f}({row['heat_level']}) | " + f"收益:{row['total_return']:+6.2f}% | " + f"换手:{row['avg_turnover']:5.2f}% | " + f"量比:{row['volume_ratio']:4.2f} | " + f"动量:{row['momentum']:.0%}") + + # 保存数据部分 + if save: + today = datetime.now().strftime('%Y-%m-%d') + + # # 1. 保存JSON(可选,用于调试) + # result_file = f"{self.data_dir}/heat_report_{today}_{self.lookback_days}d.json" + # heat_df.to_json(result_file, orient='records', force_ascii=False, indent=2) + # print(f"\n💾 JSON报告已保存: {result_file}") + + # 2. 保存CSV(可选,便于Excel查看) + csv_file = f"{self.data_dir}/heat_ranking_{today}_{self.lookback_days}d.csv" + heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig') + csv_file = f"{self.data_dir}/heat_ranking_newest_{self.lookback_days}d.csv" + heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig') + print(f"💾 CSV排名已保存: {csv_file}") + + # 3. 存入MySQL(新增) + if to_mysql and self.db_config: + self.save_to_mysql(heat_df, calc_date) + + return heat_df + + +# 使用示例 +if __name__ == "__main__": + # 数据库配置 + # config = { + # 'host': 'localhost', + # 'port': 3306, + # 'user': 'your_username', + # 'password': 'your_password', + # 'database': 'your_database', + # 'charset': 'utf8mb4' + # } + config = { + 'host': '10.127.2.207', + 'user': 'financial_prod', + 'password': 'mmTFncqmDal5HLRGY0BV', + 'database': 'reference', + 'port': 3306, + 'charset': 'utf8mb4' + } + + # 创建追踪器,传入数据库配置 + tracker = SWIndustryHeatTracker( + lookback_days=5, + db_config=config # 传入数据库配置 + ) + + # 生成报告并自动存入MySQL + result = tracker.generate_report(save=True, to_mysql=True) \ No newline at end of file diff --git a/mqreceive.py b/mqreceive.py index 99f13f1..bdbd691 100644 --- a/mqreceive.py +++ b/mqreceive.py @@ -24,19 +24,13 @@ with open("media_score.txt", "r", encoding="utf-8") as f: processed_ids = set() def message_callback(ch, method, properties, body): - """消息处理回调函数""" - - + """消息处理回调函数""" try: start_time = time.time() data = json.loads(body) id_str = str(data["id"]) input_date = data["input_date"] print(id_str, input_date) - - # ch.basic_ack(delivery_tag=method.delivery_tag) - # print(f"接收到消息: {id_str}") - # return # 幂等性检查:如果消息已处理过,直接确认并跳过 if id_str in processed_ids: @@ -116,7 +110,7 @@ def start_consumer(): connection = create_connection() channel = connection.channel() - # 设置QoS,限制每次只取一条消息 + # 设置QoS,限制每次只取50条消息 channel.basic_qos(prefetch_count=50) channel.exchange_declare( diff --git a/mqreceive_multithread.py b/mqreceive_multithread.py index 4f0bf5f..cd38c03 100644 --- a/mqreceive_multithread.py +++ b/mqreceive_multithread.py @@ -3,11 +3,24 @@ import json import logging import time import os +import re import threading +from datetime import datetime from concurrent.futures import ThreadPoolExecutor from queue import Queue from config import * -from llm_process import send_mq, get_label +import jieba +from llm_process import send_mq, get_label, get_translation, is_mostly_chinese, get_tdx_map, get_stock_map +from label_check import load_label_mapping, validate_tags +from check_sensitivity import SensitiveWordFilter +from news_clustering import NewsClusterer +from cal_etf_labels import * +import pandas as pd + +import threading +import time +from datetime import datetime +from apscheduler.schedulers.background import BackgroundScheduler # 声明一个全局变量,存媒体的权威度打分 media_score = {} @@ -23,6 +36,70 @@ with open("media_score.txt", "r", encoding="utf-8") as f: print(f"解析错误: {e},行内容: {line}") continue +# 行业及概念的一二级标签强制检查 +concept_mapping = load_label_mapping("concept_mapping.txt") +industry_mapping = load_label_mapping("industry_mapping.txt") + +# 敏感词检查 +sensitive_filter = SensitiveWordFilter("sensitive_words.txt") + +# 聚类打分 +logging.getLogger("pika").setLevel(logging.WARNING) +JIEBA_DICT_PATH = '/root/zzck/news_distinct_task/dict.txt.big' # jieba分词字典路径 +if os.path.exists(JIEBA_DICT_PATH): + jieba.set_dictionary(JIEBA_DICT_PATH) +news_cluster = NewsClusterer() +news_cluster.load_clusters() + +# 关键词加分 +words_bonus = {} +with open("bonus_words.txt", "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + word, bonus = line.split(",") + words_bonus[word.strip()] = int(bonus) + except ValueError as e: + print(f"关键词加分项解析错误: {e},行内容: {line}") + continue +print(words_bonus) + +# 申万一级行业热度分, 从csv读入dataframe后直接转dict +# --- 1. 初始化 --- +sw_heat_file = "/root/zzck/industry_heat_task/sw_data/heat_ranking_newest_5d.csv" +def load_sw_map(): + """封装读取逻辑,增加异常处理防止程序因为文件错误崩溃""" + try: + df = pd.read_csv(sw_heat_file) + # 使用你最优雅的一行逻辑 + return df.set_index('name')['heat_score'].to_dict() + except Exception as e: + print(f"读取CSV失败: {e}") + return None + +# 全局变量初始化 +sw_heat_map = load_sw_map() + +# --- 2. 定义更新任务 --- +def update_sw_heat_job(): + global sw_heat_map + print(f"[{datetime.now()}] 开始执行每日热度数据更新...") + new_map = load_sw_map() + if new_map: + # 原子替换:这是线程安全的,正在运行的进程会自动切换到新引用 + sw_heat_map = new_map + print("sw_heat_map 更新成功") + print(f"当前sw_heat_map示例: {list(sw_heat_map.items())[:5]}") # 打印前5项检查 + +# --- 3. 启动定时调度器 --- +scheduler = BackgroundScheduler() +# 每天 18:10 执行更新 +scheduler.add_job(update_sw_heat_job, 'cron', hour=18, minute=10) +scheduler.start() + + # 幂等性存储 - 记录已处理消息ID (使用线程安全的集合) processed_ids = set() processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问 @@ -64,31 +141,246 @@ def process_single_message(data): source = d_data[0].get('sourcename', "其他") source_impact = media_score.get(source, 5) tagged_news = get_label(content, source) - public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分 - China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度 - news_score = source_impact * 0.04 + public_opinion_score * 0.25 + China_factor * 35 - news_score = round(news_score, 2) + # deal with the label problems + tagged_news = validate_tags(tagged_news, "industry_label", "industry_confidence", industry_mapping) + tagged_news = validate_tags(tagged_news, "concept_label", "concept_confidence", concept_mapping) industry_confidence = tagged_news.get("industry_confidence", []) - industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence)) concept_confidence = tagged_news.get("concept_confidence", []) - concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence)) - # 确保最终展示的分数是两位小数 industry_confidence = list(map(lambda x: round(x, 2), industry_confidence)) concept_confidence = list(map(lambda x: round(x, 2), concept_confidence)) + + industry_label = tagged_news.get("industry_label", []) + sw_heat = 10 # 默认热度分,如果没有标签或热度数据缺失则使用默认值 + if industry_label and industry_confidence: + first_industry = industry_label[0].split("-")[0] + first_industry_confidence = industry_confidence[0] + if first_industry_confidence >= 0.75: + sw_heat = sw_heat_map.get(first_industry, 0) + + public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分 + China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度 + news_score = source_impact * 0.03 + public_opinion_score * 0.3 + China_factor * 35 + sw_heat * 0.05 + + industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence)) + concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence)) + + # 增加一个关键词加分, 例如原文标题命中 Google 加 5 分, 不区分大小写 + ori_title = data.get('title_EN', '') + hits = [] + for word, bonus in words_bonus.items(): + if re.search(rf'\b{re.escape(word)}\b', ori_title, re.IGNORECASE): + hits.append((word, bonus)) + if hits: + # 按分数降序排序,取最高分项 + hits_sorted = sorted(hits, key=lambda x: x[1], reverse=True) + print(f"原文标题{ori_title} 命中关键词及分数:", hits_sorted) + add_score = hits_sorted[0][-1] * 0.05 + print(f"新闻{id_str} 当前分数news_score变动: {news_score} -> {news_score + add_score} ... 如果超出上限后续会调整...") + news_score = min(99.0, news_score + add_score) + else: + pass + + # 增加一个聚类打分,目前仅作试验 + if datetime.now().minute == 12: + if not news_cluster.load_clusters(): + print(f"聚类文件加载失败 {datetime.now()}") + CN_title = data.get('title_txt', '') + cluster_evaluate = news_cluster.evaluate_news(title=CN_title, content=content) + cluster_score = (cluster_evaluate['weighted_score'] ** 0.5) * 10 if cluster_evaluate else 20.0 + center_news_id = cluster_evaluate['center_news_id'] if cluster_evaluate else None + news_score_exp = source_impact * 0.04 + cluster_score * 0.25 + China_factor * 35 + + # if it's Ads, decrease the score + if all(keyword in content for keyword in ["邮箱", "电话", "网站"]): + print("新闻判别为广告软文,分数强制8折") + news_score *= 0.79 + news_score_exp *= 0.79 + + # if there is no labels at all, decrease the score + if news_score >= 80 and len(tagged_news.get("industry_label", [])) == 0 and len(tagged_news.get("concept_label", [])) == 0: + print("新闻没有打上任何有效标签,分数强制8折") + news_score *= 0.79 + news_score_exp *= 0.79 + + # if there are sensitive words in title and content, decrease the score + if news_score >= 80: + title = data.get('title_txt', '') + news_text = f"{title} {content}" + result = sensitive_filter.detect(news_text) + if result: + print("新闻命中敏感词: ", ", ".join(list(result)), "; 分数强制8折") + news_score *= 0.79 + news_score_exp *= 0.79 + + # make sure the selected news label confidence is higher than 0.7 + if news_score >= 80: + industry_high_confidence = industry_confidence[0] if len(industry_confidence) > 0 else 0 + concept_high_confidence = concept_confidence[0] if len(concept_confidence) > 0 else 0 + if max(industry_high_confidence, concept_high_confidence) < 0.7: + news_score *= 0.79 + news_score_exp *= 0.79 + print(f"行业标签置信度为{industry_high_confidence}, 概念标签置信度为{concept_high_confidence}, 新闻得分强制8折.") + news_score = round(news_score, 2) + news_score_exp = round(news_score_exp, 2) + cluster_score = round(cluster_score, 2) + + # if score >= 80, add translation by llm + if news_score >= 80: + ori_content = data.get('EN_content', '') + ori_title = data.get('title_EN', '') + for _ in range(3): + # 翻译功能存在不稳定, 至多重试三次 + translation = get_translation(ori_content, ori_title) + llm_content = translation.get("llm_content", "") if translation else "" + llm_title = translation.get("llm_title", "") if translation else "" + if not is_mostly_chinese(llm_title, threshold=0.4): + print("新闻标题没有被翻译成中文, 会被舍弃, LLM标题为: ", llm_title) + llm_title = "" + continue + if not is_mostly_chinese(llm_content, threshold=0.5): + print("新闻正文没有被翻译成中文, 会被舍弃, LLM正文为: ", llm_content[:30] + "...") + llm_content = "" + continue + # 20250922新增 海外事件 海外宏观 中国宏观 行业新闻 公司新闻 转载来源 公司名称 等精细处理字段 + overseas_event = translation.get("overseas_event", None) if translation else None + overseas_macro = translation.get("overseas_macro", None) if translation else None + china_macro = translation.get("china_macro", None) if translation else None + industry_news = translation.get("industry_news", None) if translation else None + company_news = translation.get("company_news", None) if translation else None + reprint_source = translation.get("reprint_source", None) if translation else None + if reprint_source == "": + reprint_source = None + company_name = translation.get("company_name", None) if translation else None + if company_name: + # 打个补丁, 让公司析出数强制不超过3 2025-12-19 + company_name_list = company_name.strip().split(",") + if len(company_name_list) > 3: + company_name = ",".join(company_name_list[:3]) + if company_name == "": + company_name = None + ## 行业和之前的标签冗余,暂时去掉 + # industry_name = translation.get("industry_name", None) if translation else None + + # 20251217新增 修改了翻译的一些规则; 新增通达信行业标签及置信度 + tdx_industry_ori = translation.get("tdx_industry", None) if translation else None + tdx_industry_confidence_ori = translation.get("tdx_industry_confidence", None) if translation else None + tdx_map = get_tdx_map() + tdx_industry = [] + tdx_industry_confidence = [] + if tdx_industry_ori and tdx_industry_confidence_ori and len(tdx_industry_confidence_ori) == len(tdx_industry_ori): + for industry, confidence in zip(tdx_industry_ori, tdx_industry_confidence_ori): + if industry in tdx_map: + tdx_industry.append(industry) + tdx_industry_confidence.append(float(confidence)) + if len(tdx_industry) == 3: + break + # 20251226 新增 增加了一项析出A股代码的操作, 列表形式返回A股上市公司6位代码,至多5个 + stock_codes_ori = translation.get("stock_codes", None) if translation else None + if stock_codes_ori: + stock_codes_ori = list(set(stock_codes_ori)) + stock_codes = [] + stock_names = [] + stock_map = get_stock_map() + if stock_codes_ori and isinstance(stock_codes_ori, list): + for stock_code in stock_codes_ori: + if stock_code in stock_map: + stock_name = stock_map.get(stock_code) + stock_codes.append(stock_code) + stock_names.append(stock_name) + + # 20260128 新增审核, 返回政治敏感度、出口转内销两项打分 + political_sensitivity = translation.get("political_sensitivity", None) if translation else None + export_domestic = translation.get("export_domestic", None) if translation else None + political_notes = translation.get("political_notes", None) if translation else None + additional_notes = translation.get("additional_notes", None) if translation else None + + # 20260226 新增大模型摘要 + llm_abstract = translation.get("llm_abstract", None) if translation else None + if llm_content and llm_title and len(llm_content) > 30 and len(llm_title) > 3 and overseas_event is not None and overseas_macro is not None: + break + + # 精选新闻添加etf + etf_labels = get_etf_labels(title=llm_title, industry_label=tagged_news.get("industry_label", []), industry_confidence=tagged_news.get("industry_confidence", []), concept_label=tagged_news.get("concept_label", []), concept_confidence=tagged_news.get("concept_confidence", [])) + etf_names = [etf_name[label] for label in etf_labels] + tagged_news["etf_labels"] = etf_labels + tagged_news["etf_names"] = etf_names + + tagged_news["title"] = llm_title + tagged_news["rewrite_content"] = llm_content + tagged_news["overseas_event"] = overseas_event + tagged_news["overseas_macro"] = overseas_macro + tagged_news["china_macro"] = china_macro + tagged_news["industry_news"] = industry_news + tagged_news["company_news"] = company_news + tagged_news["reprint_source"] = reprint_source + tagged_news["company_name"] = company_name + + tagged_news["tdx_industry"] = tdx_industry if tdx_industry else None + tagged_news["tdx_industry_confidence"] = tdx_industry_confidence if tdx_industry_confidence else None + + tagged_news["stock_codes"] = stock_codes if stock_codes else None + tagged_news["stock_names"] = stock_names if stock_names else None + + tagged_news["political_sensitivity"] = political_sensitivity if political_sensitivity else None + tagged_news["export_domestic"] = export_domestic if export_domestic else None + tagged_news["political_notes"] = political_notes if political_notes else None + tagged_news["additional_notes"] = additional_notes if additional_notes else None + + tagged_news["llm_abstract"] = llm_abstract if llm_abstract else None + tagged_news["source"] = source tagged_news["source_impact"] = source_impact tagged_news["industry_score"] = industry_score tagged_news["concept_score"] = concept_score tagged_news["news_score"] = news_score + tagged_news["news_score_exp"] = news_score_exp + tagged_news["center_news_id"] = center_news_id tagged_news["id"] = id_str + + tagged_news["cluster_score"] = cluster_score tagged_news["industry_confidence"] = industry_confidence tagged_news["concept_confidence"] = concept_confidence + + if news_score >= 80: + # 保险起见对大模型翻译内容再做一遍敏感词筛查 20260128 + news_title_llm = tagged_news.get("title", "") + news_content_llm = tagged_news.get("rewrite_content", "") + news_text_llm = f"{news_title_llm} {news_content_llm}" + result = sensitive_filter.detect(news_text_llm) + if result: + print("大模型翻译内容命中敏感词: ", ", ".join(list(result)), "; 分数强制8折") + tagged_news["news_score"] = news_score * 0.79 + + # export_domestic >= 80 或者 political_sensitivity >= 70 也直接移除精选池 + if tagged_news.get("export_domestic", 0) >= 80 or tagged_news.get("political_sensitivity", 0) >= 70: + print(f'本条资讯的出口转内销得分为{tagged_news.get("export_domestic", 0)}, 政治敏感得分为{tagged_news.get("political_sensitivity", 0)}, political_notes信息为【{tagged_news.get("political_notes", "")}】, 分数8折') + tagged_news["news_score"] = news_score * 0.79 #print(json.dumps(tagged_news, ensure_ascii=False)) - print(tagged_news["id"], tagged_news["title"], tagged_news["news_score"], tagged_news["industry_label"], input_date) + print(tagged_news["id"], tagged_news["news_score"], tagged_news["news_score_exp"], tagged_news["industry_label"], tagged_news["concept_label"], public_opinion_score, tagged_news["cluster_score"], tagged_news["center_news_id"], input_date) + if news_score >= 80: + print("LLM translation: ", tagged_news.get("rewrite_content", "")[:100]) + print("LLM overseas_event score: ", tagged_news.get("overseas_event", None)) + print("LLM overseas_macro score: ", tagged_news.get("overseas_macro", None)) + print("LLM china_macro score: ", tagged_news.get("china_macro", None)) + if tagged_news.get("reprint_source", None): + print("LLM reprint_source: ", tagged_news.get("reprint_source", None)) + if tagged_news.get("company_name", None): + print("LLM company_name: ", tagged_news.get("company_name", None), tagged_news.get("company_news", None)) + print("ETF NAMES: ", tagged_news.get("etf_names", [])) + if tagged_news.get("tdx_industry", None): + print("LLM_max tdx_industry: ", tagged_news.get("tdx_industry", None)) + if tagged_news.get("tdx_industry_confidence", None): + print("LLM_max tdx_industry_confidence: ", tagged_news.get("tdx_industry_confidence", None)) + if tagged_news.get("stock_names", None): + print("LLM_max stock codes & names: ", tagged_news.get("stock_codes", None), tagged_news.get("stock_names", None)) + if tagged_news.get("political_sensitivity", None): + print("LLM_max political_sensitivity: ", tagged_news.get("political_sensitivity", None)) + if tagged_news.get("llm_abstract", None): + print("LLM_max llm_abstract: ", tagged_news.get("llm_abstract", None)[:30]) return tagged_news, True except Exception as e: diff --git a/mqreceivefromllm.py b/mqreceivefromllm.py index 02574f3..4972cff 100644 --- a/mqreceivefromllm.py +++ b/mqreceivefromllm.py @@ -24,7 +24,6 @@ def message_callback(ch, method, properties, body): # 数据写入资讯精选表 write_to_news(data) - # 手动确认消息 ch.basic_ack(delivery_tag=method.delivery_tag) @@ -33,11 +32,20 @@ def message_callback(ch, method, properties, body): # 拒绝消息并重新入队 ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) +def prepare_db_value(value): + """准备数据库值:空列表和None都转为None,其他情况JSON序列化""" + if value is None: + return None + if isinstance(value, list): + return json.dumps(value) if value else None + # 如果不是列表,保持原样(或者根据需求处理) + return value + def write_to_news(data): news_score = data.get('news_score', 0.0) if float(news_score) < 80: # 过滤掉news_score小于80的消息 return - + # 获取返回数据里面的 新闻id news_id = data.get('id', "") adr = jx_adr.replace("news_id", news_id) @@ -48,7 +56,67 @@ def write_to_news(data): return print(f"新闻id:{news_id} 得分:{news_score}, 调用精选接口成功") - + # 某些自建字段由python脚本自行更新 by 朱思南 + conn = pymysql.connect( + host=MYSQL_HOST_APP, + port=MYSQL_PORT_APP, + user=MYSQL_USER_APP, + password=MYSQL_PASSWORD_APP, + db=MYSQL_DB_APP, + charset='utf8mb4' + ) + try: + with conn.cursor() as cursor: + sql = '''UPDATE news + set overseas_event = %s, + overseas_macro = %s, + china_macro = %s, + industry_news = %s, + company_news = %s, + reprint_source = %s, + company_name = %s, + etf_labels = %s, + etf_names = %s, + tdx_industry = %s, + tdx_industry_confidence = %s, + stock_codes = %s, + stock_names = %s, + political_sensitivity = %s, + export_domestic = %s, + political_notes = %s, + additional_notes = %s, + llm_abstract = %s + WHERE newsinfo_id = %s ''' + values = (data.get("overseas_event", None), + data.get("overseas_macro", None), + data.get("china_macro", None), + data.get("industry_news", None), + data.get("company_news", None), + data.get("reprint_source", None), + data.get("company_name", None), + prepare_db_value(data.get("etf_labels", [])), + prepare_db_value(data.get("etf_names", [])), + prepare_db_value(data.get("tdx_industry", [])), + prepare_db_value(data.get("tdx_industry_confidence", [])), + prepare_db_value(data.get("stock_codes", [])), + prepare_db_value(data.get("stock_names", [])), + data.get("political_sensitivity", None), + data.get("export_domestic", None), + data.get("political_notes", None), + data.get("additional_notes", None), + data.get("llm_abstract", None), + news_id) + cursor.execute(sql, values) + if cursor.rowcount == 0: + print(f'warning: newsinfo_id={news_id} 不存在,更新 0 行') + conn.commit() + print(f"新闻id:{news_id} 海外事件置信度得分:{data.get('overseas_event', None)}, 关联ETF名称列表:{data.get('etf_names', None)}, 审核政治敏感度:{data.get('political_sensitivity', None)}. 自建字段写入精选news表成功!") + except Exception as e: + print(f"自建字段写入精选news表失败: {str(e)}") + finally: + if conn: + conn.close() + def write_to_es(data): """写入ES""" @@ -67,9 +135,9 @@ def write_to_es(data): "abstract": data.get('abstract', ""), "title": data.get('title', ""), "rewrite_content": data.get('rewrite_content', ""), - "industry_label": json.dumps(data.get('industry_label', [])), + "industry_label": data.get('industry_label', []), "industry_confidence": data.get('industry_confidence', []), - "industry_score": data.get('industry_score', ""), + "industry_score": data.get('industry_score', []), "concept_label": data.get('concept_label', []), "concept_confidence": data.get('concept_confidence', []), "concept_score": data.get('concept_score', []), @@ -104,29 +172,35 @@ def write_to_mysql(data): # 新增JSON结构解析逻辑 # 修改后的SQL语句 sql = """INSERT INTO news_tags - (abstract, title, rewrite_content, industry_label, industry_confidence, industry_score, concept_label, concept_confidence, concept_score, public_opinion_score, China_factor, source, source_impact, news_score, news_id) - VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """ + (abstract, title, rewrite_content, industry_label, industry_confidence, industry_score, concept_label, concept_confidence, concept_score, public_opinion_score, China_factor, source, source_impact, news_score, news_id, news_score_exp, cluster_score, center_news_id) + VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """ values = (data.get('abstract', ""), data.get('title', ""), data.get('rewrite_content', ""), - json.dumps(data.get('industry_label', [])), - json.dumps(data.get('industry_confidence', [])), - json.dumps(data.get('industry_score', [])), - json.dumps(data.get('concept_label', [])), - json.dumps(data.get('concept_confidence', [])), - json.dumps(data.get('concept_score', [])), + prepare_db_value(data.get('industry_label', [])), + prepare_db_value(data.get('industry_confidence', [])), + prepare_db_value(data.get('industry_score', [])), + prepare_db_value(data.get('concept_label', [])), + prepare_db_value(data.get('concept_confidence', [])), + prepare_db_value(data.get('concept_score', [])), data.get('public_opinion_score', 10), data.get('China_factor', 0.1), data.get('source', "其他"), data.get('source_impact', 5), data.get('news_score', 0.0), - data.get('id', "") + data.get('id', ""), + data.get('news_score_exp', 5.0), + data.get('cluster_score', 5.0), + data.get('center_news_id', None) ) cursor.execute(sql, values) conn.commit() - abstract = data.get('abstract', "") - print(f"{abstract}, 写入news_tags 表成功") + id = data.get('id', "") + industry_label = data.get('industry_label', []) + concept_label = data.get('concept_label', []) + overseas_event = data.get('overseas_event', None) + print(f"{id} {industry_label} {concept_label} {data.get('news_score', 0.0)} {data.get('news_score_exp', 5.0)} 写入news_tags 表成功") except Exception as e: print(f"写入news_tags失败: {str(e)}")