pymarl3-feudal/plotsmac.py
2025-01-08 18:39:17 +08:00

117 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import matplotlib.pyplot as plt
import numpy as np
import tkinter as tk
from tkinter import ttk
from scipy.signal import savgol_filter
# 讀取多個 JSON 檔案
def load_data(file_paths):
all_data = []
for file_path in file_paths:
with open(file_path) as f:
data = json.load(f)
# 處理特殊的數據格式numpy.float64 對象)
processed_data = {}
for key, value in data.items():
if isinstance(value, list):
# 檢查是否包含字典格式的數值
if value and isinstance(value[0], dict) and 'value' in value[0]:
processed_data[key] = [item['value'] for item in value]
else:
processed_data[key] = value
else:
processed_data[key] = value
all_data.append(processed_data)
return all_data
def smooth(y, window_length=51, polyorder=3):
return savgol_filter(y, window_length, polyorder)
# 定義要繪製的數據
def plot_data(data_list, keys, name_list, battle_name, smooth_window=2):
for key in keys:
# 創建一個圖形
fig, ax = plt.subplots()
# 設置坐標軸背景顏色
ax.set_facecolor('lightyellow')
# 設置網格顏色
ax.grid(color='green', linestyle='--', linewidth=0.5)
for data, name in zip(data_list, name_list):
if key in data:
x = data[key + '_T']
y = data[key]
y_smooth = smooth(y)
ax.plot(x, y_smooth, label=name)
#ax.plot(data[key + '_T'], data[key], label=name)
ax.set_xlabel('Time Steps')
ax.set_ylabel(key)
ax.set_title(battle_name)
ax.legend()
plt.show()
def create_dynamic_window(data_list, keys, name_list, battle_name):
root = tk.Tk()
root.title("動態選擇要顯示的數據")
# 創建左側的選擇面板
select_frame = ttk.Frame(root)
select_frame.pack(side=tk.LEFT, fill=tk.Y, padx=10, pady=5)
# 創建變量來存儲選擇狀態
vars = []
def update_plot(*args):
# 獲取選中的數據
selected_names = [name for name, var in zip(name_list, vars) if var.get()]
filtered_data = [data for data, name in zip(data_list, name_list) if name in selected_names]
filtered_names = [name for name in name_list if name in selected_names]
# 清除所有現有的圖表
plt.close('all')
# 重新繪製圖表
if filtered_data: # 確保至少選擇了一個數據
plot_data(filtered_data, keys, filtered_names, battle_name)
# 創建複選框
for name in name_list:
var = tk.BooleanVar(value=True) # 默認全選
var.trace('w', update_plot) # 添加跟踪器,當值改變時更新圖表
vars.append(var)
cb = ttk.Checkbutton(select_frame, text=name, variable=var)
cb.pack(anchor='w', padx=5, pady=2)
# 添加全選/取消全選按鈕
def select_all():
for var in vars:
var.set(True)
def deselect_all():
for var in vars:
var.set(False)
ttk.Button(select_frame, text="全選", command=select_all).pack(pady=5)
ttk.Button(select_frame, text="取消全選", command=deselect_all).pack(pady=5)
# 初始繪圖
update_plot()
root.mainloop()
# 使用者選擇要繪製的數據
file_paths = ['results/sacred/10gen_protoss/feudal/5/info.json',
'results/sacred/10gen_protoss/qmix/6/info.json',
]
data_list = load_data(file_paths)
selected_keys = ['test_battle_won_mean','return_mean',"worker_loss","loss_td","manager_loss"]
#selected_keys = ['battle_won_mean', 'loss']
name_list = ['feudal', 'qmix']
battle_name = '5protoss'
# 使用新的動態選擇窗口
create_dynamic_window(data_list, selected_keys, name_list, battle_name)