| |
| |
| from __future__ import absolute_import |
|
|
| import os |
| import sys |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| __all__ = ["Logger", "LoggerMonitor", "savefig"] |
|
|
|
|
| def savefig(fname, dpi=None): |
| dpi = 150 if dpi == None else dpi |
| plt.savefig(fname, dpi=dpi) |
|
|
|
|
| def plot_overlap(logger, names=None): |
| names = logger.names if names == None else names |
| numbers = logger.numbers |
| for _, name in enumerate(names): |
| x = np.arange(len(numbers[name])) |
| plt.plot(x, np.asarray(numbers[name])) |
| return [logger.title + "(" + name + ")" for name in names] |
|
|
|
|
| class Logger(object): |
| """Save training process to log file with simple plot function.""" |
|
|
| def __init__(self, fpath, title=None, resume=False): |
| self.file = None |
| self.resume = resume |
| self.title = "" if title == None else title |
| if fpath is not None: |
| if resume: |
| self.file = open(fpath, "r") |
| name = self.file.readline() |
| self.names = name.rstrip().split("\t") |
| self.numbers = {} |
| for _, name in enumerate(self.names): |
| self.numbers[name] = [] |
|
|
| for numbers in self.file: |
| numbers = numbers.rstrip().split("\t") |
| for i in range(0, len(numbers)): |
| self.numbers[self.names[i]].append(numbers[i]) |
| self.file.close() |
| self.file = open(fpath, "a") |
| else: |
| self.file = open(fpath, "w") |
|
|
| def set_names(self, names): |
| if self.resume: |
| pass |
| |
| self.numbers = {} |
| self.names = names |
| for _, name in enumerate(self.names): |
| self.file.write(name) |
| self.file.write("\t") |
| self.numbers[name] = [] |
| self.file.write("\n") |
| self.file.flush() |
|
|
| def append(self, numbers): |
| assert len(self.names) == len(numbers), "Numbers do not match names" |
| for index, num in enumerate(numbers): |
| self.file.write("{0:.6f}".format(num)) |
| self.file.write("\t") |
| self.numbers[self.names[index]].append(num) |
| self.file.write("\n") |
| self.file.flush() |
|
|
| def plot(self, names=None): |
| names = self.names if names == None else names |
| numbers = self.numbers |
| for _, name in enumerate(names): |
| x = np.arange(len(numbers[name])) |
| plt.plot(x, np.asarray(numbers[name])) |
| plt.legend([self.title + "(" + name + ")" for name in names]) |
| plt.grid(True) |
|
|
| def close(self): |
| if self.file is not None: |
| self.file.close() |
|
|
|
|
| class LoggerMonitor(object): |
| """Load and visualize multiple logs.""" |
|
|
| def __init__(self, paths): |
| """paths is a distionary with {name:filepath} pair""" |
| self.loggers = [] |
| for title, path in paths.items(): |
| logger = Logger(path, title=title, resume=True) |
| self.loggers.append(logger) |
|
|
| def plot(self, names=None): |
| plt.figure() |
| plt.subplot(121) |
| legend_text = [] |
| for logger in self.loggers: |
| legend_text += plot_overlap(logger, names) |
| plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) |
| plt.grid(True) |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| paths = { |
| "resadvnet20": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt", |
| "resadvnet32": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt", |
| "resadvnet44": "/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt", |
| } |
|
|
| field = ["Valid Acc."] |
|
|
| monitor = LoggerMonitor(paths) |
| monitor.plot(names=field) |
| savefig("test.eps") |
|
|