| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- import pandas as pd
- from stock_processor import get_stock_data
- from kline_processor import KlineProcessor
- from validate_fractals import identify_fractals
- from find_strokes import find_strokes # 导入 find_strokes
- import os
- from multiprocessing import Pool, cpu_count
- from functools import partial
- import pymysql
- # 数据库配置
- db_config = {
- 'host': 'localhost',
- 'port': 3307,
- 'user': 'root',
- 'password': 'r6kEwqWU9!v3',
- 'database': 'qmt_stocks_whole'
- }
- def get_stock_list(db_config):
- """
- 从数据库中获取股票列表。
- :param db_config: 数据库配置信息
- :return: 股票表名列表
- """
- try:
- connection = pymysql.connect(
- host=db_config['host'],
- port=db_config['port'],
- user=db_config['user'],
- password=db_config['password'],
- database=db_config['database']
- )
- cursor = connection.cursor()
- # 查询所有股票表名,假设所有股票表在 'qmt_stocks_whole' 数据库中
- cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = %s", (db_config['database'],))
- results = cursor.fetchall()
- stock_list = [row[0] for row in results]
- cursor.close()
- connection.close()
- return stock_list
- except Exception as e:
- print(f"获取股票列表失败:{e}")
- return []
- def process_stock(table_name, db_config, output_dir):
- """
- 处理单只股票数据的完整流程:
- 1. 调取元数据
- 2. 处理包含关系
- 3. 标记初步的顶底分型
- 4. 验证分型
- 5. 生成笔
- """
- cleaned_df, fractals, strokes = None, None, None
- try:
- # 1. 调取元数据
- df = get_stock_data(table_name, db_config)
- if df.empty:
- print(f"{table_name} 数据为空,跳过。")
- return
-
- # 2. 处理包含关系
- processor = KlineProcessor(df)
- cleaned_df = processor.df
- # 第3步:获取初步分型(此时validate_fractals仅返回初步的fractals)
- fractals = identify_fractals(cleaned_df)
- # 第4步:使用find_strokes进行有效性判断和生成笔
- valid_fractals, strokes = find_strokes(fractals, cleaned_df)
- # 第5步:导出结果
- export_to_csv(cleaned_df, valid_fractals, strokes, output_dir, table_name)
- print(f"{table_name} 处理完成。")
- except Exception as e:
- print(f"处理 {table_name} 时出错:{e}")
- def export_to_csv(df, fractals, strokes, output_dir, table_name):
- """
- 将数据框、分型信息和笔信息导出为 CSV 文件。
- 仅标记有效的顶底分型和笔的起止点。
- :param df: 数据框,包含 K 线数据
- :param fractals: 分型点列表,格式为 [(index, 'Top'), (index, 'Bottom')]
- :param strokes: 笔的起止点列表 [(start_index, end_index)]
- :param output_dir: 导出文件的目录
- :param table_name: 股票表名
- """
- # 添加分型列
- df['Fractal'] = ""
- for idx, fractal_type in fractals:
- df.at[idx, 'Fractal'] = fractal_type
- # 添加笔列,使用两个独立的列标记起点和终点
- df['Stroke_Start'] = ""
- df['Stroke_End'] = ""
- # 标记笔的起点和终点
- for i, (start, end) in enumerate(strokes, 1):
- if start < len(df):
- df.at[start, 'Stroke_Start'] = f'Stroke{i}_Start'
- if end < len(df):
- df.at[end, 'Stroke_End'] = f'Stroke{i}_End'
- # 导出为 CSV,不进行数据过滤
- os.makedirs(output_dir, exist_ok=True) # 确保目录存在
- output_file = os.path.join(output_dir, f"{table_name}_result.csv")
- df.to_csv(output_file, index=False, encoding='utf-8-sig')
- print(f"数据已成功导出到 {output_file}")
- def main():
- output_dir = "./output" # 导出文件的目录
- # 获取股票列表
- stock_list = get_stock_list(db_config)
- if not stock_list:
- print("未获取到任何股票表名,程序终止。")
- return
- # 限制处理前20个股票
- stock_list = stock_list[:20]
- print(f"共获取到 {len(stock_list)} 只股票,开始处理。")
- # 使用多进程
- pool_size = cpu_count()
- print(f"使用 {pool_size} 个进程进行并行处理。")
- with Pool(pool_size) as pool:
- pool.map(partial(process_stock, db_config=db_config, output_dir=output_dir), stock_list)
- print("前20只股票数据处理完成。")
- if __name__ == "__main__":
- main()
|