zzck_code/backfill_mysql.py

244 lines
8.2 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
申万行业热度数据回溯补录工具
读取本地CSV文件存入MySQL sw_heat_daily表
用于历史数据补录或重新导入
"""
import pandas as pd
import pymysql
import os
import re
from datetime import datetime
import argparse
def extract_date_from_filename(filename):
"""
从文件名提取日期
例如: heat_ranking_2026-03-04_5d.csv -> 2026-03-04 -> 20260304
"""
# 匹配日期格式 YYYY-MM-DD
pattern = r'(\d{4}-\d{2}-\d{2})'
match = re.search(pattern, filename)
if match:
date_str = match.group(1) # 2026-03-04
# 转换为 YYYYMMDD 格式
calc_date = date_str.replace('-', '') # 20260304
return calc_date, date_str
else:
raise ValueError(f"无法从文件名提取日期: {filename}")
def read_heat_csv(file_path):
"""
读取热度排名CSV文件
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
print(f"📂 读取文件: {file_path}")
df = pd.read_csv(file_path, encoding='utf-8-sig')
# 检查必要列是否存在
required_cols = ['code', 'name', 'total_return', 'avg_turnover',
'volume_ratio', 'momentum', 'total_return_score',
'avg_turnover_score', 'volume_ratio_score',
'momentum_score', 'heat_score']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
# 尝试兼容旧版列名(如果有的话)
col_mapping = {
'行业代码': 'code',
'行业名称': 'name',
# 添加其他可能的映射
}
for old_col, new_col in col_mapping.items():
if old_col in df.columns and new_col not in df.columns:
df[new_col] = df[old_col]
# 再次检查
still_missing = [col for col in required_cols if col not in df.columns]
if still_missing:
raise ValueError(f"CSV文件缺少必要列: {still_missing}")
print(f"✅ 成功读取 {len(df)} 条行业数据")
return df
def save_to_mysql(df, calc_date, db_config, dry_run=False):
"""
将DataFrame存入MySQL
:param calc_date: YYYYMMDD 格式的日期字符串
:param dry_run: 如果为True只打印SQL不执行
"""
if dry_run:
print(f"\n🔍 [DRY RUN] 预览模式,不实际写入数据库")
print(f"📅 目标日期: {calc_date}")
print(f"📊 数据预览:")
print(df[['code', 'name', 'heat_score']].head())
return True
conn = None
try:
conn = pymysql.connect(**db_config)
cursor = conn.cursor()
# 检查该日期是否已有数据
cursor.execute(
"SELECT COUNT(*) FROM sw_heat_daily WHERE calc_date = %s",
(calc_date,)
)
existing = cursor.fetchone()[0]
if existing > 0:
print(f"⚠️ 日期 {calc_date} 已有 {existing} 条记录,将执行覆盖更新")
# 准备插入SQL
insert_sql = """
INSERT INTO sw_heat_daily
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
name = VALUES(name),
total_return = VALUES(total_return),
avg_turnover = VALUES(avg_turnover),
volume_ratio = VALUES(volume_ratio),
momentum = VALUES(momentum),
total_return_score = VALUES(total_return_score),
avg_turnover_score = VALUES(avg_turnover_score),
volume_ratio_score = VALUES(volume_ratio_score),
momentum_score = VALUES(momentum_score),
heat_score = VALUES(heat_score),
created_at = CURRENT_TIMESTAMP;
"""
# 转换数据为列表元组
data_tuples = []
for _, row in df.iterrows():
data_tuples.append((
calc_date,
str(row['code']),
str(row['name']),
float(row['total_return']),
float(row['avg_turnover']),
float(row['volume_ratio']),
float(row['momentum']),
float(row['total_return_score']),
float(row['avg_turnover_score']),
float(row['volume_ratio_score']),
float(row['momentum_score']),
float(row['heat_score'])
))
# 执行批量插入
cursor.executemany(insert_sql, data_tuples)
conn.commit()
action = "覆盖更新" if existing > 0 else "新增"
print(f"✅ MySQL{action}成功: {calc_date}{len(data_tuples)} 条记录")
return True
except Exception as e:
print(f"❌ MySQL操作失败: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
def batch_import(data_dir, db_config, pattern=None, dry_run=False):
"""
批量导入目录下所有符合格式的CSV文件
"""
if pattern is None:
pattern = r'heat_ranking_\d{4}-\d{2}-\d{2}_.*\.csv'
files = [f for f in os.listdir(data_dir) if re.match(pattern, f)]
files.sort()
print(f"📁 发现 {len(files)} 个待导入文件")
success_count = 0
for filename in files:
file_path = os.path.join(data_dir, filename)
try:
calc_date, _ = extract_date_from_filename(filename)
df = read_heat_csv(file_path)
if save_to_mysql(df, calc_date, db_config, dry_run):
success_count += 1
except Exception as e:
print(f"❌ 处理 {filename} 失败: {e}")
continue
print(f"\n📊 导入完成: {success_count}/{len(files)} 个文件成功")
def main():
# 数据库配置(请根据实际情况修改)
config = {
'host': '10.127.2.207',
'user': 'financial_prod',
'password': 'mmTFncqmDal5HLRGY0BV',
'database': 'reference',
'port': 3306,
'charset': 'utf8mb4'
}
# 命令行参数解析
parser = argparse.ArgumentParser(description='申万行业热度数据回溯补录工具')
parser.add_argument('file', nargs='?', help='单个CSV文件路径')
parser.add_argument('--date', '-d', help='指定日期(YYYYMMDD),覆盖文件名中的日期')
parser.add_argument('--batch', '-b', action='store_true', help='批量导入目录下所有文件')
parser.add_argument('--dir', default='./sw_data', help='数据目录路径(默认: ./sw_data)')
parser.add_argument('--dry-run', action='store_true', help='预览模式,不实际写入数据库')
args = parser.parse_args()
# 批量模式
if args.batch:
batch_import(args.dir, config, dry_run=args.dry_run)
return
# 单文件模式
if args.file:
file_path = args.file
else:
# 默认示例文件
file_path = './sw_data/heat_ranking_2026-03-04_5d.csv'
if not os.path.exists(file_path):
print("请提供CSV文件路径或使用 --batch 批量导入")
parser.print_help()
return
try:
# 读取CSV
df = read_heat_csv(file_path)
# 确定日期
if args.date:
calc_date = args.date # 命令行指定的日期
print(f"📅 使用指定日期: {calc_date}")
else:
calc_date, original_date = extract_date_from_filename(file_path)
print(f"📅 从文件名提取日期: {original_date} -> {calc_date}")
# 存入MySQL
save_to_mysql(df, calc_date, config, dry_run=args.dry_run)
except Exception as e:
print(f"❌ 错误: {e}")
raise
if __name__ == "__main__":
main()