| |
|
|
| """ |
| Misc download and visualization helper functions and class wrappers. |
| """ |
|
|
| import sys |
| import time |
| import torch |
| from visdom import Visdom |
|
|
|
|
| def reporthook(count, block_size, total_size): |
| global start_time |
| if count == 0: |
| start_time = time.time() |
| return |
| duration = time.time() - start_time |
| progress_size = int(count * block_size) |
| speed = int(progress_size / (1024 * duration)) |
| percent = min(int(count * block_size * 100 / total_size), 100) |
| sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % |
| (percent, progress_size / (1024 * 1024), speed, duration)) |
| sys.stdout.flush() |
|
|
|
|
| class VisdomWrapper(Visdom): |
| def __init__(self, *args, env=None, **kwargs): |
| Visdom.__init__(self, *args, **kwargs) |
| self.env = env |
| self.plots = {} |
|
|
| def init_line_plot(self, name, |
| X=torch.zeros((1,)).cpu(), |
| Y=torch.zeros((1,)).cpu(), **opts): |
| self.plots[name] = self.line(X=X, Y=Y, env=self.env, opts=opts) |
|
|
| def plot_line(self, name, **kwargs): |
| self.line(win=self.plots[name], env=self.env, **kwargs) |
|
|