main.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import pandas as pd
  2. from stock_processor import get_stock_data
  3. from kline_processor import KlineProcessor
  4. from validate_fractals import identify_fractals
  5. from find_strokes import find_strokes # 导入 find_strokes
  6. import os
  7. from multiprocessing import Pool, cpu_count
  8. from functools import partial
  9. import pymysql
  10. # 数据库配置
  11. db_config = {
  12. 'host': 'localhost',
  13. 'port': 3307,
  14. 'user': 'root',
  15. 'password': 'r6kEwqWU9!v3',
  16. 'database': 'qmt_stocks_whole'
  17. }
  18. def get_stock_list(db_config):
  19. """
  20. 从数据库中获取股票列表。
  21. :param db_config: 数据库配置信息
  22. :return: 股票表名列表
  23. """
  24. try:
  25. connection = pymysql.connect(
  26. host=db_config['host'],
  27. port=db_config['port'],
  28. user=db_config['user'],
  29. password=db_config['password'],
  30. database=db_config['database']
  31. )
  32. cursor = connection.cursor()
  33. # 查询所有股票表名,假设所有股票表在 'qmt_stocks_whole' 数据库中
  34. cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = %s", (db_config['database'],))
  35. results = cursor.fetchall()
  36. stock_list = [row[0] for row in results]
  37. cursor.close()
  38. connection.close()
  39. return stock_list
  40. except Exception as e:
  41. print(f"获取股票列表失败:{e}")
  42. return []
  43. def process_stock(table_name, db_config, output_dir):
  44. """
  45. 处理单只股票数据的完整流程:
  46. 1. 调取元数据
  47. 2. 处理包含关系
  48. 3. 标记初步的顶底分型
  49. 4. 验证分型
  50. 5. 生成笔
  51. """
  52. cleaned_df, fractals, strokes = None, None, None
  53. try:
  54. # 1. 调取元数据
  55. df = get_stock_data(table_name, db_config)
  56. if df.empty:
  57. print(f"{table_name} 数据为空,跳过。")
  58. return
  59. # 2. 处理包含关系
  60. processor = KlineProcessor(df)
  61. cleaned_df = processor.df
  62. # 第3步:获取初步分型(此时validate_fractals仅返回初步的fractals)
  63. fractals = identify_fractals(cleaned_df)
  64. # 第4步:使用find_strokes进行有效性判断和生成笔
  65. valid_fractals, strokes = find_strokes(fractals, cleaned_df)
  66. # 第5步:导出结果
  67. export_to_csv(cleaned_df, valid_fractals, strokes, output_dir, table_name)
  68. print(f"{table_name} 处理完成。")
  69. except Exception as e:
  70. print(f"处理 {table_name} 时出错:{e}")
  71. def export_to_csv(df, fractals, strokes, output_dir, table_name):
  72. """
  73. 将数据框、分型信息和笔信息导出为 CSV 文件。
  74. 仅标记有效的顶底分型和笔的起止点。
  75. :param df: 数据框,包含 K 线数据
  76. :param fractals: 分型点列表,格式为 [(index, 'Top'), (index, 'Bottom')]
  77. :param strokes: 笔的起止点列表 [(start_index, end_index)]
  78. :param output_dir: 导出文件的目录
  79. :param table_name: 股票表名
  80. """
  81. # 添加分型列
  82. df['Fractal'] = ""
  83. for idx, fractal_type in fractals:
  84. df.at[idx, 'Fractal'] = fractal_type
  85. # 添加笔列,使用两个独立的列标记起点和终点
  86. df['Stroke_Start'] = ""
  87. df['Stroke_End'] = ""
  88. # 标记笔的起点和终点
  89. for i, (start, end) in enumerate(strokes, 1):
  90. if start < len(df):
  91. df.at[start, 'Stroke_Start'] = f'Stroke{i}_Start'
  92. if end < len(df):
  93. df.at[end, 'Stroke_End'] = f'Stroke{i}_End'
  94. # 导出为 CSV,不进行数据过滤
  95. os.makedirs(output_dir, exist_ok=True) # 确保目录存在
  96. output_file = os.path.join(output_dir, f"{table_name}_result.csv")
  97. df.to_csv(output_file, index=False, encoding='utf-8-sig')
  98. print(f"数据已成功导出到 {output_file}")
  99. def main():
  100. output_dir = "./output" # 导出文件的目录
  101. # 获取股票列表
  102. stock_list = get_stock_list(db_config)
  103. if not stock_list:
  104. print("未获取到任何股票表名,程序终止。")
  105. return
  106. # 限制处理前20个股票
  107. stock_list = stock_list[:20]
  108. print(f"共获取到 {len(stock_list)} 只股票,开始处理。")
  109. # 使用多进程
  110. pool_size = cpu_count()
  111. print(f"使用 {pool_size} 个进程进行并行处理。")
  112. with Pool(pool_size) as pool:
  113. pool.map(partial(process_stock, db_config=db_config, output_dir=output_dir), stock_list)
  114. print("前20只股票数据处理完成。")
  115. if __name__ == "__main__":
  116. main()