244 lines
8.2 KiB
Python
244 lines
8.2 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
申万行业热度数据回溯补录工具
|
|||
|
|
读取本地CSV文件,存入MySQL sw_heat_daily表
|
|||
|
|
用于历史数据补录或重新导入
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pandas as pd
|
|||
|
|
import pymysql
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
from datetime import datetime
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_date_from_filename(filename):
|
|||
|
|
"""
|
|||
|
|
从文件名提取日期
|
|||
|
|
例如: heat_ranking_2026-03-04_5d.csv -> 2026-03-04 -> 20260304
|
|||
|
|
"""
|
|||
|
|
# 匹配日期格式 YYYY-MM-DD
|
|||
|
|
pattern = r'(\d{4}-\d{2}-\d{2})'
|
|||
|
|
match = re.search(pattern, filename)
|
|||
|
|
|
|||
|
|
if match:
|
|||
|
|
date_str = match.group(1) # 2026-03-04
|
|||
|
|
# 转换为 YYYYMMDD 格式
|
|||
|
|
calc_date = date_str.replace('-', '') # 20260304
|
|||
|
|
return calc_date, date_str
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"无法从文件名提取日期: {filename}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def read_heat_csv(file_path):
|
|||
|
|
"""
|
|||
|
|
读取热度排名CSV文件
|
|||
|
|
"""
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
raise FileNotFoundError(f"文件不存在: {file_path}")
|
|||
|
|
|
|||
|
|
print(f"📂 读取文件: {file_path}")
|
|||
|
|
df = pd.read_csv(file_path, encoding='utf-8-sig')
|
|||
|
|
|
|||
|
|
# 检查必要列是否存在
|
|||
|
|
required_cols = ['code', 'name', 'total_return', 'avg_turnover',
|
|||
|
|
'volume_ratio', 'momentum', 'total_return_score',
|
|||
|
|
'avg_turnover_score', 'volume_ratio_score',
|
|||
|
|
'momentum_score', 'heat_score']
|
|||
|
|
|
|||
|
|
missing_cols = [col for col in required_cols if col not in df.columns]
|
|||
|
|
if missing_cols:
|
|||
|
|
# 尝试兼容旧版列名(如果有的话)
|
|||
|
|
col_mapping = {
|
|||
|
|
'行业代码': 'code',
|
|||
|
|
'行业名称': 'name',
|
|||
|
|
# 添加其他可能的映射
|
|||
|
|
}
|
|||
|
|
for old_col, new_col in col_mapping.items():
|
|||
|
|
if old_col in df.columns and new_col not in df.columns:
|
|||
|
|
df[new_col] = df[old_col]
|
|||
|
|
|
|||
|
|
# 再次检查
|
|||
|
|
still_missing = [col for col in required_cols if col not in df.columns]
|
|||
|
|
if still_missing:
|
|||
|
|
raise ValueError(f"CSV文件缺少必要列: {still_missing}")
|
|||
|
|
|
|||
|
|
print(f"✅ 成功读取 {len(df)} 条行业数据")
|
|||
|
|
return df
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_to_mysql(df, calc_date, db_config, dry_run=False):
|
|||
|
|
"""
|
|||
|
|
将DataFrame存入MySQL
|
|||
|
|
:param calc_date: YYYYMMDD 格式的日期字符串
|
|||
|
|
:param dry_run: 如果为True,只打印SQL不执行
|
|||
|
|
"""
|
|||
|
|
if dry_run:
|
|||
|
|
print(f"\n🔍 [DRY RUN] 预览模式,不实际写入数据库")
|
|||
|
|
print(f"📅 目标日期: {calc_date}")
|
|||
|
|
print(f"📊 数据预览:")
|
|||
|
|
print(df[['code', 'name', 'heat_score']].head())
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
conn = None
|
|||
|
|
try:
|
|||
|
|
conn = pymysql.connect(**db_config)
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 检查该日期是否已有数据
|
|||
|
|
cursor.execute(
|
|||
|
|
"SELECT COUNT(*) FROM sw_heat_daily WHERE calc_date = %s",
|
|||
|
|
(calc_date,)
|
|||
|
|
)
|
|||
|
|
existing = cursor.fetchone()[0]
|
|||
|
|
|
|||
|
|
if existing > 0:
|
|||
|
|
print(f"⚠️ 日期 {calc_date} 已有 {existing} 条记录,将执行覆盖更新")
|
|||
|
|
|
|||
|
|
# 准备插入SQL
|
|||
|
|
insert_sql = """
|
|||
|
|
INSERT INTO sw_heat_daily
|
|||
|
|
(calc_date, code, name, total_return, avg_turnover, volume_ratio, momentum,
|
|||
|
|
total_return_score, avg_turnover_score, volume_ratio_score, momentum_score, heat_score)
|
|||
|
|
VALUES
|
|||
|
|
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|||
|
|
ON DUPLICATE KEY UPDATE
|
|||
|
|
name = VALUES(name),
|
|||
|
|
total_return = VALUES(total_return),
|
|||
|
|
avg_turnover = VALUES(avg_turnover),
|
|||
|
|
volume_ratio = VALUES(volume_ratio),
|
|||
|
|
momentum = VALUES(momentum),
|
|||
|
|
total_return_score = VALUES(total_return_score),
|
|||
|
|
avg_turnover_score = VALUES(avg_turnover_score),
|
|||
|
|
volume_ratio_score = VALUES(volume_ratio_score),
|
|||
|
|
momentum_score = VALUES(momentum_score),
|
|||
|
|
heat_score = VALUES(heat_score),
|
|||
|
|
created_at = CURRENT_TIMESTAMP;
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 转换数据为列表元组
|
|||
|
|
data_tuples = []
|
|||
|
|
for _, row in df.iterrows():
|
|||
|
|
data_tuples.append((
|
|||
|
|
calc_date,
|
|||
|
|
str(row['code']),
|
|||
|
|
str(row['name']),
|
|||
|
|
float(row['total_return']),
|
|||
|
|
float(row['avg_turnover']),
|
|||
|
|
float(row['volume_ratio']),
|
|||
|
|
float(row['momentum']),
|
|||
|
|
float(row['total_return_score']),
|
|||
|
|
float(row['avg_turnover_score']),
|
|||
|
|
float(row['volume_ratio_score']),
|
|||
|
|
float(row['momentum_score']),
|
|||
|
|
float(row['heat_score'])
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# 执行批量插入
|
|||
|
|
cursor.executemany(insert_sql, data_tuples)
|
|||
|
|
conn.commit()
|
|||
|
|
|
|||
|
|
action = "覆盖更新" if existing > 0 else "新增"
|
|||
|
|
print(f"✅ MySQL{action}成功: {calc_date} 共 {len(data_tuples)} 条记录")
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ MySQL操作失败: {e}")
|
|||
|
|
if conn:
|
|||
|
|
conn.rollback()
|
|||
|
|
return False
|
|||
|
|
finally:
|
|||
|
|
if conn:
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def batch_import(data_dir, db_config, pattern=None, dry_run=False):
|
|||
|
|
"""
|
|||
|
|
批量导入目录下所有符合格式的CSV文件
|
|||
|
|
"""
|
|||
|
|
if pattern is None:
|
|||
|
|
pattern = r'heat_ranking_\d{4}-\d{2}-\d{2}_.*\.csv'
|
|||
|
|
|
|||
|
|
files = [f for f in os.listdir(data_dir) if re.match(pattern, f)]
|
|||
|
|
files.sort()
|
|||
|
|
|
|||
|
|
print(f"📁 发现 {len(files)} 个待导入文件")
|
|||
|
|
|
|||
|
|
success_count = 0
|
|||
|
|
for filename in files:
|
|||
|
|
file_path = os.path.join(data_dir, filename)
|
|||
|
|
try:
|
|||
|
|
calc_date, _ = extract_date_from_filename(filename)
|
|||
|
|
df = read_heat_csv(file_path)
|
|||
|
|
|
|||
|
|
if save_to_mysql(df, calc_date, db_config, dry_run):
|
|||
|
|
success_count += 1
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 处理 {filename} 失败: {e}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"\n📊 导入完成: {success_count}/{len(files)} 个文件成功")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# 数据库配置(请根据实际情况修改)
|
|||
|
|
config = {
|
|||
|
|
'host': '10.127.2.207',
|
|||
|
|
'user': 'financial_prod',
|
|||
|
|
'password': 'mmTFncqmDal5HLRGY0BV',
|
|||
|
|
'database': 'reference',
|
|||
|
|
'port': 3306,
|
|||
|
|
'charset': 'utf8mb4'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 命令行参数解析
|
|||
|
|
parser = argparse.ArgumentParser(description='申万行业热度数据回溯补录工具')
|
|||
|
|
parser.add_argument('file', nargs='?', help='单个CSV文件路径')
|
|||
|
|
parser.add_argument('--date', '-d', help='指定日期(YYYYMMDD),覆盖文件名中的日期')
|
|||
|
|
parser.add_argument('--batch', '-b', action='store_true', help='批量导入目录下所有文件')
|
|||
|
|
parser.add_argument('--dir', default='./sw_data', help='数据目录路径(默认: ./sw_data)')
|
|||
|
|
parser.add_argument('--dry-run', action='store_true', help='预览模式,不实际写入数据库')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 批量模式
|
|||
|
|
if args.batch:
|
|||
|
|
batch_import(args.dir, config, dry_run=args.dry_run)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 单文件模式
|
|||
|
|
if args.file:
|
|||
|
|
file_path = args.file
|
|||
|
|
else:
|
|||
|
|
# 默认示例文件
|
|||
|
|
file_path = './sw_data/heat_ranking_2026-03-04_5d.csv'
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
print("请提供CSV文件路径,或使用 --batch 批量导入")
|
|||
|
|
parser.print_help()
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 读取CSV
|
|||
|
|
df = read_heat_csv(file_path)
|
|||
|
|
|
|||
|
|
# 确定日期
|
|||
|
|
if args.date:
|
|||
|
|
calc_date = args.date # 命令行指定的日期
|
|||
|
|
print(f"📅 使用指定日期: {calc_date}")
|
|||
|
|
else:
|
|||
|
|
calc_date, original_date = extract_date_from_filename(file_path)
|
|||
|
|
print(f"📅 从文件名提取日期: {original_date} -> {calc_date}")
|
|||
|
|
|
|||
|
|
# 存入MySQL
|
|||
|
|
save_to_mysql(df, calc_date, config, dry_run=args.dry_run)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 错误: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|