import pandas as pd from stock_processor import get_stock_data from kline_processor import KlineProcessor from validate_fractals import identify_fractals from find_strokes import find_strokes # 导入 find_strokes import os from multiprocessing import Pool, cpu_count from functools import partial import pymysql # 数据库配置 db_config = { 'host': 'localhost', 'port': 3307, 'user': 'root', 'password': 'r6kEwqWU9!v3', 'database': 'qmt_stocks_whole' } def get_stock_list(db_config): """ 从数据库中获取股票列表。 :param db_config: 数据库配置信息 :return: 股票表名列表 """ try: connection = pymysql.connect( host=db_config['host'], port=db_config['port'], user=db_config['user'], password=db_config['password'], database=db_config['database'] ) cursor = connection.cursor() # 查询所有股票表名,假设所有股票表在 'qmt_stocks_whole' 数据库中 cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = %s", (db_config['database'],)) results = cursor.fetchall() stock_list = [row[0] for row in results] cursor.close() connection.close() return stock_list except Exception as e: print(f"获取股票列表失败:{e}") return [] def process_stock(table_name, db_config, output_dir): """ 处理单只股票数据的完整流程: 1. 调取元数据 2. 处理包含关系 3. 标记初步的顶底分型 4. 验证分型 5. 生成笔 """ cleaned_df, fractals, strokes = None, None, None try: # 1. 调取元数据 df = get_stock_data(table_name, db_config) if df.empty: print(f"{table_name} 数据为空,跳过。") return # 2. 处理包含关系 processor = KlineProcessor(df) cleaned_df = processor.df # 第3步:获取初步分型(此时validate_fractals仅返回初步的fractals) fractals = identify_fractals(cleaned_df) # 第4步:使用find_strokes进行有效性判断和生成笔 valid_fractals, strokes = find_strokes(fractals, cleaned_df) # 第5步:导出结果 export_to_csv(cleaned_df, valid_fractals, strokes, output_dir, table_name) print(f"{table_name} 处理完成。") except Exception as e: print(f"处理 {table_name} 时出错:{e}") def export_to_csv(df, fractals, strokes, output_dir, table_name): """ 将数据框、分型信息和笔信息导出为 CSV 文件。 仅标记有效的顶底分型和笔的起止点。 :param df: 数据框,包含 K 线数据 :param fractals: 分型点列表,格式为 [(index, 'Top'), (index, 'Bottom')] :param strokes: 笔的起止点列表 [(start_index, end_index)] :param output_dir: 导出文件的目录 :param table_name: 股票表名 """ # 添加分型列 df['Fractal'] = "" for idx, fractal_type in fractals: df.at[idx, 'Fractal'] = fractal_type # 添加笔列,使用两个独立的列标记起点和终点 df['Stroke_Start'] = "" df['Stroke_End'] = "" # 标记笔的起点和终点 for i, (start, end) in enumerate(strokes, 1): if start < len(df): df.at[start, 'Stroke_Start'] = f'Stroke{i}_Start' if end < len(df): df.at[end, 'Stroke_End'] = f'Stroke{i}_End' # 导出为 CSV,不进行数据过滤 os.makedirs(output_dir, exist_ok=True) # 确保目录存在 output_file = os.path.join(output_dir, f"{table_name}_result.csv") df.to_csv(output_file, index=False, encoding='utf-8-sig') print(f"数据已成功导出到 {output_file}") def main(): output_dir = "./output" # 导出文件的目录 # 获取股票列表 stock_list = get_stock_list(db_config) if not stock_list: print("未获取到任何股票表名,程序终止。") return # 限制处理前20个股票 stock_list = stock_list[:20] print(f"共获取到 {len(stock_list)} 只股票,开始处理。") # 使用多进程 pool_size = cpu_count() print(f"使用 {pool_size} 个进程进行并行处理。") with Pool(pool_size) as pool: pool.map(partial(process_stock, db_config=db_config, output_dir=output_dir), stock_list) print("前20只股票数据处理完成。") if __name__ == "__main__": main()