plot.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os
  2. import glob
  3. import pandas as pd
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from matplotlib.font_manager import FontProperties
  7. # ===================== 配置 =====================
  8. data_dir = r"E:\Greentech\models\uf-rl\datasets\processed\segments"
  9. target_col = "cycle_long_r2"
  10. # ===================== 中文字体设置 =====================
  11. # 注意:这里使用 SimHei 字体,可显示中文
  12. font = FontProperties(fname=r"C:\Windows\Fonts\simhei.ttf", size=12)
  13. # ===================== 读取所有 CSV =====================
  14. all_files = glob.glob(os.path.join(data_dir, "*.csv"))
  15. values = []
  16. for file in all_files:
  17. try:
  18. df = pd.read_csv(file)
  19. if target_col in df.columns:
  20. vals = df[target_col].dropna().values
  21. values.append(vals)
  22. except Exception as e:
  23. print(f"读取失败: {file}, 错误: {e}")
  24. # 合并所有数据
  25. if len(values) == 0:
  26. raise ValueError("未在任何 CSV 中找到有效的 cycle_long_R2 数据")
  27. data = np.concatenate(values)
  28. total_count = len(data)
  29. # ===================== 定义区间 =====================
  30. bins = [
  31. -np.inf,
  32. 0.0,
  33. 0.5,
  34. 0.6,
  35. 0.7,
  36. 0.8,
  37. 0.9,
  38. 1.0
  39. ]
  40. labels = [
  41. "<0",
  42. "0 – 0.5",
  43. "0.5 – 0.6",
  44. "0.6 – 0.7",
  45. "0.7 – 0.8",
  46. "0.8 – 0.9",
  47. "0.9 – 1.0"
  48. ]
  49. # ===================== 统计分布 =====================
  50. counts = pd.cut(
  51. data,
  52. bins=bins,
  53. labels=labels,
  54. right=True,
  55. include_lowest=True
  56. ).value_counts().sort_index()
  57. ratios = counts / total_count * 100
  58. # ===================== 输出结果 =====================
  59. result = pd.DataFrame({
  60. "样本数": counts,
  61. "占比 (%)": ratios.round(2)
  62. })
  63. print(f"\n总样本数: {total_count}\n")
  64. print(result)
  65. # ===================== 绘制柱状图 =====================
  66. plt.figure(figsize=(10, 6))
  67. plt.bar(labels, ratios, color='skyblue', edgecolor='black')
  68. plt.title("cycle_long_R2 数据分布柱状图", fontproperties=font)
  69. plt.xlabel("区间", fontproperties=font)
  70. plt.ylabel("占比 (%)", fontproperties=font)
  71. plt.ylim(0, 100)
  72. plt.grid(axis='y', linestyle='--', alpha=0.7)
  73. # 在柱子上显示百分比
  74. for i, v in enumerate(ratios):
  75. plt.text(i, v + 1, f"{v:.1f}%", ha='center', va='bottom', fontsize=10, fontproperties=font)
  76. plt.tight_layout()
  77. plt.show()