zzck_code/backfill_mysql.py

244 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
申万行业热度数据回溯补录工具
读取本地CSV文件存入MySQL sw_heat_daily表
用于历史数据补录或重新导入
"""
import pandas as pd
import pymysql
import os
import re
from datetime import datetime
import argparse
def extract_date_from_filename(filename):
"""
从文件名提取日期
例如: heat_ranking_2026-03-04_5d.csv -> 2026-03-04 -> 20260304
"""
# 匹配日期格式 YYYY-MM-DD
pattern = r'(\d{4}-\d{2}-\d{2})'
match = re.search(pattern, filename)
if match:
date_str = match.group(1) # 2026-03-04
# 转换为 YYYYMMDD 格式
calc_date = date_str.replace('-', '') # 20260304
return calc_date, date_str
else:
raise ValueError(f"无法从文件名提取日期: {filename}")
def read_heat_csv(file_path):
"""
读取热度排名CSV文件
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
print(f"📂 读取文件: {file_path}")
df = pd.read_csv(file_path, encoding='utf-8-sig')
# 检查必要列是否存在
required_cols = ['code', 'name', 'total_return', 'avg_turnover',
'volume_ratio', 'momentum', 'total_return_score',
'avg_turnover_score', 'volume_ratio_score',
'momentum_score', 'heat_score']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
# 尝试兼容旧版列名(如果有的话)
col_mapping = {
'行业代码': 'code',
'行业名称': 'name',
# 添加其他可能的映射
}
for old_col, new_col in col_mapping.items():
if old_col in df.columns and new_col not in df.columns:
df[new_col] = df[old_col]
# 再次检查
still_missing = [col for col in required_cols if col not in df.columns]
if still_missing:
raise ValueError(f"CSV文件缺少必要列: {still_missing}")
print(f"✅ 成功读取 {len(df)} 条行业数据")
return df
def save_to_mysql(df, calc_date, db_config, dry_run=False):
"""
将DataFrame存入MySQL
:param calc_date: YYYYMMDD 格式的日期字符串
:param dry_run: 如果为True只打印SQL不执行
"""
if dry_run:
print(f"\n🔍 [DRY RUN] 预览模式,不实际写入数据库")
print(f"📅 目标日期: {calc_date}")
print(f"📊 数据预览:")
print(df[['code', 'name', 'heat_score']].head())
return True
conn = None
try:
conn = pymysql.connect(**db_config)
cursor = conn.cursor()
# 检查该日期是否已有数据
cursor.execute(
"SELECT COUNT(*) FROM sw_heat_daily WHERE calc_date = %s",
(calc_date,)
)
existing = cursor.fetchone()[0]
if existing > 0:
print(f"⚠️ 日期 {calc_date} 已有 {existing} 条记录,将执行覆盖更新")
# 准备插入SQL
insert_sql = """
INSERT INTO sw_heat_daily
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
name = VALUES(name),
total_return = VALUES(total_return),
avg_turnover = VALUES(avg_turnover),
volume_ratio = VALUES(volume_ratio),
momentum = VALUES(momentum),
total_return_score = VALUES(total_return_score),
avg_turnover_score = VALUES(avg_turnover_score),
volume_ratio_score = VALUES(volume_ratio_score),
momentum_score = VALUES(momentum_score),
heat_score = VALUES(heat_score),
created_at = CURRENT_TIMESTAMP;
"""
# 转换数据为列表元组
data_tuples = []
for _, row in df.iterrows():
data_tuples.append((
calc_date,
str(row['code']),
str(row['name']),
float(row['total_return']),
float(row['avg_turnover']),
float(row['volume_ratio']),
float(row['momentum']),
float(row['total_return_score']),
float(row['avg_turnover_score']),
float(row['volume_ratio_score']),
float(row['momentum_score']),
float(row['heat_score'])
))
# 执行批量插入
cursor.executemany(insert_sql, data_tuples)
conn.commit()
action = "覆盖更新" if existing > 0 else "新增"
print(f"✅ MySQL{action}成功: {calc_date}{len(data_tuples)} 条记录")
return True
except Exception as e:
print(f"❌ MySQL操作失败: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
def batch_import(data_dir, db_config, pattern=None, dry_run=False):
"""
批量导入目录下所有符合格式的CSV文件
"""
if pattern is None:
pattern = r'heat_ranking_\d{4}-\d{2}-\d{2}_.*\.csv'
files = [f for f in os.listdir(data_dir) if re.match(pattern, f)]
files.sort()
print(f"📁 发现 {len(files)} 个待导入文件")
success_count = 0
for filename in files:
file_path = os.path.join(data_dir, filename)
try:
calc_date, _ = extract_date_from_filename(filename)
df = read_heat_csv(file_path)
if save_to_mysql(df, calc_date, db_config, dry_run):
success_count += 1
except Exception as e:
print(f"❌ 处理 {filename} 失败: {e}")
continue
print(f"\n📊 导入完成: {success_count}/{len(files)} 个文件成功")
def main():
# 数据库配置(请根据实际情况修改)
config = {
'host': '10.127.2.207',
'user': 'financial_prod',
'password': 'mmTFncqmDal5HLRGY0BV',
'database': 'reference',
'port': 3306,
'charset': 'utf8mb4'
}
# 命令行参数解析
parser = argparse.ArgumentParser(description='申万行业热度数据回溯补录工具')
parser.add_argument('file', nargs='?', help='单个CSV文件路径')
parser.add_argument('--date', '-d', help='指定日期(YYYYMMDD),覆盖文件名中的日期')
parser.add_argument('--batch', '-b', action='store_true', help='批量导入目录下所有文件')
parser.add_argument('--dir', default='./sw_data', help='数据目录路径(默认: ./sw_data)')
parser.add_argument('--dry-run', action='store_true', help='预览模式,不实际写入数据库')
args = parser.parse_args()
# 批量模式
if args.batch:
batch_import(args.dir, config, dry_run=args.dry_run)
return
# 单文件模式
if args.file:
file_path = args.file
else:
# 默认示例文件
file_path = './sw_data/heat_ranking_2026-03-04_5d.csv'
if not os.path.exists(file_path):
print("请提供CSV文件路径或使用 --batch 批量导入")
parser.print_help()
return
try:
# 读取CSV
df = read_heat_csv(file_path)
# 确定日期
if args.date:
calc_date = args.date # 命令行指定的日期
print(f"📅 使用指定日期: {calc_date}")
else:
calc_date, original_date = extract_date_from_filename(file_path)
print(f"📅 从文件名提取日期: {original_date} -> {calc_date}")
# 存入MySQL
save_to_mysql(df, calc_date, config, dry_run=args.dry_run)
except Exception as e:
print(f"❌ 错误: {e}")
raise
if __name__ == "__main__":
main()