import akshare as ak import pandas as pd import numpy as np from datetime import datetime, timedelta import os import time import pymysql import warnings warnings.filterwarnings('ignore') class SWIndustryHeatTracker: """ 申万一级行业热度追踪器 使用 index_analysis_daily_sw 接口(收盘后官方数据) 支持自定义热度计算周期和权重,数据自动存入MySQL """ def __init__(self, data_dir="./sw_data", lookback_days=5, db_config=None): self.data_dir = data_dir self.lookback_days = lookback_days self.db_config = db_config # MySQL配置 os.makedirs(data_dir, exist_ok=True) # 初始化交易日历 self.trading_dates = self._load_trading_calendar() # 申万一级行业代码映射(31个行业) self.industry_map = { '801010': '农林牧渔', '801030': '基础化工', '801040': '钢铁', '801050': '有色金属', '801080': '电子', '801110': '家用电器', '801120': '食品饮料', '801130': '纺织服饰', '801140': '轻工制造', '801150': '医药生物', '801160': '公用事业', '801170': '交通运输', '801180': '房地产', '801200': '商贸零售', '801210': '社会服务', '801230': '综合', '801710': '建筑材料', '801720': '建筑装饰', '801730': '电力设备', '801740': '国防军工', '801750': '计算机', '801760': '传媒', '801770': '通信', '801780': '银行', '801790': '非银金融', '801880': '汽车', '801890': '机械设备', '801950': '煤炭', '801960': '石油石化', '801970': '环保', '801980': '美容护理' } # 热度指标权重配置 self.weights = { 'return': 0.35, 'turnover': 0.25, 'volume_ratio': 0.25, 'momentum': 0.15 } # 初始化数据库表 # if self.db_config: # self._init_database() def _init_database(self): """初始化MySQL数据库和表""" conn = None try: conn = pymysql.connect(**self.db_config) with conn.cursor() as cursor: # 创建表(如果不存在) create_table_sql = """ CREATE TABLE IF NOT EXISTS sw_heat_daily ( id INT AUTO_INCREMENT PRIMARY KEY, calc_date VARCHAR(8) NOT NULL COMMENT '计算日期(YYYYMMDD)', code VARCHAR(10) NOT NULL COMMENT '行业代码', name VARCHAR(20) NOT NULL COMMENT '行业名称', total_return DECIMAL(10,4) COMMENT '累计涨跌幅(%)', avg_turnover DECIMAL(10,4) COMMENT '平均换手率(%)', volume_ratio DECIMAL(10,4) COMMENT '量比', momentum DECIMAL(5,4) COMMENT '动量(上涨天数占比)', total_return_score DECIMAL(5,4) COMMENT '收益因子得分(0-1)', avg_turnover_score DECIMAL(5,4) COMMENT '换手因子得分(0-1)', volume_ratio_score DECIMAL(5,4) COMMENT '量比因子得分(0-1)', momentum_score DECIMAL(5,4) COMMENT '动量因子得分(0-1)', heat_score DECIMAL(6,2) COMMENT '热度综合得分(0-100)', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, UNIQUE KEY uk_date_code (calc_date, code), INDEX idx_date (calc_date), INDEX idx_heat_score (calc_date, heat_score) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='申万一级行业热度日表'; """ cursor.execute(create_table_sql) conn.commit() print("✅ 数据库表 sw_heat_daily 初始化成功") except Exception as e: print(f"❌ 数据库初始化失败: {e}") raise finally: if conn: conn.close() def _load_trading_calendar(self): """加载交易日历""" cache_file = f"{self.data_dir}/trading_calendar.csv" if os.path.exists(cache_file): df = pd.read_csv(cache_file) dates = df['trade_date'].astype(str).tolist() cache_time = datetime.fromtimestamp(os.path.getmtime(cache_file)) if (datetime.now() - cache_time).days > 7: self._update_calendar_cache(cache_file) else: dates = self._update_calendar_cache(cache_file) return set(dates) def is_trading_day(self, date=None): """判断是否为交易日""" if date is None: date = datetime.now() if isinstance(date, datetime): date = date.strftime('%Y-%m-%d') return date in self.trading_dates def _update_calendar_cache(self, cache_file): """更新交易日历缓存""" print("正在更新交易日历...") df = ak.tool_trade_date_hist_sina() df.to_csv(cache_file, index=False) return df['trade_date'].astype(str).tolist() def get_recent_trading_dates(self, n_days, end_date=None): """获取最近N个交易日""" if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') valid_dates = [d for d in self.trading_dates if d <= end_date] return sorted(valid_dates)[-n_days:] def fetch_daily_data(self, trade_date): """获取指定交易日的申万行业日报数据""" date_str = trade_date.replace('-', '') try: df = ak.index_analysis_daily_sw( symbol="一级行业", start_date=date_str, end_date=date_str ) if df.empty: print(f"⚠️ {trade_date} 无数据") return None df = df.rename(columns={ '指数代码': 'code', '指数名称': 'name', '发布日期': 'date', '收盘指数': 'close', '涨跌幅': 'return', '换手率': 'turnover', '成交量': 'volume', '成交额': 'amount', '市盈率': 'pe', '市净率': 'pb' }) df['trade_date'] = trade_date df['code'] = df['code'].astype(str) return df except Exception as e: print(f"❌ 获取 {trade_date} 数据失败: {e}") return None def fetch_multi_days_data(self, n_days=None): """获取最近N天的数据""" if n_days is None: n_days = self.lookback_days + 5 dates = self.get_recent_trading_dates(n_days) print(f"📅 获取数据区间: {dates[0]} 至 {dates[-1]}") all_data = [] for date in dates: df = self.fetch_daily_data(date) if df is not None: all_data.append(df) time.sleep(0.5) if not all_data: return None combined = pd.concat(all_data, ignore_index=True) combined['trade_date'] = pd.to_datetime(combined['trade_date']) return combined.sort_values(['code', 'trade_date']) def calculate_heat_score(self, df, lookback_days=None): """计算板块热度复合指标""" if lookback_days is None: lookback_days = self.lookback_days latest_dates = df['trade_date'].unique()[-lookback_days:] recent_df = df[df['trade_date'].isin(latest_dates)].copy() all_dates = df['trade_date'].unique() if len(all_dates) >= lookback_days * 2: hist_dates = all_dates[-lookback_days*2:-lookback_days] hist_df = df[df['trade_date'].isin(hist_dates)] else: hist_df = None heat_data = [] for code in recent_df['code'].unique(): industry_data = recent_df[recent_df['code'] == code] if len(industry_data) < lookback_days: continue name = industry_data['name'].iloc[0] total_return = industry_data['return'].sum() avg_turnover = industry_data['turnover'].mean() recent_volume = industry_data['volume'].mean() if hist_df is not None and code in hist_df['code'].values: hist_volume = hist_df[hist_df['code'] == code]['volume'].mean() volume_ratio = recent_volume / hist_volume if hist_volume > 0 else 1.0 else: volume_ratio = 1.0 up_days = (industry_data['return'] > 0).sum() momentum = up_days / lookback_days heat_data.append({ 'code': code, 'name': name, 'total_return': total_return, 'avg_turnover': avg_turnover, 'volume_ratio': volume_ratio, 'momentum': momentum, 'latest_close': industry_data['close'].iloc[-1], 'latest_return': industry_data['return'].iloc[-1], 'up_days': up_days }) heat_df = pd.DataFrame(heat_data) for col in ['total_return', 'avg_turnover', 'volume_ratio', 'momentum']: heat_df[f'{col}_score'] = heat_df[col].rank(pct=True) heat_df['heat_score'] = ( heat_df['total_return_score'] * self.weights['return'] + heat_df['avg_turnover_score'] * self.weights['turnover'] + heat_df['volume_ratio_score'] * self.weights['volume_ratio'] + heat_df['momentum_score'] * self.weights['momentum'] ) * 100 heat_df['heat_level'] = pd.cut( heat_df['heat_score'], bins=[0, 20, 40, 60, 80, 100], labels=['极冷', '偏冷', '温和', '偏热', '极热'] ) heat_df['rank'] = heat_df['heat_score'].rank(ascending=False, method='min').astype(int) return heat_df.sort_values('heat_score', ascending=False) def save_to_mysql(self, heat_df, calc_date): """ 将热度数据存入MySQL :param heat_df: 热度数据DataFrame :param calc_date: 计算日期,格式'YYYYMMDD' """ if not self.db_config: print("⚠️ 未配置数据库,跳过MySQL存储") return False conn = None try: conn = pymysql.connect(**self.db_config) cursor = conn.cursor() # 准备插入数据 insert_sql = """ INSERT INTO sw_heat_daily (calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum, total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE name = VALUES(name), total_return = VALUES(total_return), avg_turnover = VALUES(avg_turnover), volume_ratio = VALUES(volume_ratio), momentum = VALUES(momentum), total_return_score = VALUES(total_return_score), avg_turnover_score = VALUES(avg_turnover_score), volume_ratio_score = VALUES(volume_ratio_score), momentum_score = VALUES(momentum_score), heat_score = VALUES(heat_score), created_at = CURRENT_TIMESTAMP; """ # 转换数据为列表元组 data_tuples = [] for _, row in heat_df.iterrows(): data_tuples.append(( calc_date, row['code'], row['name'], float(row['total_return']), float(row['avg_turnover']), float(row['volume_ratio']), float(row['momentum']), float(row['total_return_score']), float(row['avg_turnover_score']), float(row['volume_ratio_score']), float(row['momentum_score']), float(row['heat_score']) )) # 批量插入 cursor.executemany(insert_sql, data_tuples) conn.commit() print(f"✅ MySQL存储成功: {calc_date} 共 {len(data_tuples)} 条记录") return True except Exception as e: print(f"❌ MySQL存储失败: {e}") if conn: conn.rollback() return False finally: if conn: conn.close() def generate_report(self, save=True, to_mysql=True): """生成热度分析报告并存储""" calc_date = datetime.now().strftime('%Y%m%d') # YYYYMMDD格式 print("="*80) print(f"🔥 申万一级行业热度报告 ({self.lookback_days}日复合指标)") print(f"📅 计算日期: {calc_date}") print(f"⏰ 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print("="*80) # 获取数据 raw_data = self.fetch_multi_days_data() if raw_data is None: print("❌ 数据获取失败") return None # 计算热度 heat_df = self.calculate_heat_score(raw_data) # 打印报告(省略,与之前相同)... print(f"\n🏆 热度 TOP 10 (近{self.lookback_days}日):") print("-"*80) top10 = heat_df.head(10)[['rank', 'name', 'heat_score', 'heat_level', 'total_return', 'avg_turnover', 'volume_ratio', 'momentum']] for _, row in top10.iterrows(): print(f"{row['rank']:2d}. {row['name']:8s} | " f"热度:{row['heat_score']:5.1f}({row['heat_level']}) | " f"收益:{row['total_return']:+6.2f}% | " f"换手:{row['avg_turnover']:5.2f}% | " f"量比:{row['volume_ratio']:4.2f} | " f"动量:{row['momentum']:.0%}") # 打印冷门板块(后10名) print(f"\n❄️ 冷门板块 (后10名):") print("-"*80) bottom10 = heat_df.tail(10) for _, row in bottom10.iterrows(): print(f"{row['rank']:2d}. {row['name']:8s} | " f"热度:{row['heat_score']:5.1f}({row['heat_level']}) | " f"收益:{row['total_return']:+6.2f}% | " f"换手:{row['avg_turnover']:5.2f}% | " f"量比:{row['volume_ratio']:4.2f} | " f"动量:{row['momentum']:.0%}") # 保存数据部分 if save: today = datetime.now().strftime('%Y-%m-%d') # # 1. 保存JSON(可选,用于调试) # result_file = f"{self.data_dir}/heat_report_{today}_{self.lookback_days}d.json" # heat_df.to_json(result_file, orient='records', force_ascii=False, indent=2) # print(f"\n💾 JSON报告已保存: {result_file}") # 2. 保存CSV(可选,便于Excel查看) csv_file = f"{self.data_dir}/heat_ranking_{today}_{self.lookback_days}d.csv" heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig') csv_file = f"{self.data_dir}/heat_ranking_newest_{self.lookback_days}d.csv" heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig') print(f"💾 CSV排名已保存: {csv_file}") # 3. 存入MySQL(新增) if to_mysql and self.db_config: self.save_to_mysql(heat_df, calc_date) return heat_df # 使用示例 if __name__ == "__main__": # 数据库配置 config = { 'host': '10.127.2.207', 'user': 'financial_prod', 'password': 'mmTFncqmDal5HLRGY0BV', 'database': 'reference', 'port': 3306, 'charset': 'utf8mb4' } # 创建追踪器,传入数据库配置 tracker = SWIndustryHeatTracker( lookback_days=5, db_config=config # 传入数据库配置 ) check_date = datetime.now().strftime('%Y-%m-%d') if not tracker.is_trading_day(check_date): print(f"\n⚠️ {check_date} 不是交易日,跳过计算\n\n") exit(0) # 生成报告并自动存入MySQL result = tracker.generate_report(save=True, to_mysql=True)