zzck_code/industry_heat_task.py

405 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import akshare as ak
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import os
import time
import pymysql
import warnings
warnings.filterwarnings('ignore')
class SWIndustryHeatTracker:
"""
申万一级行业热度追踪器
使用 index_analysis_daily_sw 接口(收盘后官方数据)
支持自定义热度计算周期和权重数据自动存入MySQL
"""
def __init__(self, data_dir="./sw_data", lookback_days=5, db_config=None):
self.data_dir = data_dir
self.lookback_days = lookback_days
self.db_config = db_config # MySQL配置
os.makedirs(data_dir, exist_ok=True)
# 初始化交易日历
self.trading_dates = self._load_trading_calendar()
# 申万一级行业代码映射31个行业
self.industry_map = {
'801010': '农林牧渔', '801030': '基础化工', '801040': '钢铁',
'801050': '有色金属', '801080': '电子', '801110': '家用电器',
'801120': '食品饮料', '801130': '纺织服饰', '801140': '轻工制造',
'801150': '医药生物', '801160': '公用事业', '801170': '交通运输',
'801180': '房地产', '801200': '商贸零售', '801210': '社会服务',
'801230': '综合', '801710': '建筑材料', '801720': '建筑装饰',
'801730': '电力设备', '801740': '国防军工', '801750': '计算机',
'801760': '传媒', '801770': '通信', '801780': '银行',
'801790': '非银金融', '801880': '汽车', '801890': '机械设备',
'801950': '煤炭', '801960': '石油石化', '801970': '环保',
'801980': '美容护理'
}
# 热度指标权重配置
self.weights = {
'return': 0.35,
'turnover': 0.25,
'volume_ratio': 0.25,
'momentum': 0.15
}
# 初始化数据库表
if self.db_config:
self._init_database()
def _init_database(self):
"""初始化MySQL数据库和表"""
conn = None
try:
conn = pymysql.connect(**self.db_config)
with conn.cursor() as cursor:
# 创建表(如果不存在)
create_table_sql = """
CREATE TABLE IF NOT EXISTS sw_heat_daily (
id INT AUTO_INCREMENT PRIMARY KEY,
calc_date VARCHAR(8) NOT NULL COMMENT '计算日期(YYYYMMDD)',
code VARCHAR(10) NOT NULL COMMENT '行业代码',
name VARCHAR(20) NOT NULL COMMENT '行业名称',
total_return DECIMAL(10,4) COMMENT '累计涨跌幅(%)',
avg_turnover DECIMAL(10,4) COMMENT '平均换手率(%)',
volume_ratio DECIMAL(10,4) COMMENT '量比',
momentum DECIMAL(5,4) COMMENT '动量(上涨天数占比)',
total_return_score DECIMAL(5,4) COMMENT '收益因子得分(0-1)',
avg_turnover_score DECIMAL(5,4) COMMENT '换手因子得分(0-1)',
volume_ratio_score DECIMAL(5,4) COMMENT '量比因子得分(0-1)',
momentum_score DECIMAL(5,4) COMMENT '动量因子得分(0-1)',
heat_score DECIMAL(6,2) COMMENT '热度综合得分(0-100)',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_date_code (calc_date, code),
INDEX idx_date (calc_date),
INDEX idx_heat_score (calc_date, heat_score)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='申万一级行业热度日表';
"""
cursor.execute(create_table_sql)
conn.commit()
print("✅ 数据库表 sw_heat_daily 初始化成功")
except Exception as e:
print(f"❌ 数据库初始化失败: {e}")
raise
finally:
if conn:
conn.close()
def _load_trading_calendar(self):
"""加载交易日历"""
cache_file = f"{self.data_dir}/trading_calendar.csv"
if os.path.exists(cache_file):
df = pd.read_csv(cache_file)
dates = df['trade_date'].astype(str).tolist()
cache_time = datetime.fromtimestamp(os.path.getmtime(cache_file))
if (datetime.now() - cache_time).days > 7:
self._update_calendar_cache(cache_file)
else:
dates = self._update_calendar_cache(cache_file)
return set(dates)
def _update_calendar_cache(self, cache_file):
"""更新交易日历缓存"""
print("正在更新交易日历...")
df = ak.tool_trade_date_hist_sina()
df.to_csv(cache_file, index=False)
return df['trade_date'].astype(str).tolist()
def get_recent_trading_dates(self, n_days, end_date=None):
"""获取最近N个交易日"""
if end_date is None:
end_date = datetime.now().strftime('%Y-%m-%d')
valid_dates = [d for d in self.trading_dates if d <= end_date]
return sorted(valid_dates)[-n_days:]
def fetch_daily_data(self, trade_date):
"""获取指定交易日的申万行业日报数据"""
date_str = trade_date.replace('-', '')
try:
df = ak.index_analysis_daily_sw(
symbol="一级行业",
start_date=date_str,
end_date=date_str
)
if df.empty:
print(f"⚠️ {trade_date} 无数据")
return None
df = df.rename(columns={
'指数代码': 'code',
'指数名称': 'name',
'发布日期': 'date',
'收盘指数': 'close',
'涨跌幅': 'return',
'换手率': 'turnover',
'成交量': 'volume',
'成交额': 'amount',
'市盈率': 'pe',
'市净率': 'pb'
})
df['trade_date'] = trade_date
df['code'] = df['code'].astype(str)
return df
except Exception as e:
print(f"❌ 获取 {trade_date} 数据失败: {e}")
return None
def fetch_multi_days_data(self, n_days=None):
"""获取最近N天的数据"""
if n_days is None:
n_days = self.lookback_days + 5
dates = self.get_recent_trading_dates(n_days)
print(f"📅 获取数据区间: {dates[0]}{dates[-1]}")
all_data = []
for date in dates:
df = self.fetch_daily_data(date)
if df is not None:
all_data.append(df)
time.sleep(0.5)
if not all_data:
return None
combined = pd.concat(all_data, ignore_index=True)
combined['trade_date'] = pd.to_datetime(combined['trade_date'])
return combined.sort_values(['code', 'trade_date'])
def calculate_heat_score(self, df, lookback_days=None):
"""计算板块热度复合指标"""
if lookback_days is None:
lookback_days = self.lookback_days
latest_dates = df['trade_date'].unique()[-lookback_days:]
recent_df = df[df['trade_date'].isin(latest_dates)].copy()
all_dates = df['trade_date'].unique()
if len(all_dates) >= lookback_days * 2:
hist_dates = all_dates[-lookback_days*2:-lookback_days]
hist_df = df[df['trade_date'].isin(hist_dates)]
else:
hist_df = None
heat_data = []
for code in recent_df['code'].unique():
industry_data = recent_df[recent_df['code'] == code]
if len(industry_data) < lookback_days:
continue
name = industry_data['name'].iloc[0]
total_return = industry_data['return'].sum()
avg_turnover = industry_data['turnover'].mean()
recent_volume = industry_data['volume'].mean()
if hist_df is not None and code in hist_df['code'].values:
hist_volume = hist_df[hist_df['code'] == code]['volume'].mean()
volume_ratio = recent_volume / hist_volume if hist_volume > 0 else 1.0
else:
volume_ratio = 1.0
up_days = (industry_data['return'] > 0).sum()
momentum = up_days / lookback_days
heat_data.append({
'code': code,
'name': name,
'total_return': total_return,
'avg_turnover': avg_turnover,
'volume_ratio': volume_ratio,
'momentum': momentum,
'latest_close': industry_data['close'].iloc[-1],
'latest_return': industry_data['return'].iloc[-1],
'up_days': up_days
})
heat_df = pd.DataFrame(heat_data)
for col in ['total_return', 'avg_turnover', 'volume_ratio', 'momentum']:
heat_df[f'{col}_score'] = heat_df[col].rank(pct=True)
heat_df['heat_score'] = (
heat_df['total_return_score'] * self.weights['return'] +
heat_df['avg_turnover_score'] * self.weights['turnover'] +
heat_df['volume_ratio_score'] * self.weights['volume_ratio'] +
heat_df['momentum_score'] * self.weights['momentum']
) * 100
heat_df['heat_level'] = pd.cut(
heat_df['heat_score'],
bins=[0, 20, 40, 60, 80, 100],
labels=['极冷', '偏冷', '温和', '偏热', '极热']
)
heat_df['rank'] = heat_df['heat_score'].rank(ascending=False, method='min').astype(int)
return heat_df.sort_values('heat_score', ascending=False)
def save_to_mysql(self, heat_df, calc_date):
"""
将热度数据存入MySQL
:param heat_df: 热度数据DataFrame
:param calc_date: 计算日期,格式'YYYYMMDD'
"""
if not self.db_config:
print("⚠️ 未配置数据库跳过MySQL存储")
return False
conn = None
try:
conn = pymysql.connect(**self.db_config)
cursor = conn.cursor()
# 准备插入数据
insert_sql = """
INSERT INTO sw_heat_daily
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
name = VALUES(name),
total_return = VALUES(total_return),
avg_turnover = VALUES(avg_turnover),
volume_ratio = VALUES(volume_ratio),
momentum = VALUES(momentum),
total_return_score = VALUES(total_return_score),
avg_turnover_score = VALUES(avg_turnover_score),
volume_ratio_score = VALUES(volume_ratio_score),
momentum_score = VALUES(momentum_score),
heat_score = VALUES(heat_score),
created_at = CURRENT_TIMESTAMP;
"""
# 转换数据为列表元组
data_tuples = []
for _, row in heat_df.iterrows():
data_tuples.append((
calc_date,
row['code'],
row['name'],
float(row['total_return']),
float(row['avg_turnover']),
float(row['volume_ratio']),
float(row['momentum']),
float(row['total_return_score']),
float(row['avg_turnover_score']),
float(row['volume_ratio_score']),
float(row['momentum_score']),
float(row['heat_score'])
))
# 批量插入
cursor.executemany(insert_sql, data_tuples)
conn.commit()
print(f"✅ MySQL存储成功: {calc_date}{len(data_tuples)} 条记录")
return True
except Exception as e:
print(f"❌ MySQL存储失败: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
def generate_report(self, save=True, to_mysql=True):
"""生成热度分析报告并存储"""
calc_date = datetime.now().strftime('%Y%m%d') # YYYYMMDD格式
print("="*80)
print(f"🔥 申万一级行业热度报告 ({self.lookback_days}日复合指标)")
print(f"📅 计算日期: {calc_date}")
print(f"⏰ 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*80)
# 获取数据
raw_data = self.fetch_multi_days_data()
if raw_data is None:
print("❌ 数据获取失败")
return None
# 计算热度
heat_df = self.calculate_heat_score(raw_data)
# 打印报告(省略,与之前相同)...
print(f"\n🏆 热度 TOP 10 (近{self.lookback_days}日):")
print("-"*80)
top10 = heat_df.head(10)[['rank', 'name', 'heat_score', 'heat_level',
'total_return', 'avg_turnover', 'volume_ratio', 'momentum']]
for _, row in top10.iterrows():
print(f"{row['rank']:2d}. {row['name']:8s} | "
f"热度:{row['heat_score']:5.1f}({row['heat_level']}) | "
f"收益:{row['total_return']:+6.2f}% | "
f"换手:{row['avg_turnover']:5.2f}% | "
f"量比:{row['volume_ratio']:4.2f} | "
f"动量:{row['momentum']:.0%}")
# 保存数据部分
if save:
today = datetime.now().strftime('%Y-%m-%d')
# # 1. 保存JSON可选用于调试
# result_file = f"{self.data_dir}/heat_report_{today}_{self.lookback_days}d.json"
# heat_df.to_json(result_file, orient='records', force_ascii=False, indent=2)
# print(f"\n💾 JSON报告已保存: {result_file}")
# 2. 保存CSV可选便于Excel查看
csv_file = f"{self.data_dir}/heat_ranking_{today}_{self.lookback_days}d.csv"
heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig')
csv_file = f"{self.data_dir}/heat_ranking_newest_{self.lookback_days}d.csv"
heat_df.to_csv(csv_file, index=False, encoding='utf-8-sig')
print(f"💾 CSV排名已保存: {csv_file}")
# 3. 存入MySQL新增
if to_mysql and self.db_config:
self.save_to_mysql(heat_df, calc_date)
return heat_df
# 使用示例
if __name__ == "__main__":
# 数据库配置
# config = {
# 'host': 'localhost',
# 'port': 3306,
# 'user': 'your_username',
# 'password': 'your_password',
# 'database': 'your_database',
# 'charset': 'utf8mb4'
# }
config = {
'host': '10.127.2.207',
'user': 'financial_prod',
'password': 'mmTFncqmDal5HLRGY0BV',
'database': 'reference',
'port': 3306,
'charset': 'utf8mb4'
}
# 创建追踪器,传入数据库配置
tracker = SWIndustryHeatTracker(
lookback_days=5,
db_config=config # 传入数据库配置
)
# 生成报告并自动存入MySQL
result = tracker.generate_report(save=True, to_mysql=True)