#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 申万行业热度数据回溯补录工具 读取本地CSV文件,存入MySQL sw_heat_daily表 用于历史数据补录或重新导入 """ import pandas as pd import pymysql import os import re from datetime import datetime import argparse def extract_date_from_filename(filename): """ 从文件名提取日期 例如: heat_ranking_2026-03-04_5d.csv -> 2026-03-04 -> 20260304 """ # 匹配日期格式 YYYY-MM-DD pattern = r'(\d{4}-\d{2}-\d{2})' match = re.search(pattern, filename) if match: date_str = match.group(1) # 2026-03-04 # 转换为 YYYYMMDD 格式 calc_date = date_str.replace('-', '') # 20260304 return calc_date, date_str else: raise ValueError(f"无法从文件名提取日期: {filename}") def read_heat_csv(file_path): """ 读取热度排名CSV文件 """ if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") print(f"📂 读取文件: {file_path}") df = pd.read_csv(file_path, encoding='utf-8-sig') # 检查必要列是否存在 required_cols = ['code', 'name', 'total_return', 'avg_turnover', 'volume_ratio', 'momentum', 'total_return_score', 'avg_turnover_score', 'volume_ratio_score', 'momentum_score', 'heat_score'] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: # 尝试兼容旧版列名(如果有的话) col_mapping = { '行业代码': 'code', '行业名称': 'name', # 添加其他可能的映射 } for old_col, new_col in col_mapping.items(): if old_col in df.columns and new_col not in df.columns: df[new_col] = df[old_col] # 再次检查 still_missing = [col for col in required_cols if col not in df.columns] if still_missing: raise ValueError(f"CSV文件缺少必要列: {still_missing}") print(f"✅ 成功读取 {len(df)} 条行业数据") return df def save_to_mysql(df, calc_date, db_config, dry_run=False): """ 将DataFrame存入MySQL :param calc_date: YYYYMMDD 格式的日期字符串 :param dry_run: 如果为True,只打印SQL不执行 """ if dry_run: print(f"\n🔍 [DRY RUN] 预览模式,不实际写入数据库") print(f"📅 目标日期: {calc_date}") print(f"📊 数据预览:") print(df[['code', 'name', 'heat_score']].head()) return True conn = None try: conn = pymysql.connect(**db_config) cursor = conn.cursor() # 检查该日期是否已有数据 cursor.execute( "SELECT COUNT(*) FROM sw_heat_daily WHERE calc_date = %s", (calc_date,) ) existing = cursor.fetchone()[0] if existing > 0: print(f"⚠️ 日期 {calc_date} 已有 {existing} 条记录,将执行覆盖更新") # 准备插入SQL insert_sql = """ INSERT INTO sw_heat_daily (calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum, total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE name = VALUES(name), total_return = VALUES(total_return), avg_turnover = VALUES(avg_turnover), volume_ratio = VALUES(volume_ratio), momentum = VALUES(momentum), total_return_score = VALUES(total_return_score), avg_turnover_score = VALUES(avg_turnover_score), volume_ratio_score = VALUES(volume_ratio_score), momentum_score = VALUES(momentum_score), heat_score = VALUES(heat_score), created_at = CURRENT_TIMESTAMP; """ # 转换数据为列表元组 data_tuples = [] for _, row in df.iterrows(): data_tuples.append(( calc_date, str(row['code']), str(row['name']), float(row['total_return']), float(row['avg_turnover']), float(row['volume_ratio']), float(row['momentum']), float(row['total_return_score']), float(row['avg_turnover_score']), float(row['volume_ratio_score']), float(row['momentum_score']), float(row['heat_score']) )) # 执行批量插入 cursor.executemany(insert_sql, data_tuples) conn.commit() action = "覆盖更新" if existing > 0 else "新增" print(f"✅ MySQL{action}成功: {calc_date} 共 {len(data_tuples)} 条记录") return True except Exception as e: print(f"❌ MySQL操作失败: {e}") if conn: conn.rollback() return False finally: if conn: conn.close() def batch_import(data_dir, db_config, pattern=None, dry_run=False): """ 批量导入目录下所有符合格式的CSV文件 """ if pattern is None: pattern = r'heat_ranking_\d{4}-\d{2}-\d{2}_.*\.csv' files = [f for f in os.listdir(data_dir) if re.match(pattern, f)] files.sort() print(f"📁 发现 {len(files)} 个待导入文件") success_count = 0 for filename in files: file_path = os.path.join(data_dir, filename) try: calc_date, _ = extract_date_from_filename(filename) df = read_heat_csv(file_path) if save_to_mysql(df, calc_date, db_config, dry_run): success_count += 1 except Exception as e: print(f"❌ 处理 {filename} 失败: {e}") continue print(f"\n📊 导入完成: {success_count}/{len(files)} 个文件成功") def main(): # 数据库配置(请根据实际情况修改) config = { 'host': '10.127.2.207', 'user': 'financial_prod', 'password': 'mmTFncqmDal5HLRGY0BV', 'database': 'reference', 'port': 3306, 'charset': 'utf8mb4' } # 命令行参数解析 parser = argparse.ArgumentParser(description='申万行业热度数据回溯补录工具') parser.add_argument('file', nargs='?', help='单个CSV文件路径') parser.add_argument('--date', '-d', help='指定日期(YYYYMMDD),覆盖文件名中的日期') parser.add_argument('--batch', '-b', action='store_true', help='批量导入目录下所有文件') parser.add_argument('--dir', default='./sw_data', help='数据目录路径(默认: ./sw_data)') parser.add_argument('--dry-run', action='store_true', help='预览模式,不实际写入数据库') args = parser.parse_args() # 批量模式 if args.batch: batch_import(args.dir, config, dry_run=args.dry_run) return # 单文件模式 if args.file: file_path = args.file else: # 默认示例文件 file_path = './sw_data/heat_ranking_2026-03-04_5d.csv' if not os.path.exists(file_path): print("请提供CSV文件路径,或使用 --batch 批量导入") parser.print_help() return try: # 读取CSV df = read_heat_csv(file_path) # 确定日期 if args.date: calc_date = args.date # 命令行指定的日期 print(f"📅 使用指定日期: {calc_date}") else: calc_date, original_date = extract_date_from_filename(file_path) print(f"📅 从文件名提取日期: {original_date} -> {calc_date}") # 存入MySQL save_to_mysql(df, calc_date, config, dry_run=args.dry_run) except Exception as e: print(f"❌ 错误: {e}") raise if __name__ == "__main__": main()