68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
from collections import defaultdict
|
|
import logging
|
|
import numpy as np
|
|
import torch as th
|
|
|
|
class Logger:
|
|
def __init__(self, console_logger):
|
|
self.console_logger = console_logger
|
|
|
|
self.use_tb = False
|
|
self.use_sacred = False
|
|
self.use_hdf = False
|
|
|
|
self.stats = defaultdict(lambda: [])
|
|
|
|
def setup_tb(self, directory_name):
|
|
# Import here so it doesn't have to be installed if you don't use it
|
|
from tensorboard_logger import configure, log_value
|
|
configure(directory_name)
|
|
self.tb_logger = log_value
|
|
self.use_tb = True
|
|
|
|
def setup_sacred(self, sacred_run_dict):
|
|
self.sacred_info = sacred_run_dict.info
|
|
self.use_sacred = True
|
|
|
|
def log_stat(self, key, value, t, to_sacred=True):
|
|
self.stats[key].append((t, value))
|
|
|
|
if self.use_tb:
|
|
self.tb_logger(key, value, t)
|
|
|
|
if self.use_sacred and to_sacred:
|
|
if key in self.sacred_info:
|
|
self.sacred_info["{}_T".format(key)].append(t)
|
|
self.sacred_info[key].append(value)
|
|
else:
|
|
self.sacred_info["{}_T".format(key)] = [t]
|
|
self.sacred_info[key] = [value]
|
|
|
|
def print_recent_stats(self):
|
|
log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(*self.stats["episode"][-1])
|
|
i = 0
|
|
for (k, v) in sorted(self.stats.items()):
|
|
if k == "episode":
|
|
continue
|
|
i += 1
|
|
window = 5 if k != "epsilon" else 1
|
|
item = "{:.4f}".format(th.mean(th.tensor([float(x[1]) for x in self.stats[k][-window:]])))
|
|
log_str += "{:<25}{:>8}".format(k + ":", item)
|
|
log_str += "\n" if i % 4 == 0 else "\t"
|
|
self.console_logger.info(log_str)
|
|
# Reset stats to avoid accumulating logs in memory
|
|
self.stats = defaultdict(lambda: [])
|
|
|
|
|
|
# set up a custom logger
|
|
def get_logger():
|
|
logger = logging.getLogger()
|
|
logger.handlers = []
|
|
ch = logging.StreamHandler()
|
|
formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S')
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
logger.setLevel('DEBUG')
|
|
|
|
return logger
|
|
|