20260305 对近期改动做一次线上备份
This commit is contained in:
parent
24b67815a0
commit
d67e85ac87
|
|
@ -0,0 +1,244 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
申万行业热度数据回溯补录工具
|
||||
读取本地CSV文件,存入MySQL sw_heat_daily表
|
||||
用于历史数据补录或重新导入
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
|
||||
|
||||
def extract_date_from_filename(filename):
|
||||
"""
|
||||
从文件名提取日期
|
||||
例如: heat_ranking_2026-03-04_5d.csv -> 2026-03-04 -> 20260304
|
||||
"""
|
||||
# 匹配日期格式 YYYY-MM-DD
|
||||
pattern = r'(\d{4}-\d{2}-\d{2})'
|
||||
match = re.search(pattern, filename)
|
||||
|
||||
if match:
|
||||
date_str = match.group(1) # 2026-03-04
|
||||
# 转换为 YYYYMMDD 格式
|
||||
calc_date = date_str.replace('-', '') # 20260304
|
||||
return calc_date, date_str
|
||||
else:
|
||||
raise ValueError(f"无法从文件名提取日期: {filename}")
|
||||
|
||||
|
||||
def read_heat_csv(file_path):
|
||||
"""
|
||||
读取热度排名CSV文件
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
print(f"📂 读取文件: {file_path}")
|
||||
df = pd.read_csv(file_path, encoding='utf-8-sig')
|
||||
|
||||
# 检查必要列是否存在
|
||||
required_cols = ['code', 'name', 'total_return', 'avg_turnover',
|
||||
'volume_ratio', 'momentum', 'total_return_score',
|
||||
'avg_turnover_score', 'volume_ratio_score',
|
||||
'momentum_score', 'heat_score']
|
||||
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
# 尝试兼容旧版列名(如果有的话)
|
||||
col_mapping = {
|
||||
'行业代码': 'code',
|
||||
'行业名称': 'name',
|
||||
# 添加其他可能的映射
|
||||
}
|
||||
for old_col, new_col in col_mapping.items():
|
||||
if old_col in df.columns and new_col not in df.columns:
|
||||
df[new_col] = df[old_col]
|
||||
|
||||
# 再次检查
|
||||
still_missing = [col for col in required_cols if col not in df.columns]
|
||||
if still_missing:
|
||||
raise ValueError(f"CSV文件缺少必要列: {still_missing}")
|
||||
|
||||
print(f"✅ 成功读取 {len(df)} 条行业数据")
|
||||
return df
|
||||
|
||||
|
||||
def save_to_mysql(df, calc_date, db_config, dry_run=False):
|
||||
"""
|
||||
将DataFrame存入MySQL
|
||||
:param calc_date: YYYYMMDD 格式的日期字符串
|
||||
:param dry_run: 如果为True,只打印SQL不执行
|
||||
"""
|
||||
if dry_run:
|
||||
print(f"\n🔍 [DRY RUN] 预览模式,不实际写入数据库")
|
||||
print(f"📅 目标日期: {calc_date}")
|
||||
print(f"📊 数据预览:")
|
||||
print(df[['code', 'name', 'heat_score']].head())
|
||||
return True
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**db_config)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查该日期是否已有数据
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM sw_heat_daily WHERE calc_date = %s",
|
||||
(calc_date,)
|
||||
)
|
||||
existing = cursor.fetchone()[0]
|
||||
|
||||
if existing > 0:
|
||||
print(f"⚠️ 日期 {calc_date} 已有 {existing} 条记录,将执行覆盖更新")
|
||||
|
||||
# 准备插入SQL
|
||||
insert_sql = """
|
||||
INSERT INTO sw_heat_daily
|
||||
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
|
||||
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
|
||||
VALUES
|
||||
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
name = VALUES(name),
|
||||
total_return = VALUES(total_return),
|
||||
avg_turnover = VALUES(avg_turnover),
|
||||
volume_ratio = VALUES(volume_ratio),
|
||||
momentum = VALUES(momentum),
|
||||
total_return_score = VALUES(total_return_score),
|
||||
avg_turnover_score = VALUES(avg_turnover_score),
|
||||
volume_ratio_score = VALUES(volume_ratio_score),
|
||||
momentum_score = VALUES(momentum_score),
|
||||
heat_score = VALUES(heat_score),
|
||||
created_at = CURRENT_TIMESTAMP;
|
||||
"""
|
||||
|
||||
# 转换数据为列表元组
|
||||
data_tuples = []
|
||||
for _, row in df.iterrows():
|
||||
data_tuples.append((
|
||||
calc_date,
|
||||
str(row['code']),
|
||||
str(row['name']),
|
||||
float(row['total_return']),
|
||||
float(row['avg_turnover']),
|
||||
float(row['volume_ratio']),
|
||||
float(row['momentum']),
|
||||
float(row['total_return_score']),
|
||||
float(row['avg_turnover_score']),
|
||||
float(row['volume_ratio_score']),
|
||||
float(row['momentum_score']),
|
||||
float(row['heat_score'])
|
||||
))
|
||||
|
||||
# 执行批量插入
|
||||
cursor.executemany(insert_sql, data_tuples)
|
||||
conn.commit()
|
||||
|
||||
action = "覆盖更新" if existing > 0 else "新增"
|
||||
print(f"✅ MySQL{action}成功: {calc_date} 共 {len(data_tuples)} 条记录")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ MySQL操作失败: {e}")
|
||||
if conn:
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def batch_import(data_dir, db_config, pattern=None, dry_run=False):
|
||||
"""
|
||||
批量导入目录下所有符合格式的CSV文件
|
||||
"""
|
||||
if pattern is None:
|
||||
pattern = r'heat_ranking_\d{4}-\d{2}-\d{2}_.*\.csv'
|
||||
|
||||
files = [f for f in os.listdir(data_dir) if re.match(pattern, f)]
|
||||
files.sort()
|
||||
|
||||
print(f"📁 发现 {len(files)} 个待导入文件")
|
||||
|
||||
success_count = 0
|
||||
for filename in files:
|
||||
file_path = os.path.join(data_dir, filename)
|
||||
try:
|
||||
calc_date, _ = extract_date_from_filename(filename)
|
||||
df = read_heat_csv(file_path)
|
||||
|
||||
if save_to_mysql(df, calc_date, db_config, dry_run):
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 处理 {filename} 失败: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n📊 导入完成: {success_count}/{len(files)} 个文件成功")
|
||||
|
||||
|
||||
def main():
|
||||
# 数据库配置(请根据实际情况修改)
|
||||
config = {
|
||||
'host': '10.127.2.207',
|
||||
'user': 'financial_prod',
|
||||
'password': 'mmTFncqmDal5HLRGY0BV',
|
||||
'database': 'reference',
|
||||
'port': 3306,
|
||||
'charset': 'utf8mb4'
|
||||
}
|
||||
|
||||
# 命令行参数解析
|
||||
parser = argparse.ArgumentParser(description='申万行业热度数据回溯补录工具')
|
||||
parser.add_argument('file', nargs='?', help='单个CSV文件路径')
|
||||
parser.add_argument('--date', '-d', help='指定日期(YYYYMMDD),覆盖文件名中的日期')
|
||||
parser.add_argument('--batch', '-b', action='store_true', help='批量导入目录下所有文件')
|
||||
parser.add_argument('--dir', default='./sw_data', help='数据目录路径(默认: ./sw_data)')
|
||||
parser.add_argument('--dry-run', action='store_true', help='预览模式,不实际写入数据库')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 批量模式
|
||||
if args.batch:
|
||||
batch_import(args.dir, config, dry_run=args.dry_run)
|
||||
return
|
||||
|
||||
# 单文件模式
|
||||
if args.file:
|
||||
file_path = args.file
|
||||
else:
|
||||
# 默认示例文件
|
||||
file_path = './sw_data/heat_ranking_2026-03-04_5d.csv'
|
||||
if not os.path.exists(file_path):
|
||||
print("请提供CSV文件路径,或使用 --batch 批量导入")
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
try:
|
||||
# 读取CSV
|
||||
df = read_heat_csv(file_path)
|
||||
|
||||
# 确定日期
|
||||
if args.date:
|
||||
calc_date = args.date # 命令行指定的日期
|
||||
print(f"📅 使用指定日期: {calc_date}")
|
||||
else:
|
||||
calc_date, original_date = extract_date_from_filename(file_path)
|
||||
print(f"📅 从文件名提取日期: {original_date} -> {calc_date}")
|
||||
|
||||
# 存入MySQL
|
||||
save_to_mysql(df, calc_date, config, dry_run=args.dry_run)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,405 @@
|
|||
import akshare as ak
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import time
|
||||
import pymysql
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
class SWIndustryHeatTracker:
|
||||
"""
|
||||
申万一级行业热度追踪器
|
||||
使用 index_analysis_daily_sw 接口(收盘后官方数据)
|
||||
支持自定义热度计算周期和权重,数据自动存入MySQL
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir="./sw_data", lookback_days=5, db_config=None):
|
||||
self.data_dir = data_dir
|
||||
self.lookback_days = lookback_days
|
||||
self.db_config = db_config # MySQL配置
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
# 初始化交易日历
|
||||
self.trading_dates = self._load_trading_calendar()
|
||||
|
||||
# 申万一级行业代码映射(31个行业)
|
||||
self.industry_map = {
|
||||
'801010': '农林牧渔', '801030': '基础化工', '801040': '钢铁',
|
||||
'801050': '有色金属', '801080': '电子', '801110': '家用电器',
|
||||
'801120': '食品饮料', '801130': '纺织服饰', '801140': '轻工制造',
|
||||
'801150': '医药生物', '801160': '公用事业', '801170': '交通运输',
|
||||
'801180': '房地产', '801200': '商贸零售', '801210': '社会服务',
|
||||
'801230': '综合', '801710': '建筑材料', '801720': '建筑装饰',
|
||||
'801730': '电力设备', '801740': '国防军工', '801750': '计算机',
|
||||
'801760': '传媒', '801770': '通信', '801780': '银行',
|
||||
'801790': '非银金融', '801880': '汽车', '801890': '机械设备',
|
||||
'801950': '煤炭', '801960': '石油石化', '801970': '环保',
|
||||
'801980': '美容护理'
|
||||
}
|
||||
|
||||
# 热度指标权重配置
|
||||
self.weights = {
|
||||
'return': 0.35,
|
||||
'turnover': 0.25,
|
||||
'volume_ratio': 0.25,
|
||||
'momentum': 0.15
|
||||
}
|
||||
|
||||
# 初始化数据库表
|
||||
if self.db_config:
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""初始化MySQL数据库和表"""
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**self.db_config)
|
||||
with conn.cursor() as cursor:
|
||||
# 创建表(如果不存在)
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS sw_heat_daily (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
calc_date VARCHAR(8) NOT NULL COMMENT '计算日期(YYYYMMDD)',
|
||||
code VARCHAR(10) NOT NULL COMMENT '行业代码',
|
||||
name VARCHAR(20) NOT NULL COMMENT '行业名称',
|
||||
total_return DECIMAL(10,4) COMMENT '累计涨跌幅(%)',
|
||||
avg_turnover DECIMAL(10,4) COMMENT '平均换手率(%)',
|
||||
volume_ratio DECIMAL(10,4) COMMENT '量比',
|
||||
momentum DECIMAL(5,4) COMMENT '动量(上涨天数占比)',
|
||||
total_return_score DECIMAL(5,4) COMMENT '收益因子得分(0-1)',
|
||||
avg_turnover_score DECIMAL(5,4) COMMENT '换手因子得分(0-1)',
|
||||
volume_ratio_score DECIMAL(5,4) COMMENT '量比因子得分(0-1)',
|
||||
momentum_score DECIMAL(5,4) COMMENT '动量因子得分(0-1)',
|
||||
heat_score DECIMAL(6,2) COMMENT '热度综合得分(0-100)',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_date_code (calc_date, code),
|
||||
INDEX idx_date (calc_date),
|
||||
INDEX idx_heat_score (calc_date, heat_score)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='申万一级行业热度日表';
|
||||
"""
|
||||
cursor.execute(create_table_sql)
|
||||
conn.commit()
|
||||
print("✅ 数据库表 sw_heat_daily 初始化成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库初始化失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def _load_trading_calendar(self):
|
||||
"""加载交易日历"""
|
||||
cache_file = f"{self.data_dir}/trading_calendar.csv"
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
df = pd.read_csv(cache_file)
|
||||
dates = df['trade_date'].astype(str).tolist()
|
||||
cache_time = datetime.fromtimestamp(os.path.getmtime(cache_file))
|
||||
if (datetime.now() - cache_time).days > 7:
|
||||
self._update_calendar_cache(cache_file)
|
||||
else:
|
||||
dates = self._update_calendar_cache(cache_file)
|
||||
|
||||
return set(dates)
|
||||
|
||||
def _update_calendar_cache(self, cache_file):
|
||||
"""更新交易日历缓存"""
|
||||
print("正在更新交易日历...")
|
||||
df = ak.tool_trade_date_hist_sina()
|
||||
df.to_csv(cache_file, index=False)
|
||||
return df['trade_date'].astype(str).tolist()
|
||||
|
||||
def get_recent_trading_dates(self, n_days, end_date=None):
|
||||
"""获取最近N个交易日"""
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
valid_dates = [d for d in self.trading_dates if d <= end_date]
|
||||
return sorted(valid_dates)[-n_days:]
|
||||
|
||||
def fetch_daily_data(self, trade_date):
|
||||
"""获取指定交易日的申万行业日报数据"""
|
||||
date_str = trade_date.replace('-', '')
|
||||
|
||||
try:
|
||||
df = ak.index_analysis_daily_sw(
|
||||
symbol="一级行业",
|
||||
start_date=date_str,
|
||||
end_date=date_str
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
print(f"⚠️ {trade_date} 无数据")
|
||||
return None
|
||||
|
||||
df = df.rename(columns={
|
||||
'指数代码': 'code',
|
||||
'指数名称': 'name',
|
||||
'发布日期': 'date',
|
||||
'收盘指数': 'close',
|
||||
'涨跌幅': 'return',
|
||||
'换手率': 'turnover',
|
||||
'成交量': 'volume',
|
||||
'成交额': 'amount',
|
||||
'市盈率': 'pe',
|
||||
'市净率': 'pb'
|
||||
})
|
||||
|
||||
df['trade_date'] = trade_date
|
||||
df['code'] = df['code'].astype(str)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取 {trade_date} 数据失败: {e}")
|
||||
return None
|
||||
|
||||
def fetch_multi_days_data(self, n_days=None):
|
||||
"""获取最近N天的数据"""
|
||||
if n_days is None:
|
||||
n_days = self.lookback_days + 5
|
||||
|
||||
dates = self.get_recent_trading_dates(n_days)
|
||||
print(f"📅 获取数据区间: {dates[0]} 至 {dates[-1]}")
|
||||
|
||||
all_data = []
|
||||
for date in dates:
|
||||
df = self.fetch_daily_data(date)
|
||||
if df is not None:
|
||||
all_data.append(df)
|
||||
time.sleep(0.5)
|
||||
|
||||
if not all_data:
|
||||
return None
|
||||
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
combined['trade_date'] = pd.to_datetime(combined['trade_date'])
|
||||
return combined.sort_values(['code', 'trade_date'])
|
||||
|
||||
def calculate_heat_score(self, df, lookback_days=None):
|
||||
"""计算板块热度复合指标"""
|
||||
if lookback_days is None:
|
||||
lookback_days = self.lookback_days
|
||||
|
||||
latest_dates = df['trade_date'].unique()[-lookback_days:]
|
||||
recent_df = df[df['trade_date'].isin(latest_dates)].copy()
|
||||
|
||||
all_dates = df['trade_date'].unique()
|
||||
if len(all_dates) >= lookback_days * 2:
|
||||
hist_dates = all_dates[-lookback_days*2:-lookback_days]
|
||||
hist_df = df[df['trade_date'].isin(hist_dates)]
|
||||
else:
|
||||
hist_df = None
|
||||
|
||||
heat_data = []
|
||||
|
||||
for code in recent_df['code'].unique():
|
||||
industry_data = recent_df[recent_df['code'] == code]
|
||||
|
||||
if len(industry_data) < lookback_days:
|
||||
continue
|
||||
|
||||
name = industry_data['name'].iloc[0]
|
||||
total_return = industry_data['return'].sum()
|
||||
avg_turnover = industry_data['turnover'].mean()
|
||||
recent_volume = industry_data['volume'].mean()
|
||||
|
||||
if hist_df is not None and code in hist_df['code'].values:
|
||||
hist_volume = hist_df[hist_df['code'] == code]['volume'].mean()
|
||||
volume_ratio = recent_volume / hist_volume if hist_volume > 0 else 1.0
|
||||
else:
|
||||
volume_ratio = 1.0
|
||||
|
||||
up_days = (industry_data['return'] > 0).sum()
|
||||
momentum = up_days / lookback_days
|
||||
|
||||
heat_data.append({
|
||||
'code': code,
|
||||
'name': name,
|
||||
'total_return': total_return,
|
||||
'avg_turnover': avg_turnover,
|
||||
'volume_ratio': volume_ratio,
|
||||
'momentum': momentum,
|
||||
'latest_close': industry_data['close'].iloc[-1],
|
||||
'latest_return': industry_data['return'].iloc[-1],
|
||||
'up_days': up_days
|
||||
})
|
||||
|
||||
heat_df = pd.DataFrame(heat_data)
|
||||
|
||||
for col in ['total_return', 'avg_turnover', 'volume_ratio', 'momentum']:
|
||||
heat_df[f'{col}_score'] = heat_df[col].rank(pct=True)
|
||||
|
||||
heat_df['heat_score'] = (
|
||||
heat_df['total_return_score'] * self.weights['return'] +
|
||||
heat_df['avg_turnover_score'] * self.weights['turnover'] +
|
||||
heat_df['volume_ratio_score'] * self.weights['volume_ratio'] +
|
||||
heat_df['momentum_score'] * self.weights['momentum']
|
||||
) * 100
|
||||
|
||||
heat_df['heat_level'] = pd.cut(
|
||||
heat_df['heat_score'],
|
||||
bins=[0, 20, 40, 60, 80, 100],
|
||||
labels=['极冷', '偏冷', '温和', '偏热', '极热']
|
||||
)
|
||||
|
||||
heat_df['rank'] = heat_df['heat_score'].rank(ascending=False, method='min').astype(int)
|
||||
|
||||
return heat_df.sort_values('heat_score', ascending=False)
|
||||
|
||||
def save_to_mysql(self, heat_df, calc_date):
|
||||
"""
|
||||
将热度数据存入MySQL
|
||||
:param heat_df: 热度数据DataFrame
|
||||
:param calc_date: 计算日期,格式'YYYYMMDD'
|
||||
"""
|
||||
if not self.db_config:
|
||||
print("⚠️ 未配置数据库,跳过MySQL存储")
|
||||
return False
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**self.db_config)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 准备插入数据
|
||||
insert_sql = """
|
||||
INSERT INTO sw_heat_daily
|
||||
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
|
||||
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
|
||||
VALUES
|
||||
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
name = VALUES(name),
|
||||
total_return = VALUES(total_return),
|
||||
avg_turnover = VALUES(avg_turnover),
|
||||
volume_ratio = VALUES(volume_ratio),
|
||||
momentum = VALUES(momentum),
|
||||
total_return_score = VALUES(total_return_score),
|
||||
avg_turnover_score = VALUES(avg_turnover_score),
|
||||
volume_ratio_score = VALUES(volume_ratio_score),
|
||||
momentum_score = VALUES(momentum_score),
|
||||
heat_score = VALUES(heat_score),
|
||||
created_at = CURRENT_TIMESTAMP;
|
||||
"""
|
||||
|
||||
# 转换数据为列表元组
|
||||
data_tuples = []
|
||||
for _, row in heat_df.iterrows():
|
||||
data_tuples.append((
|
||||
calc_date,
|
||||
row['code'],
|
||||
row['name'],
|
||||
float(row['total_return']),
|
||||
float(row['avg_turnover']),
|
||||
float(row['volume_ratio']),
|
||||
float(row['momentum']),
|
||||
float(row['total_return_score']),
|
||||
float(row['avg_turnover_score']),
|
||||
float(row['volume_ratio_score']),
|
||||
float(row['momentum_score']),
|
||||
float(row['heat_score'])
|
||||
))
|
||||
|
||||
# 批量插入
|
||||
cursor.executemany(insert_sql, data_tuples)
|
||||
conn.commit()
|
||||
|
||||
print(f"✅ MySQL存储成功: {calc_date} 共 {len(data_tuples)} 条记录")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ MySQL存储失败: {e}")
|
||||
if conn:
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def generate_report(self, save=True, to_mysql=True):
|
||||
"""生成热度分析报告并存储"""
|
||||
calc_date = datetime.now().strftime('%Y%m%d') # YYYYMMDD格式
|
||||
|
||||
print("="*80)
|
||||
print(f"🔥 申万一级行业热度报告 ({self.lookback_days}日复合指标)")
|
||||
print(f"📅 计算日期: {calc_date}")
|
||||
print(f"⏰ 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print("="*80)
|
||||
|
||||
# 获取数据
|
||||
raw_data = self.fetch_multi_days_data()
|
||||
if raw_data is None:
|
||||
print("❌ 数据获取失败")
|
||||
return None
|
||||
|
||||
# 计算热度
|
||||
heat_df = self.calculate_heat_score(raw_data)
|
||||
|
||||
# 打印报告(省略,与之前相同)...
|
||||
print(f"\n🏆 热度 TOP 10 (近{self.lookback_days}日):")
|
||||
print("-"*80)
|
||||
top10 = heat_df.head(10)[['rank', 'name', 'heat_score', 'heat_level',
|
||||
'total_return', 'avg_turnover', 'volume_ratio', 'momentum']]
|
||||
|
||||
for _, row in top10.iterrows():
|
||||
print(f"{row['rank']:2d}. {row['name']:8s} | "
|
||||
f"热度:{row['heat_score']:5.1f}({row['heat_level']}) | "
|
||||
f"收益:{row['total_return']:+6.2f}% | "
|
||||
f"换手:{row['avg_turnover']:5.2f}% | "
|
||||
f"量比:{row['volume_ratio']:4.2f} | "
|
||||
f"动量:{row['momentum']:.0%}")
|
||||
|
||||
# 保存数据部分
|
||||
if save:
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
# # 1. 保存JSON(可选,用于调试)
|
||||
# result_file = f"{self.data_dir}/heat_report_{today}_{self.lookback_days}d.json"
|
||||
# heat_df.to_json(result_file, orient='records', force_ascii=False, indent=2)
|
||||
# print(f"\n💾 JSON报告已保存: {result_file}")
|
||||
|
||||
# 2. 保存CSV(可选,便于Excel查看)
|
||||
csv_file = f"{self.data_dir}/heat_ranking_{today}_{self.lookback_days}d.csv"
|
||||
heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig')
|
||||
csv_file = f"{self.data_dir}/heat_ranking_newest_{self.lookback_days}d.csv"
|
||||
heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig')
|
||||
print(f"💾 CSV排名已保存: {csv_file}")
|
||||
|
||||
# 3. 存入MySQL(新增)
|
||||
if to_mysql and self.db_config:
|
||||
self.save_to_mysql(heat_df, calc_date)
|
||||
|
||||
return heat_df
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 数据库配置
|
||||
# config = {
|
||||
# 'host': 'localhost',
|
||||
# 'port': 3306,
|
||||
# 'user': 'your_username',
|
||||
# 'password': 'your_password',
|
||||
# 'database': 'your_database',
|
||||
# 'charset': 'utf8mb4'
|
||||
# }
|
||||
config = {
|
||||
'host': '10.127.2.207',
|
||||
'user': 'financial_prod',
|
||||
'password': 'mmTFncqmDal5HLRGY0BV',
|
||||
'database': 'reference',
|
||||
'port': 3306,
|
||||
'charset': 'utf8mb4'
|
||||
}
|
||||
|
||||
# 创建追踪器,传入数据库配置
|
||||
tracker = SWIndustryHeatTracker(
|
||||
lookback_days=5,
|
||||
db_config=config # 传入数据库配置
|
||||
)
|
||||
|
||||
# 生成报告并自动存入MySQL
|
||||
result = tracker.generate_report(save=True, to_mysql=True)
|
||||
10
mqreceive.py
10
mqreceive.py
|
|
@ -24,19 +24,13 @@ with open("media_score.txt", "r", encoding="utf-8") as f:
|
|||
processed_ids = set()
|
||||
|
||||
def message_callback(ch, method, properties, body):
|
||||
"""消息处理回调函数"""
|
||||
|
||||
|
||||
"""消息处理回调函数"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
data = json.loads(body)
|
||||
id_str = str(data["id"])
|
||||
input_date = data["input_date"]
|
||||
print(id_str, input_date)
|
||||
|
||||
# ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
# print(f"接收到消息: {id_str}")
|
||||
# return
|
||||
|
||||
# 幂等性检查:如果消息已处理过,直接确认并跳过
|
||||
if id_str in processed_ids:
|
||||
|
|
@ -116,7 +110,7 @@ def start_consumer():
|
|||
connection = create_connection()
|
||||
channel = connection.channel()
|
||||
|
||||
# 设置QoS,限制每次只取一条消息
|
||||
# 设置QoS,限制每次只取50条消息
|
||||
channel.basic_qos(prefetch_count=50)
|
||||
|
||||
channel.exchange_declare(
|
||||
|
|
|
|||
|
|
@ -3,11 +3,24 @@ import json
|
|||
import logging
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from config import *
|
||||
from llm_process import send_mq, get_label
|
||||
import jieba
|
||||
from llm_process import send_mq, get_label, get_translation, is_mostly_chinese, get_tdx_map, get_stock_map
|
||||
from label_check import load_label_mapping, validate_tags
|
||||
from check_sensitivity import SensitiveWordFilter
|
||||
from news_clustering import NewsClusterer
|
||||
from cal_etf_labels import *
|
||||
import pandas as pd
|
||||
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
|
||||
# 声明一个全局变量,存媒体的权威度打分
|
||||
media_score = {}
|
||||
|
|
@ -23,6 +36,70 @@ with open("media_score.txt", "r", encoding="utf-8") as f:
|
|||
print(f"解析错误: {e},行内容: {line}")
|
||||
continue
|
||||
|
||||
# 行业及概念的一二级标签强制检查
|
||||
concept_mapping = load_label_mapping("concept_mapping.txt")
|
||||
industry_mapping = load_label_mapping("industry_mapping.txt")
|
||||
|
||||
# 敏感词检查
|
||||
sensitive_filter = SensitiveWordFilter("sensitive_words.txt")
|
||||
|
||||
# 聚类打分
|
||||
logging.getLogger("pika").setLevel(logging.WARNING)
|
||||
JIEBA_DICT_PATH = '/root/zzck/news_distinct_task/dict.txt.big' # jieba分词字典路径
|
||||
if os.path.exists(JIEBA_DICT_PATH):
|
||||
jieba.set_dictionary(JIEBA_DICT_PATH)
|
||||
news_cluster = NewsClusterer()
|
||||
news_cluster.load_clusters()
|
||||
|
||||
# 关键词加分
|
||||
words_bonus = {}
|
||||
with open("bonus_words.txt", "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
word, bonus = line.split(",")
|
||||
words_bonus[word.strip()] = int(bonus)
|
||||
except ValueError as e:
|
||||
print(f"关键词加分项解析错误: {e},行内容: {line}")
|
||||
continue
|
||||
print(words_bonus)
|
||||
|
||||
# 申万一级行业热度分, 从csv读入dataframe后直接转dict
|
||||
# --- 1. 初始化 ---
|
||||
sw_heat_file = "/root/zzck/industry_heat_task/sw_data/heat_ranking_newest_5d.csv"
|
||||
def load_sw_map():
|
||||
"""封装读取逻辑,增加异常处理防止程序因为文件错误崩溃"""
|
||||
try:
|
||||
df = pd.read_csv(sw_heat_file)
|
||||
# 使用你最优雅的一行逻辑
|
||||
return df.set_index('name')['heat_score'].to_dict()
|
||||
except Exception as e:
|
||||
print(f"读取CSV失败: {e}")
|
||||
return None
|
||||
|
||||
# 全局变量初始化
|
||||
sw_heat_map = load_sw_map()
|
||||
|
||||
# --- 2. 定义更新任务 ---
|
||||
def update_sw_heat_job():
|
||||
global sw_heat_map
|
||||
print(f"[{datetime.now()}] 开始执行每日热度数据更新...")
|
||||
new_map = load_sw_map()
|
||||
if new_map:
|
||||
# 原子替换:这是线程安全的,正在运行的进程会自动切换到新引用
|
||||
sw_heat_map = new_map
|
||||
print("sw_heat_map 更新成功")
|
||||
print(f"当前sw_heat_map示例: {list(sw_heat_map.items())[:5]}") # 打印前5项检查
|
||||
|
||||
# --- 3. 启动定时调度器 ---
|
||||
scheduler = BackgroundScheduler()
|
||||
# 每天 18:10 执行更新
|
||||
scheduler.add_job(update_sw_heat_job, 'cron', hour=18, minute=10)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
# 幂等性存储 - 记录已处理消息ID (使用线程安全的集合)
|
||||
processed_ids = set()
|
||||
processed_ids_lock = threading.Lock() # 用于同步对processed_ids的访问
|
||||
|
|
@ -64,31 +141,246 @@ def process_single_message(data):
|
|||
source = d_data[0].get('sourcename', "其他")
|
||||
source_impact = media_score.get(source, 5)
|
||||
tagged_news = get_label(content, source)
|
||||
public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分
|
||||
China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度
|
||||
news_score = source_impact * 0.04 + public_opinion_score * 0.25 + China_factor * 35
|
||||
news_score = round(news_score, 2)
|
||||
# deal with the label problems
|
||||
tagged_news = validate_tags(tagged_news, "industry_label", "industry_confidence", industry_mapping)
|
||||
tagged_news = validate_tags(tagged_news, "concept_label", "concept_confidence", concept_mapping)
|
||||
|
||||
industry_confidence = tagged_news.get("industry_confidence", [])
|
||||
industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence))
|
||||
concept_confidence = tagged_news.get("concept_confidence", [])
|
||||
concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence))
|
||||
|
||||
# 确保最终展示的分数是两位小数
|
||||
industry_confidence = list(map(lambda x: round(x, 2), industry_confidence))
|
||||
concept_confidence = list(map(lambda x: round(x, 2), concept_confidence))
|
||||
|
||||
industry_label = tagged_news.get("industry_label", [])
|
||||
sw_heat = 10 # 默认热度分,如果没有标签或热度数据缺失则使用默认值
|
||||
if industry_label and industry_confidence:
|
||||
first_industry = industry_label[0].split("-")[0]
|
||||
first_industry_confidence = industry_confidence[0]
|
||||
if first_industry_confidence >= 0.75:
|
||||
sw_heat = sw_heat_map.get(first_industry, 0)
|
||||
|
||||
public_opinion_score = tagged_news.get("public_opinion_score", 30) #资讯质量分
|
||||
China_factor = tagged_news.get("China_factor", 0.2) #中国股市相关度
|
||||
news_score = source_impact * 0.03 + public_opinion_score * 0.3 + China_factor * 35 + sw_heat * 0.05
|
||||
|
||||
|
||||
industry_score = list(map(lambda x: round(x * news_score, 2), industry_confidence))
|
||||
concept_score = list(map(lambda x: round(x * news_score, 2), concept_confidence))
|
||||
|
||||
# 增加一个关键词加分, 例如原文标题命中 Google 加 5 分, 不区分大小写
|
||||
ori_title = data.get('title_EN', '')
|
||||
hits = []
|
||||
for word, bonus in words_bonus.items():
|
||||
if re.search(rf'\b{re.escape(word)}\b', ori_title, re.IGNORECASE):
|
||||
hits.append((word, bonus))
|
||||
if hits:
|
||||
# 按分数降序排序,取最高分项
|
||||
hits_sorted = sorted(hits, key=lambda x: x[1], reverse=True)
|
||||
print(f"原文标题{ori_title} 命中关键词及分数:", hits_sorted)
|
||||
add_score = hits_sorted[0][-1] * 0.05
|
||||
print(f"新闻{id_str} 当前分数news_score变动: {news_score} -> {news_score + add_score} ... 如果超出上限后续会调整...")
|
||||
news_score = min(99.0, news_score + add_score)
|
||||
else:
|
||||
pass
|
||||
|
||||
# 增加一个聚类打分,目前仅作试验
|
||||
if datetime.now().minute == 12:
|
||||
if not news_cluster.load_clusters():
|
||||
print(f"聚类文件加载失败 {datetime.now()}")
|
||||
CN_title = data.get('title_txt', '')
|
||||
cluster_evaluate = news_cluster.evaluate_news(title=CN_title, content=content)
|
||||
cluster_score = (cluster_evaluate['weighted_score'] ** 0.5) * 10 if cluster_evaluate else 20.0
|
||||
center_news_id = cluster_evaluate['center_news_id'] if cluster_evaluate else None
|
||||
news_score_exp = source_impact * 0.04 + cluster_score * 0.25 + China_factor * 35
|
||||
|
||||
# if it's Ads, decrease the score
|
||||
if all(keyword in content for keyword in ["邮箱", "电话", "网站"]):
|
||||
print("新闻判别为广告软文,分数强制8折")
|
||||
news_score *= 0.79
|
||||
news_score_exp *= 0.79
|
||||
|
||||
# if there is no labels at all, decrease the score
|
||||
if news_score >= 80 and len(tagged_news.get("industry_label", [])) == 0 and len(tagged_news.get("concept_label", [])) == 0:
|
||||
print("新闻没有打上任何有效标签,分数强制8折")
|
||||
news_score *= 0.79
|
||||
news_score_exp *= 0.79
|
||||
|
||||
# if there are sensitive words in title and content, decrease the score
|
||||
if news_score >= 80:
|
||||
title = data.get('title_txt', '')
|
||||
news_text = f"{title} {content}"
|
||||
result = sensitive_filter.detect(news_text)
|
||||
if result:
|
||||
print("新闻命中敏感词: ", ", ".join(list(result)), "; 分数强制8折")
|
||||
news_score *= 0.79
|
||||
news_score_exp *= 0.79
|
||||
|
||||
# make sure the selected news label confidence is higher than 0.7
|
||||
if news_score >= 80:
|
||||
industry_high_confidence = industry_confidence[0] if len(industry_confidence) > 0 else 0
|
||||
concept_high_confidence = concept_confidence[0] if len(concept_confidence) > 0 else 0
|
||||
if max(industry_high_confidence, concept_high_confidence) < 0.7:
|
||||
news_score *= 0.79
|
||||
news_score_exp *= 0.79
|
||||
print(f"行业标签置信度为{industry_high_confidence}, 概念标签置信度为{concept_high_confidence}, 新闻得分强制8折.")
|
||||
news_score = round(news_score, 2)
|
||||
news_score_exp = round(news_score_exp, 2)
|
||||
cluster_score = round(cluster_score, 2)
|
||||
|
||||
# if score >= 80, add translation by llm
|
||||
if news_score >= 80:
|
||||
ori_content = data.get('EN_content', '')
|
||||
ori_title = data.get('title_EN', '')
|
||||
for _ in range(3):
|
||||
# 翻译功能存在不稳定, 至多重试三次
|
||||
translation = get_translation(ori_content, ori_title)
|
||||
llm_content = translation.get("llm_content", "") if translation else ""
|
||||
llm_title = translation.get("llm_title", "") if translation else ""
|
||||
if not is_mostly_chinese(llm_title, threshold=0.4):
|
||||
print("新闻标题没有被翻译成中文, 会被舍弃, LLM标题为: ", llm_title)
|
||||
llm_title = ""
|
||||
continue
|
||||
if not is_mostly_chinese(llm_content, threshold=0.5):
|
||||
print("新闻正文没有被翻译成中文, 会被舍弃, LLM正文为: ", llm_content[:30] + "...")
|
||||
llm_content = ""
|
||||
continue
|
||||
# 20250922新增 海外事件 海外宏观 中国宏观 行业新闻 公司新闻 转载来源 公司名称 等精细处理字段
|
||||
overseas_event = translation.get("overseas_event", None) if translation else None
|
||||
overseas_macro = translation.get("overseas_macro", None) if translation else None
|
||||
china_macro = translation.get("china_macro", None) if translation else None
|
||||
industry_news = translation.get("industry_news", None) if translation else None
|
||||
company_news = translation.get("company_news", None) if translation else None
|
||||
reprint_source = translation.get("reprint_source", None) if translation else None
|
||||
if reprint_source == "":
|
||||
reprint_source = None
|
||||
company_name = translation.get("company_name", None) if translation else None
|
||||
if company_name:
|
||||
# 打个补丁, 让公司析出数强制不超过3 2025-12-19
|
||||
company_name_list = company_name.strip().split(",")
|
||||
if len(company_name_list) > 3:
|
||||
company_name = ",".join(company_name_list[:3])
|
||||
if company_name == "":
|
||||
company_name = None
|
||||
## 行业和之前的标签冗余,暂时去掉
|
||||
# industry_name = translation.get("industry_name", None) if translation else None
|
||||
|
||||
# 20251217新增 修改了翻译的一些规则; 新增通达信行业标签及置信度
|
||||
tdx_industry_ori = translation.get("tdx_industry", None) if translation else None
|
||||
tdx_industry_confidence_ori = translation.get("tdx_industry_confidence", None) if translation else None
|
||||
tdx_map = get_tdx_map()
|
||||
tdx_industry = []
|
||||
tdx_industry_confidence = []
|
||||
if tdx_industry_ori and tdx_industry_confidence_ori and len(tdx_industry_confidence_ori) == len(tdx_industry_ori):
|
||||
for industry, confidence in zip(tdx_industry_ori, tdx_industry_confidence_ori):
|
||||
if industry in tdx_map:
|
||||
tdx_industry.append(industry)
|
||||
tdx_industry_confidence.append(float(confidence))
|
||||
if len(tdx_industry) == 3:
|
||||
break
|
||||
# 20251226 新增 增加了一项析出A股代码的操作, 列表形式返回A股上市公司6位代码,至多5个
|
||||
stock_codes_ori = translation.get("stock_codes", None) if translation else None
|
||||
if stock_codes_ori:
|
||||
stock_codes_ori = list(set(stock_codes_ori))
|
||||
stock_codes = []
|
||||
stock_names = []
|
||||
stock_map = get_stock_map()
|
||||
if stock_codes_ori and isinstance(stock_codes_ori, list):
|
||||
for stock_code in stock_codes_ori:
|
||||
if stock_code in stock_map:
|
||||
stock_name = stock_map.get(stock_code)
|
||||
stock_codes.append(stock_code)
|
||||
stock_names.append(stock_name)
|
||||
|
||||
# 20260128 新增审核, 返回政治敏感度、出口转内销两项打分
|
||||
political_sensitivity = translation.get("political_sensitivity", None) if translation else None
|
||||
export_domestic = translation.get("export_domestic", None) if translation else None
|
||||
political_notes = translation.get("political_notes", None) if translation else None
|
||||
additional_notes = translation.get("additional_notes", None) if translation else None
|
||||
|
||||
# 20260226 新增大模型摘要
|
||||
llm_abstract = translation.get("llm_abstract", None) if translation else None
|
||||
if llm_content and llm_title and len(llm_content) > 30 and len(llm_title) > 3 and overseas_event is not None and overseas_macro is not None:
|
||||
break
|
||||
|
||||
# 精选新闻添加etf
|
||||
etf_labels = get_etf_labels(title=llm_title, industry_label=tagged_news.get("industry_label", []), industry_confidence=tagged_news.get("industry_confidence", []), concept_label=tagged_news.get("concept_label", []), concept_confidence=tagged_news.get("concept_confidence", []))
|
||||
etf_names = [etf_name[label] for label in etf_labels]
|
||||
tagged_news["etf_labels"] = etf_labels
|
||||
tagged_news["etf_names"] = etf_names
|
||||
|
||||
tagged_news["title"] = llm_title
|
||||
tagged_news["rewrite_content"] = llm_content
|
||||
tagged_news["overseas_event"] = overseas_event
|
||||
tagged_news["overseas_macro"] = overseas_macro
|
||||
tagged_news["china_macro"] = china_macro
|
||||
tagged_news["industry_news"] = industry_news
|
||||
tagged_news["company_news"] = company_news
|
||||
tagged_news["reprint_source"] = reprint_source
|
||||
tagged_news["company_name"] = company_name
|
||||
|
||||
tagged_news["tdx_industry"] = tdx_industry if tdx_industry else None
|
||||
tagged_news["tdx_industry_confidence"] = tdx_industry_confidence if tdx_industry_confidence else None
|
||||
|
||||
tagged_news["stock_codes"] = stock_codes if stock_codes else None
|
||||
tagged_news["stock_names"] = stock_names if stock_names else None
|
||||
|
||||
tagged_news["political_sensitivity"] = political_sensitivity if political_sensitivity else None
|
||||
tagged_news["export_domestic"] = export_domestic if export_domestic else None
|
||||
tagged_news["political_notes"] = political_notes if political_notes else None
|
||||
tagged_news["additional_notes"] = additional_notes if additional_notes else None
|
||||
|
||||
tagged_news["llm_abstract"] = llm_abstract if llm_abstract else None
|
||||
|
||||
tagged_news["source"] = source
|
||||
tagged_news["source_impact"] = source_impact
|
||||
tagged_news["industry_score"] = industry_score
|
||||
tagged_news["concept_score"] = concept_score
|
||||
tagged_news["news_score"] = news_score
|
||||
tagged_news["news_score_exp"] = news_score_exp
|
||||
tagged_news["center_news_id"] = center_news_id
|
||||
tagged_news["id"] = id_str
|
||||
|
||||
tagged_news["cluster_score"] = cluster_score
|
||||
tagged_news["industry_confidence"] = industry_confidence
|
||||
tagged_news["concept_confidence"] = concept_confidence
|
||||
|
||||
if news_score >= 80:
|
||||
# 保险起见对大模型翻译内容再做一遍敏感词筛查 20260128
|
||||
news_title_llm = tagged_news.get("title", "")
|
||||
news_content_llm = tagged_news.get("rewrite_content", "")
|
||||
news_text_llm = f"{news_title_llm} {news_content_llm}"
|
||||
result = sensitive_filter.detect(news_text_llm)
|
||||
if result:
|
||||
print("大模型翻译内容命中敏感词: ", ", ".join(list(result)), "; 分数强制8折")
|
||||
tagged_news["news_score"] = news_score * 0.79
|
||||
|
||||
# export_domestic >= 80 或者 political_sensitivity >= 70 也直接移除精选池
|
||||
if tagged_news.get("export_domestic", 0) >= 80 or tagged_news.get("political_sensitivity", 0) >= 70:
|
||||
print(f'本条资讯的出口转内销得分为{tagged_news.get("export_domestic", 0)}, 政治敏感得分为{tagged_news.get("political_sensitivity", 0)}, political_notes信息为【{tagged_news.get("political_notes", "")}】, 分数8折')
|
||||
tagged_news["news_score"] = news_score * 0.79
|
||||
|
||||
#print(json.dumps(tagged_news, ensure_ascii=False))
|
||||
print(tagged_news["id"], tagged_news["title"], tagged_news["news_score"], tagged_news["industry_label"], input_date)
|
||||
print(tagged_news["id"], tagged_news["news_score"], tagged_news["news_score_exp"], tagged_news["industry_label"], tagged_news["concept_label"], public_opinion_score, tagged_news["cluster_score"], tagged_news["center_news_id"], input_date)
|
||||
if news_score >= 80:
|
||||
print("LLM translation: ", tagged_news.get("rewrite_content", "")[:100])
|
||||
print("LLM overseas_event score: ", tagged_news.get("overseas_event", None))
|
||||
print("LLM overseas_macro score: ", tagged_news.get("overseas_macro", None))
|
||||
print("LLM china_macro score: ", tagged_news.get("china_macro", None))
|
||||
if tagged_news.get("reprint_source", None):
|
||||
print("LLM reprint_source: ", tagged_news.get("reprint_source", None))
|
||||
if tagged_news.get("company_name", None):
|
||||
print("LLM company_name: ", tagged_news.get("company_name", None), tagged_news.get("company_news", None))
|
||||
print("ETF NAMES: ", tagged_news.get("etf_names", []))
|
||||
if tagged_news.get("tdx_industry", None):
|
||||
print("LLM_max tdx_industry: ", tagged_news.get("tdx_industry", None))
|
||||
if tagged_news.get("tdx_industry_confidence", None):
|
||||
print("LLM_max tdx_industry_confidence: ", tagged_news.get("tdx_industry_confidence", None))
|
||||
if tagged_news.get("stock_names", None):
|
||||
print("LLM_max stock codes & names: ", tagged_news.get("stock_codes", None), tagged_news.get("stock_names", None))
|
||||
if tagged_news.get("political_sensitivity", None):
|
||||
print("LLM_max political_sensitivity: ", tagged_news.get("political_sensitivity", None))
|
||||
if tagged_news.get("llm_abstract", None):
|
||||
print("LLM_max llm_abstract: ", tagged_news.get("llm_abstract", None)[:30])
|
||||
return tagged_news, True
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ def message_callback(ch, method, properties, body):
|
|||
# 数据写入资讯精选表
|
||||
write_to_news(data)
|
||||
|
||||
|
||||
|
||||
# 手动确认消息
|
||||
ch.basic_ack(delivery_tag=method.delivery_tag)
|
||||
|
|
@ -33,11 +32,20 @@ def message_callback(ch, method, properties, body):
|
|||
# 拒绝消息并重新入队
|
||||
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
|
||||
|
||||
def prepare_db_value(value):
|
||||
"""准备数据库值:空列表和None都转为None,其他情况JSON序列化"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return json.dumps(value) if value else None
|
||||
# 如果不是列表,保持原样(或者根据需求处理)
|
||||
return value
|
||||
|
||||
def write_to_news(data):
|
||||
news_score = data.get('news_score', 0.0)
|
||||
if float(news_score) < 80: # 过滤掉news_score小于80的消息
|
||||
return
|
||||
|
||||
|
||||
# 获取返回数据里面的 新闻id
|
||||
news_id = data.get('id', "")
|
||||
adr = jx_adr.replace("news_id", news_id)
|
||||
|
|
@ -48,7 +56,67 @@ def write_to_news(data):
|
|||
return
|
||||
print(f"新闻id:{news_id} 得分:{news_score}, 调用精选接口成功")
|
||||
|
||||
|
||||
# 某些自建字段由python脚本自行更新 by 朱思南
|
||||
conn = pymysql.connect(
|
||||
host=MYSQL_HOST_APP,
|
||||
port=MYSQL_PORT_APP,
|
||||
user=MYSQL_USER_APP,
|
||||
password=MYSQL_PASSWORD_APP,
|
||||
db=MYSQL_DB_APP,
|
||||
charset='utf8mb4'
|
||||
)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
sql = '''UPDATE news
|
||||
set overseas_event = %s,
|
||||
overseas_macro = %s,
|
||||
china_macro = %s,
|
||||
industry_news = %s,
|
||||
company_news = %s,
|
||||
reprint_source = %s,
|
||||
company_name = %s,
|
||||
etf_labels = %s,
|
||||
etf_names = %s,
|
||||
tdx_industry = %s,
|
||||
tdx_industry_confidence = %s,
|
||||
stock_codes = %s,
|
||||
stock_names = %s,
|
||||
political_sensitivity = %s,
|
||||
export_domestic = %s,
|
||||
political_notes = %s,
|
||||
additional_notes = %s,
|
||||
llm_abstract = %s
|
||||
WHERE newsinfo_id = %s '''
|
||||
values = (data.get("overseas_event", None),
|
||||
data.get("overseas_macro", None),
|
||||
data.get("china_macro", None),
|
||||
data.get("industry_news", None),
|
||||
data.get("company_news", None),
|
||||
data.get("reprint_source", None),
|
||||
data.get("company_name", None),
|
||||
prepare_db_value(data.get("etf_labels", [])),
|
||||
prepare_db_value(data.get("etf_names", [])),
|
||||
prepare_db_value(data.get("tdx_industry", [])),
|
||||
prepare_db_value(data.get("tdx_industry_confidence", [])),
|
||||
prepare_db_value(data.get("stock_codes", [])),
|
||||
prepare_db_value(data.get("stock_names", [])),
|
||||
data.get("political_sensitivity", None),
|
||||
data.get("export_domestic", None),
|
||||
data.get("political_notes", None),
|
||||
data.get("additional_notes", None),
|
||||
data.get("llm_abstract", None),
|
||||
news_id)
|
||||
cursor.execute(sql, values)
|
||||
if cursor.rowcount == 0:
|
||||
print(f'warning: newsinfo_id={news_id} 不存在,更新 0 行')
|
||||
conn.commit()
|
||||
print(f"新闻id:{news_id} 海外事件置信度得分:{data.get('overseas_event', None)}, 关联ETF名称列表:{data.get('etf_names', None)}, 审核政治敏感度:{data.get('political_sensitivity', None)}. 自建字段写入精选news表成功!")
|
||||
except Exception as e:
|
||||
print(f"自建字段写入精选news表失败: {str(e)}")
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def write_to_es(data):
|
||||
"""写入ES"""
|
||||
|
|
@ -67,9 +135,9 @@ def write_to_es(data):
|
|||
"abstract": data.get('abstract', ""),
|
||||
"title": data.get('title', ""),
|
||||
"rewrite_content": data.get('rewrite_content', ""),
|
||||
"industry_label": json.dumps(data.get('industry_label', [])),
|
||||
"industry_label": data.get('industry_label', []),
|
||||
"industry_confidence": data.get('industry_confidence', []),
|
||||
"industry_score": data.get('industry_score', ""),
|
||||
"industry_score": data.get('industry_score', []),
|
||||
"concept_label": data.get('concept_label', []),
|
||||
"concept_confidence": data.get('concept_confidence', []),
|
||||
"concept_score": data.get('concept_score', []),
|
||||
|
|
@ -104,29 +172,35 @@ def write_to_mysql(data):
|
|||
# 新增JSON结构解析逻辑
|
||||
# 修改后的SQL语句
|
||||
sql = """INSERT INTO news_tags
|
||||
(abstract, title, rewrite_content, industry_label, industry_confidence, industry_score, concept_label, concept_confidence, concept_score, public_opinion_score, China_factor, source, source_impact, news_score, news_id)
|
||||
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """
|
||||
(abstract, title, rewrite_content, industry_label, industry_confidence, industry_score, concept_label, concept_confidence, concept_score, public_opinion_score, China_factor, source, source_impact, news_score, news_id, news_score_exp, cluster_score, center_news_id)
|
||||
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """
|
||||
|
||||
values = (data.get('abstract', ""),
|
||||
data.get('title', ""),
|
||||
data.get('rewrite_content', ""),
|
||||
json.dumps(data.get('industry_label', [])),
|
||||
json.dumps(data.get('industry_confidence', [])),
|
||||
json.dumps(data.get('industry_score', [])),
|
||||
json.dumps(data.get('concept_label', [])),
|
||||
json.dumps(data.get('concept_confidence', [])),
|
||||
json.dumps(data.get('concept_score', [])),
|
||||
prepare_db_value(data.get('industry_label', [])),
|
||||
prepare_db_value(data.get('industry_confidence', [])),
|
||||
prepare_db_value(data.get('industry_score', [])),
|
||||
prepare_db_value(data.get('concept_label', [])),
|
||||
prepare_db_value(data.get('concept_confidence', [])),
|
||||
prepare_db_value(data.get('concept_score', [])),
|
||||
data.get('public_opinion_score', 10),
|
||||
data.get('China_factor', 0.1),
|
||||
data.get('source', "其他"),
|
||||
data.get('source_impact', 5),
|
||||
data.get('news_score', 0.0),
|
||||
data.get('id', "")
|
||||
data.get('id', ""),
|
||||
data.get('news_score_exp', 5.0),
|
||||
data.get('cluster_score', 5.0),
|
||||
data.get('center_news_id', None)
|
||||
)
|
||||
cursor.execute(sql, values)
|
||||
conn.commit()
|
||||
abstract = data.get('abstract', "")
|
||||
print(f"{abstract}, 写入news_tags 表成功")
|
||||
id = data.get('id', "")
|
||||
industry_label = data.get('industry_label', [])
|
||||
concept_label = data.get('concept_label', [])
|
||||
overseas_event = data.get('overseas_event', None)
|
||||
print(f"{id} {industry_label} {concept_label} {data.get('news_score', 0.0)} {data.get('news_score_exp', 5.0)} 写入news_tags 表成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"写入news_tags失败: {str(e)}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue