find_strokes.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. def find_strokes(fractals, df, min_interval=5):
  2. """
  3. 在此函数中对初步分型(fractals)进行有效性判断和筛选,
  4. 包括:
  5. 1. 去除连续同类型分型中较不极端的分型。
  6. 2. 确保顶底交替和满足最小间隔要求 min_interval。
  7. 3. 最终根据有效分型构造笔(strokes)。
  8. :param fractals: 初步的分型点列表 [(index, 'Top'), (index, 'Bottom')]
  9. :param df: 去除包含关系的K线数据,包含High、Low
  10. :param min_interval: 顶底之间的最小间隔K线数
  11. :return: (valid_fractals, strokes)
  12. valid_fractals: 有效分型列表 [(index, 'Top'), (index, 'Bottom')]
  13. strokes: 笔的列表 [(start_index, end_index)]
  14. """
  15. if not fractals:
  16. return [], []
  17. # 1. 先对分型列表排序(按index)
  18. fractals = sorted(fractals, key=lambda x: x[0])
  19. valid_fractals = [fractals[0]] # 将第一个分型作为初始有效分型
  20. # 开始进行有效性筛选
  21. for i in range(1, len(fractals)):
  22. current_idx, current_type = fractals[i]
  23. last_idx, last_type = valid_fractals[-1]
  24. # 如果类型相同,保留更加极端的分型
  25. if current_type == last_type:
  26. if current_type == 'Top':
  27. # 保留更高的顶
  28. if df.loc[current_idx, 'High'] > df.loc[last_idx, 'High']:
  29. valid_fractals[-1] = (current_idx, current_type)
  30. # 否则保持不变(跳过当前分型)
  31. else: # 'Bottom'
  32. # 保留更低的底
  33. if df.loc[current_idx, 'Low'] < df.loc[last_idx, 'Low']:
  34. valid_fractals[-1] = (current_idx, current_type)
  35. # 否则保持不变
  36. else:
  37. # 类型不同,检查间隔
  38. if current_idx - last_idx - 1 >= min_interval:
  39. valid_fractals.append((current_idx, current_type))
  40. # 如果间隔不够,不加入
  41. # 有了有效分型列表后,再构造笔
  42. strokes = []
  43. for i in range(len(valid_fractals) - 1):
  44. start_idx, start_type = valid_fractals[i]
  45. end_idx, end_type = valid_fractals[i + 1]
  46. # 此时有效分型保证类型交替(如果需要进一步确保,也可再加判断)
  47. strokes.append((start_idx, end_idx))
  48. return valid_fractals, strokes