Upload folder using huggingface_hub
Browse files- 2026-04-09-200000-pergame-specialized-v6/__pycache__/predict.cpython-311.pyc +0 -0
- 2026-04-09-200000-pergame-specialized-v6/model_config.json +18 -0
- 2026-04-09-200000-pergame-specialized-v6/pole_position_model.pt +3 -0
- 2026-04-09-200000-pergame-specialized-v6/pong_model.pt +3 -0
- 2026-04-09-200000-pergame-specialized-v6/predict.py +199 -0
- 2026-04-09-200000-pergame-specialized-v6/sonic_model.pt +3 -0
- 2026-04-09-200000-pergame-specialized-v6/train.log +227 -0
2026-04-09-200000-pergame-specialized-v6/__pycache__/predict.cpython-311.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
2026-04-09-200000-pergame-specialized-v6/model_config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"channels": [
|
| 3 |
+
24,
|
| 4 |
+
48,
|
| 5 |
+
96
|
| 6 |
+
],
|
| 7 |
+
"context_len": 8,
|
| 8 |
+
"games": [
|
| 9 |
+
"pong",
|
| 10 |
+
"sonic",
|
| 11 |
+
"pole_position"
|
| 12 |
+
],
|
| 13 |
+
"param_counts": {
|
| 14 |
+
"pong": 2087515,
|
| 15 |
+
"sonic": 2087515,
|
| 16 |
+
"pole_position": 2087515
|
| 17 |
+
}
|
| 18 |
+
}
|
2026-04-09-200000-pergame-specialized-v6/pole_position_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24b023b30c48afdde451c234b1ac8cd7ed69a6fbd01e5c5542cd61fa95c4fcc0
|
| 3 |
+
size 4227186
|
2026-04-09-200000-pergame-specialized-v6/pong_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab63f17574963c5ed92407a1069785f86e8fed8f6965f1861337e4854c11ed7f
|
| 3 |
+
size 4225638
|
2026-04-09-200000-pergame-specialized-v6/predict.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Predict interface for per-game specialized models."""
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ResBlock(nn.Module):
|
| 12 |
+
def __init__(self, channels, groups=8):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
| 15 |
+
self.gn1 = nn.GroupNorm(groups, channels)
|
| 16 |
+
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
| 17 |
+
self.gn2 = nn.GroupNorm(groups, channels)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
residual = x
|
| 21 |
+
x = F.silu(self.gn1(self.conv1(x)))
|
| 22 |
+
x = self.gn2(self.conv2(x))
|
| 23 |
+
return F.silu(x + residual)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SEBlock(nn.Module):
|
| 27 |
+
def __init__(self, channels, reduction=16):
|
| 28 |
+
super().__init__()
|
| 29 |
+
mid = max(channels // reduction, 4)
|
| 30 |
+
self.fc1 = nn.Linear(channels, mid)
|
| 31 |
+
self.fc2 = nn.Linear(mid, channels)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
b, c, _, _ = x.shape
|
| 35 |
+
y = x.mean(dim=[2, 3])
|
| 36 |
+
y = F.silu(self.fc1(y))
|
| 37 |
+
y = torch.sigmoid(self.fc2(y))
|
| 38 |
+
return x * y.view(b, c, 1, 1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class EncoderLevel(nn.Module):
|
| 42 |
+
def __init__(self, in_ch, out_ch, groups=8):
|
| 43 |
+
super().__init__()
|
| 44 |
+
g = min(groups, out_ch)
|
| 45 |
+
self.proj = nn.Sequential(
|
| 46 |
+
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
| 47 |
+
nn.GroupNorm(g, out_ch),
|
| 48 |
+
nn.SiLU(inplace=True),
|
| 49 |
+
)
|
| 50 |
+
self.res1 = ResBlock(out_ch, g)
|
| 51 |
+
self.res2 = ResBlock(out_ch, g)
|
| 52 |
+
self.se = SEBlock(out_ch)
|
| 53 |
+
self.pool = nn.MaxPool2d(2)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
x = self.proj(x)
|
| 57 |
+
x = self.res1(x)
|
| 58 |
+
x = self.res2(x)
|
| 59 |
+
x = self.se(x)
|
| 60 |
+
return x, self.pool(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class DecoderLevel(nn.Module):
|
| 64 |
+
def __init__(self, in_ch, skip_ch, out_ch, groups=8):
|
| 65 |
+
super().__init__()
|
| 66 |
+
g = min(groups, out_ch)
|
| 67 |
+
self.upconv = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
|
| 68 |
+
self.proj = nn.Sequential(
|
| 69 |
+
nn.Conv2d(out_ch + skip_ch, out_ch, 3, padding=1),
|
| 70 |
+
nn.GroupNorm(g, out_ch),
|
| 71 |
+
nn.SiLU(inplace=True),
|
| 72 |
+
)
|
| 73 |
+
self.res1 = ResBlock(out_ch, g)
|
| 74 |
+
self.res2 = ResBlock(out_ch, g)
|
| 75 |
+
self.se = SEBlock(out_ch)
|
| 76 |
+
|
| 77 |
+
def forward(self, x, skip):
|
| 78 |
+
x = self.upconv(x)
|
| 79 |
+
x = torch.cat([x, skip], dim=1)
|
| 80 |
+
x = self.proj(x)
|
| 81 |
+
x = self.res1(x)
|
| 82 |
+
x = self.res2(x)
|
| 83 |
+
x = self.se(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class PerGameUNet(nn.Module):
|
| 88 |
+
def __init__(self, in_channels=24, out_channels=3, channels=[24, 48, 96]):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.encoders = nn.ModuleList()
|
| 91 |
+
prev_ch = in_channels
|
| 92 |
+
for ch in channels:
|
| 93 |
+
self.encoders.append(EncoderLevel(prev_ch, ch))
|
| 94 |
+
prev_ch = ch
|
| 95 |
+
|
| 96 |
+
bottleneck_ch = channels[-1] * 2
|
| 97 |
+
g = min(8, bottleneck_ch)
|
| 98 |
+
self.bottleneck_proj = nn.Sequential(
|
| 99 |
+
nn.Conv2d(channels[-1], bottleneck_ch, 3, padding=1),
|
| 100 |
+
nn.GroupNorm(g, bottleneck_ch),
|
| 101 |
+
nn.SiLU(inplace=True),
|
| 102 |
+
)
|
| 103 |
+
self.bottleneck_res = ResBlock(bottleneck_ch, g)
|
| 104 |
+
self.bottleneck_se = SEBlock(bottleneck_ch)
|
| 105 |
+
|
| 106 |
+
self.decoders = nn.ModuleList()
|
| 107 |
+
rev_channels = list(reversed(channels))
|
| 108 |
+
prev_ch = bottleneck_ch
|
| 109 |
+
for ch in rev_channels:
|
| 110 |
+
self.decoders.append(DecoderLevel(prev_ch, ch, ch))
|
| 111 |
+
prev_ch = ch
|
| 112 |
+
|
| 113 |
+
self.out_conv = nn.Conv2d(channels[0], out_channels, 1)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
skips = []
|
| 117 |
+
for enc in self.encoders:
|
| 118 |
+
skip, x = enc(x)
|
| 119 |
+
skips.append(skip)
|
| 120 |
+
x = self.bottleneck_proj(x)
|
| 121 |
+
x = self.bottleneck_res(x)
|
| 122 |
+
x = self.bottleneck_se(x)
|
| 123 |
+
for dec, skip in zip(self.decoders, reversed(skips)):
|
| 124 |
+
x = dec(x, skip)
|
| 125 |
+
return self.out_conv(x)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _load_single_model(model_dir, game, channels, device):
|
| 129 |
+
weights_path = os.path.join(model_dir, f"{game}_model.pt")
|
| 130 |
+
model = PerGameUNet(in_channels=24, out_channels=3, channels=channels)
|
| 131 |
+
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
|
| 132 |
+
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 133 |
+
model.load_state_dict(state_dict)
|
| 134 |
+
model = model.to(device).eval()
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _detect_game(frame):
|
| 139 |
+
"""Detect which game from a single frame (64, 64, 3) uint8."""
|
| 140 |
+
mean_val = frame.astype(np.float32).mean()
|
| 141 |
+
# Pong: very dark background (mean ~2)
|
| 142 |
+
if mean_val < 30:
|
| 143 |
+
return "pong"
|
| 144 |
+
# Pole Position: bright frames (mean ~113), dominated by sky blue
|
| 145 |
+
if mean_val > 80:
|
| 146 |
+
return "pole_position"
|
| 147 |
+
# Sonic: moderate brightness (mean ~54)
|
| 148 |
+
return "sonic"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def load_model(model_dir: str):
|
| 152 |
+
config_path = os.path.join(model_dir, "model_config.json")
|
| 153 |
+
with open(config_path) as f:
|
| 154 |
+
config = json.load(f)
|
| 155 |
+
|
| 156 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 157 |
+
channels = config["channels"]
|
| 158 |
+
context_len = config["context_len"]
|
| 159 |
+
|
| 160 |
+
models = {}
|
| 161 |
+
for game in config["games"]:
|
| 162 |
+
models[game] = _load_single_model(model_dir, game, channels, device)
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
"models": models,
|
| 166 |
+
"device": device,
|
| 167 |
+
"context_len": context_len,
|
| 168 |
+
"detected_game": None,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def predict_next_frame(model_dict, context_frames: np.ndarray) -> np.ndarray:
|
| 173 |
+
device = model_dict["device"]
|
| 174 |
+
context_len = model_dict["context_len"]
|
| 175 |
+
|
| 176 |
+
# Detect game from first context frame each call
|
| 177 |
+
game = _detect_game(context_frames[0])
|
| 178 |
+
model = model_dict["models"][game]
|
| 179 |
+
|
| 180 |
+
if len(context_frames) >= context_len:
|
| 181 |
+
frames = context_frames[-context_len:]
|
| 182 |
+
else:
|
| 183 |
+
pad_count = context_len - len(context_frames)
|
| 184 |
+
padding = np.repeat(context_frames[:1], pad_count, axis=0)
|
| 185 |
+
frames = np.concatenate([padding, context_frames], axis=0)
|
| 186 |
+
|
| 187 |
+
frames_f = frames.astype(np.float32) / 255.0
|
| 188 |
+
frames_t = torch.from_numpy(frames_f).permute(0, 3, 1, 2)
|
| 189 |
+
input_t = frames_t.reshape(1, -1, 64, 64).to(device)
|
| 190 |
+
last_frame = frames_t[-1:].to(device)
|
| 191 |
+
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
residual = model(input_t)
|
| 194 |
+
pred = torch.clamp(last_frame + residual, 0, 1)
|
| 195 |
+
|
| 196 |
+
pred_np = pred[0].cpu().numpy()
|
| 197 |
+
pred_np = np.transpose(pred_np, (1, 2, 0))
|
| 198 |
+
pred_np = (pred_np * 255).clip(0, 255).astype(np.uint8)
|
| 199 |
+
return pred_np
|
2026-04-09-200000-pergame-specialized-v6/sonic_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb01d9400efd7d18efdd740fc802a6da3a5bc35c86c77d77e301d49de95f8da0
|
| 3 |
+
size 4225810
|
2026-04-09-200000-pergame-specialized-v6/train.log
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Per-game specialized training with channels [24, 48, 96]
|
| 2 |
+
Device: cuda
|
| 3 |
+
|
| 4 |
+
=== Training pong model ===
|
| 5 |
+
Params: 2,087,515, Train samples: 8194, Val samples: 964
|
| 6 |
+
[pong] Epoch 1/50 P1 | steps=4 | train=0.181056 | val=0.387214 | lr=2.99e-04 | 34s
|
| 7 |
+
-> Saved best (val=0.387214)
|
| 8 |
+
[pong] Epoch 2/50 P1 | steps=4 | train=0.139606 | val=0.348065 | lr=2.95e-04 | 34s
|
| 9 |
+
-> Saved best (val=0.348065)
|
| 10 |
+
[pong] Epoch 3/50 P1 | steps=4 | train=0.125677 | val=0.346101 | lr=2.90e-04 | 33s
|
| 11 |
+
-> Saved best (val=0.346101)
|
| 12 |
+
[pong] Epoch 4/50 P1 | steps=4 | train=0.113426 | val=0.332581 | lr=2.82e-04 | 33s
|
| 13 |
+
-> Saved best (val=0.332581)
|
| 14 |
+
[pong] Epoch 5/50 P1 | steps=4 | train=0.096369 | val=0.287765 | lr=2.72e-04 | 33s
|
| 15 |
+
-> Saved best (val=0.287765)
|
| 16 |
+
[pong] Epoch 6/50 P1 | steps=4 | train=0.079409 | val=0.273727 | lr=2.61e-04 | 34s
|
| 17 |
+
-> Saved best (val=0.273727)
|
| 18 |
+
[pong] Epoch 7/50 P1 | steps=4 | train=0.064860 | val=0.250092 | lr=2.47e-04 | 30s
|
| 19 |
+
-> Saved best (val=0.250092)
|
| 20 |
+
[pong] Epoch 8/50 P1 | steps=4 | train=0.054571 | val=0.247202 | lr=2.33e-04 | 33s
|
| 21 |
+
-> Saved best (val=0.247202)
|
| 22 |
+
[pong] Epoch 9/50 P1 | steps=4 | train=0.046480 | val=0.231186 | lr=2.17e-04 | 33s
|
| 23 |
+
-> Saved best (val=0.231186)
|
| 24 |
+
[pong] Epoch 10/50 P1 | steps=4 | train=0.041027 | val=0.200864 | lr=2.00e-04 | 33s
|
| 25 |
+
-> Saved best (val=0.200864)
|
| 26 |
+
[pong] Epoch 11/50 P1 | steps=4 | train=0.036763 | val=0.198458 | lr=1.82e-04 | 33s
|
| 27 |
+
-> Saved best (val=0.198458)
|
| 28 |
+
[pong] Epoch 12/50 P1 | steps=4 | train=0.033568 | val=0.190238 | lr=1.64e-04 | 33s
|
| 29 |
+
-> Saved best (val=0.190238)
|
| 30 |
+
[pong] Epoch 13/50 P1 | steps=4 | train=0.030780 | val=0.193181 | lr=1.46e-04 | 33s
|
| 31 |
+
[pong] Epoch 14/50 P1 | steps=4 | train=0.028308 | val=0.176619 | lr=1.28e-04 | 33s
|
| 32 |
+
-> Saved best (val=0.176619)
|
| 33 |
+
[pong] Epoch 15/50 P1 | steps=4 | train=0.027104 | val=0.175077 | lr=1.10e-04 | 32s
|
| 34 |
+
-> Saved best (val=0.175077)
|
| 35 |
+
[pong] Epoch 16/50 P1 | steps=4 | train=0.025855 | val=0.170153 | lr=9.33e-05 | 33s
|
| 36 |
+
-> Saved best (val=0.170153)
|
| 37 |
+
[pong] Epoch 17/50 P1 | steps=4 | train=0.024338 | val=0.168914 | lr=7.73e-05 | 33s
|
| 38 |
+
-> Saved best (val=0.168914)
|
| 39 |
+
[pong] Epoch 18/50 P1 | steps=4 | train=0.023429 | val=0.163598 | lr=6.26e-05 | 33s
|
| 40 |
+
-> Saved best (val=0.163598)
|
| 41 |
+
[pong] Epoch 19/50 P1 | steps=4 | train=0.022348 | val=0.166410 | lr=4.93e-05 | 33s
|
| 42 |
+
[pong] Epoch 20/50 P1 | steps=4 | train=0.021361 | val=0.161002 | lr=3.77e-05 | 33s
|
| 43 |
+
-> Saved best (val=0.161002)
|
| 44 |
+
[pong] Epoch 21/50 P1 | steps=4 | train=0.020563 | val=0.162222 | lr=2.79e-05 | 34s
|
| 45 |
+
[pong] Epoch 22/50 P1 | steps=4 | train=0.019955 | val=0.160925 | lr=2.02e-05 | 33s
|
| 46 |
+
-> Saved best (val=0.160925)
|
| 47 |
+
[pong] Epoch 23/50 P1 | steps=4 | train=0.019431 | val=0.159964 | lr=1.46e-05 | 33s
|
| 48 |
+
-> Saved best (val=0.159964)
|
| 49 |
+
[pong] Epoch 24/50 P1 | steps=4 | train=0.019006 | val=0.159872 | lr=1.11e-05 | 32s
|
| 50 |
+
-> Saved best (val=0.159872)
|
| 51 |
+
[pong] Epoch 25/50 P1 | steps=4 | train=0.018644 | val=0.159735 | lr=1.00e-05 | 33s
|
| 52 |
+
-> Saved best (val=0.159735)
|
| 53 |
+
[pong] Epoch 26/50 P2 | steps=8 | train=0.030633 | val=0.158183 | lr=9.96e-05 | 93s
|
| 54 |
+
-> Saved best (val=0.158183)
|
| 55 |
+
[pong] Epoch 27/50 P2 | steps=8 | train=0.026713 | val=0.156725 | lr=9.84e-05 | 92s
|
| 56 |
+
-> Saved best (val=0.156725)
|
| 57 |
+
[pong] Epoch 28/50 P2 | steps=8 | train=0.024779 | val=0.148382 | lr=9.65e-05 | 92s
|
| 58 |
+
-> Saved best (val=0.148382)
|
| 59 |
+
[pong] Epoch 29/50 P2 | steps=8 | train=0.023135 | val=0.148092 | lr=9.39e-05 | 95s
|
| 60 |
+
-> Saved best (val=0.148092)
|
| 61 |
+
[pong] Epoch 30/50 P2 | steps=8 | train=0.022060 | val=0.147614 | lr=9.05e-05 | 93s
|
| 62 |
+
-> Saved best (val=0.147614)
|
| 63 |
+
[pong] Epoch 31/50 P2 | steps=8 | train=0.020363 | val=0.147755 | lr=8.66e-05 | 95s
|
| 64 |
+
[pong] Epoch 32/50 P2 | steps=8 | train=0.019286 | val=0.144701 | lr=8.21e-05 | 94s
|
| 65 |
+
-> Saved best (val=0.144701)
|
| 66 |
+
[pong] Epoch 33/50 P2 | steps=8 | train=0.018276 | val=0.145386 | lr=7.70e-05 | 94s
|
| 67 |
+
[pong] Epoch 34/50 P2 | steps=8 | train=0.017180 | val=0.145127 | lr=7.16e-05 | 94s
|
| 68 |
+
[pong] Epoch 35/50 P2 | steps=8 | train=0.016489 | val=0.146141 | lr=6.58e-05 | 94s
|
| 69 |
+
[pong] Epoch 36/50 P2 | steps=8 | train=0.015510 | val=0.145791 | lr=5.98e-05 | 97s
|
| 70 |
+
[pong] Epoch 37/50 P2 | steps=8 | train=0.014867 | val=0.142461 | lr=5.36e-05 | 92s
|
| 71 |
+
-> Saved best (val=0.142461)
|
| 72 |
+
[pong] Epoch 38/50 P2 | steps=8 | train=0.014174 | val=0.143379 | lr=4.74e-05 | 93s
|
| 73 |
+
[pong] Epoch 39/50 P2 | steps=8 | train=0.013558 | val=0.140337 | lr=4.12e-05 | 95s
|
| 74 |
+
-> Saved best (val=0.140337)
|
| 75 |
+
[pong] Epoch 40/50 P2 | steps=8 | train=0.012977 | val=0.141929 | lr=3.52e-05 | 92s
|
| 76 |
+
[pong] Epoch 41/50 P2 | steps=8 | train=0.012432 | val=0.139980 | lr=2.94e-05 | 94s
|
| 77 |
+
-> Saved best (val=0.139980)
|
| 78 |
+
[pong] Epoch 42/50 P2 | steps=8 | train=0.012051 | val=0.140125 | lr=2.40e-05 | 92s
|
| 79 |
+
[pong] Epoch 43/50 P2 | steps=8 | train=0.011630 | val=0.138873 | lr=1.89e-05 | 93s
|
| 80 |
+
-> Saved best (val=0.138873)
|
| 81 |
+
[pong] Epoch 44/50 P2 | steps=8 | train=0.011273 | val=0.137333 | lr=1.44e-05 | 95s
|
| 82 |
+
-> Saved best (val=0.137333)
|
| 83 |
+
[pong] Epoch 45/50 P2 | steps=8 | train=0.010989 | val=0.139273 | lr=1.05e-05 | 95s
|
| 84 |
+
[pong] Epoch 46/50 P2 | steps=8 | train=0.010751 | val=0.138294 | lr=7.12e-06 | 94s
|
| 85 |
+
[pong] Epoch 47/50 P2 | steps=8 | train=0.010587 | val=0.137346 | lr=4.48e-06 | 93s
|
| 86 |
+
[pong] Epoch 48/50 P2 | steps=8 | train=0.010506 | val=0.137741 | lr=2.56e-06 | 95s
|
| 87 |
+
[pong] Epoch 49/50 P2 | steps=8 | train=0.010427 | val=0.136487 | lr=1.39e-06 | 93s
|
| 88 |
+
-> Saved best (val=0.136487)
|
| 89 |
+
[pong] Epoch 50/50 P2 | steps=8 | train=0.010349 | val=0.136725 | lr=1.00e-06 | 94s
|
| 90 |
+
|
| 91 |
+
=== Training sonic model ===
|
| 92 |
+
Params: 2,087,515, Train samples: 30848, Val samples: 3856
|
| 93 |
+
[sonic] Epoch 1/50 P1 | steps=4 | train=0.125317 | val=0.215988 | lr=2.99e-04 | 123s
|
| 94 |
+
-> Saved best (val=0.215988)
|
| 95 |
+
[sonic] Epoch 2/50 P1 | steps=4 | train=0.108658 | val=0.203349 | lr=2.95e-04 | 122s
|
| 96 |
+
-> Saved best (val=0.203349)
|
| 97 |
+
[sonic] Epoch 3/50 P1 | steps=4 | train=0.102061 | val=0.193975 | lr=2.90e-04 | 122s
|
| 98 |
+
-> Saved best (val=0.193975)
|
| 99 |
+
[sonic] Epoch 4/50 P1 | steps=4 | train=0.097724 | val=0.189329 | lr=2.82e-04 | 122s
|
| 100 |
+
-> Saved best (val=0.189329)
|
| 101 |
+
[sonic] Epoch 5/50 P1 | steps=4 | train=0.094953 | val=0.186686 | lr=2.72e-04 | 121s
|
| 102 |
+
-> Saved best (val=0.186686)
|
| 103 |
+
[sonic] Epoch 6/50 P1 | steps=4 | train=0.091985 | val=0.186693 | lr=2.61e-04 | 121s
|
| 104 |
+
[sonic] Epoch 7/50 P1 | steps=4 | train=0.089359 | val=0.179815 | lr=2.47e-04 | 121s
|
| 105 |
+
-> Saved best (val=0.179815)
|
| 106 |
+
[sonic] Epoch 8/50 P1 | steps=4 | train=0.087181 | val=0.179115 | lr=2.33e-04 | 121s
|
| 107 |
+
-> Saved best (val=0.179115)
|
| 108 |
+
[sonic] Epoch 9/50 P1 | steps=4 | train=0.085216 | val=0.178042 | lr=2.17e-04 | 122s
|
| 109 |
+
-> Saved best (val=0.178042)
|
| 110 |
+
[sonic] Epoch 10/50 P1 | steps=4 | train=0.083539 | val=0.181086 | lr=2.00e-04 | 121s
|
| 111 |
+
[sonic] Epoch 11/50 P1 | steps=4 | train=0.081816 | val=0.173978 | lr=1.82e-04 | 121s
|
| 112 |
+
-> Saved best (val=0.173978)
|
| 113 |
+
[sonic] Epoch 12/50 P1 | steps=4 | train=0.080440 | val=0.170419 | lr=1.64e-04 | 121s
|
| 114 |
+
-> Saved best (val=0.170419)
|
| 115 |
+
[sonic] Epoch 13/50 P1 | steps=4 | train=0.078925 | val=0.174053 | lr=1.46e-04 | 122s
|
| 116 |
+
[sonic] Epoch 14/50 P1 | steps=4 | train=0.077366 | val=0.173196 | lr=1.28e-04 | 124s
|
| 117 |
+
[sonic] Epoch 15/50 P1 | steps=4 | train=0.075744 | val=0.170515 | lr=1.10e-04 | 122s
|
| 118 |
+
[sonic] Epoch 16/50 P1 | steps=4 | train=0.074594 | val=0.172440 | lr=9.33e-05 | 123s
|
| 119 |
+
[sonic] Epoch 17/50 P1 | steps=4 | train=0.073175 | val=0.171735 | lr=7.73e-05 | 123s
|
| 120 |
+
[sonic] Epoch 18/50 P1 | steps=4 | train=0.071953 | val=0.170494 | lr=6.26e-05 | 122s
|
| 121 |
+
[sonic] Epoch 19/50 P1 | steps=4 | train=0.070892 | val=0.169376 | lr=4.93e-05 | 122s
|
| 122 |
+
-> Saved best (val=0.169376)
|
| 123 |
+
[sonic] Epoch 20/50 P1 | steps=4 | train=0.069842 | val=0.171050 | lr=3.77e-05 | 122s
|
| 124 |
+
[sonic] Epoch 21/50 P1 | steps=4 | train=0.069004 | val=0.175268 | lr=2.79e-05 | 123s
|
| 125 |
+
[sonic] Epoch 22/50 P1 | steps=4 | train=0.068189 | val=0.174244 | lr=2.02e-05 | 122s
|
| 126 |
+
[sonic] Epoch 23/50 P1 | steps=4 | train=0.067582 | val=0.173629 | lr=1.46e-05 | 121s
|
| 127 |
+
[sonic] Epoch 24/50 P1 | steps=4 | train=0.067093 | val=0.173393 | lr=1.11e-05 | 121s
|
| 128 |
+
[sonic] Epoch 25/50 P1 | steps=4 | train=0.066864 | val=0.175190 | lr=1.00e-05 | 121s
|
| 129 |
+
[sonic] Epoch 26/50 P2 | steps=8 | train=0.099375 | val=0.172702 | lr=9.96e-05 | 349s
|
| 130 |
+
[sonic] Epoch 27/50 P2 | steps=8 | train=0.098544 | val=0.169787 | lr=9.84e-05 | 348s
|
| 131 |
+
[sonic] Epoch 28/50 P2 | steps=8 | train=0.097575 | val=0.174494 | lr=9.65e-05 | 347s
|
| 132 |
+
[sonic] Epoch 29/50 P2 | steps=8 | train=0.096509 | val=0.173821 | lr=9.39e-05 | 350s
|
| 133 |
+
[sonic] Epoch 30/50 P2 | steps=8 | train=0.095506 | val=0.181292 | lr=9.05e-05 | 347s
|
| 134 |
+
[sonic] Epoch 31/50 P2 | steps=8 | train=0.094344 | val=0.181127 | lr=8.66e-05 | 374s
|
| 135 |
+
[sonic] Epoch 32/50 P2 | steps=8 | train=0.092923 | val=0.183132 | lr=8.21e-05 | 380s
|
| 136 |
+
[sonic] Epoch 33/50 P2 | steps=8 | train=0.092020 | val=0.172820 | lr=7.70e-05 | 382s
|
| 137 |
+
[sonic] Epoch 34/50 P2 | steps=8 | train=0.090893 | val=0.178188 | lr=7.16e-05 | 369s
|
| 138 |
+
[sonic] Epoch 35/50 P2 | steps=8 | train=0.089567 | val=0.178557 | lr=6.58e-05 | 376s
|
| 139 |
+
[sonic] Epoch 36/50 P2 | steps=8 | train=0.088448 | val=0.180559 | lr=5.98e-05 | 375s
|
| 140 |
+
[sonic] Epoch 37/50 P2 | steps=8 | train=0.087275 | val=0.183680 | lr=5.36e-05 | 376s
|
| 141 |
+
[sonic] Epoch 38/50 P2 | steps=8 | train=0.086070 | val=0.180474 | lr=4.74e-05 | 380s
|
| 142 |
+
[sonic] Epoch 39/50 P2 | steps=8 | train=0.084930 | val=0.180976 | lr=4.12e-05 | 381s
|
| 143 |
+
[sonic] Epoch 40/50 P2 | steps=8 | train=0.083936 | val=0.188466 | lr=3.52e-05 | 382s
|
| 144 |
+
[sonic] Epoch 41/50 P2 | steps=8 | train=0.082753 | val=0.183426 | lr=2.94e-05 | 378s
|
| 145 |
+
[sonic] Epoch 42/50 P2 | steps=8 | train=0.081880 | val=0.184469 | lr=2.40e-05 | 379s
|
| 146 |
+
[sonic] Epoch 43/50 P2 | steps=8 | train=0.080913 | val=0.187971 | lr=1.89e-05 | 386s
|
| 147 |
+
[sonic] Epoch 44/50 P2 | steps=8 | train=0.080052 | val=0.184644 | lr=1.44e-05 | 364s
|
| 148 |
+
[sonic] Epoch 45/50 P2 | steps=8 | train=0.079292 | val=0.185277 | lr=1.05e-05 | 345s
|
| 149 |
+
[sonic] Epoch 46/50 P2 | steps=8 | train=0.078618 | val=0.190683 | lr=7.12e-06 | 349s
|
| 150 |
+
[sonic] Epoch 47/50 P2 | steps=8 | train=0.078161 | val=0.187349 | lr=4.48e-06 | 346s
|
| 151 |
+
[sonic] Epoch 48/50 P2 | steps=8 | train=0.077754 | val=0.186002 | lr=2.56e-06 | 351s
|
| 152 |
+
[sonic] Epoch 49/50 P2 | steps=8 | train=0.077498 | val=0.187672 | lr=1.39e-06 | 346s
|
| 153 |
+
[sonic] Epoch 50/50 P2 | steps=8 | train=0.077320 | val=0.187627 | lr=1.00e-06 | 347s
|
| 154 |
+
|
| 155 |
+
=== Training pole_position model ===
|
| 156 |
+
Params: 2,087,515, Train samples: 4097, Val samples: 482
|
| 157 |
+
[pole_position] Epoch 1/50 P1 | steps=4 | train=0.089510 | val=0.099064 | lr=2.99e-04 | 16s
|
| 158 |
+
-> Saved best (val=0.099064)
|
| 159 |
+
[pole_position] Epoch 2/50 P1 | steps=4 | train=0.058853 | val=0.087623 | lr=2.95e-04 | 17s
|
| 160 |
+
-> Saved best (val=0.087623)
|
| 161 |
+
[pole_position] Epoch 3/50 P1 | steps=4 | train=0.054709 | val=0.082491 | lr=2.90e-04 | 17s
|
| 162 |
+
-> Saved best (val=0.082491)
|
| 163 |
+
[pole_position] Epoch 4/50 P1 | steps=4 | train=0.051302 | val=0.078103 | lr=2.82e-04 | 17s
|
| 164 |
+
-> Saved best (val=0.078103)
|
| 165 |
+
[pole_position] Epoch 5/50 P1 | steps=4 | train=0.048063 | val=0.075398 | lr=2.72e-04 | 17s
|
| 166 |
+
-> Saved best (val=0.075398)
|
| 167 |
+
[pole_position] Epoch 6/50 P1 | steps=4 | train=0.045211 | val=0.073670 | lr=2.61e-04 | 17s
|
| 168 |
+
-> Saved best (val=0.073670)
|
| 169 |
+
[pole_position] Epoch 7/50 P1 | steps=4 | train=0.043285 | val=0.066788 | lr=2.47e-04 | 17s
|
| 170 |
+
-> Saved best (val=0.066788)
|
| 171 |
+
[pole_position] Epoch 8/50 P1 | steps=4 | train=0.041317 | val=0.065624 | lr=2.33e-04 | 17s
|
| 172 |
+
-> Saved best (val=0.065624)
|
| 173 |
+
[pole_position] Epoch 9/50 P1 | steps=4 | train=0.039757 | val=0.063329 | lr=2.17e-04 | 16s
|
| 174 |
+
-> Saved best (val=0.063329)
|
| 175 |
+
[pole_position] Epoch 10/50 P1 | steps=4 | train=0.038379 | val=0.064602 | lr=2.00e-04 | 17s
|
| 176 |
+
[pole_position] Epoch 11/50 P1 | steps=4 | train=0.037215 | val=0.063741 | lr=1.82e-04 | 17s
|
| 177 |
+
[pole_position] Epoch 12/50 P1 | steps=4 | train=0.036174 | val=0.059305 | lr=1.64e-04 | 16s
|
| 178 |
+
-> Saved best (val=0.059305)
|
| 179 |
+
[pole_position] Epoch 13/50 P1 | steps=4 | train=0.035457 | val=0.063606 | lr=1.46e-04 | 17s
|
| 180 |
+
[pole_position] Epoch 14/50 P1 | steps=4 | train=0.034620 | val=0.059135 | lr=1.28e-04 | 16s
|
| 181 |
+
-> Saved best (val=0.059135)
|
| 182 |
+
[pole_position] Epoch 15/50 P1 | steps=4 | train=0.034073 | val=0.057301 | lr=1.10e-04 | 15s
|
| 183 |
+
-> Saved best (val=0.057301)
|
| 184 |
+
[pole_position] Epoch 16/50 P1 | steps=4 | train=0.033230 | val=0.058048 | lr=9.33e-05 | 17s
|
| 185 |
+
[pole_position] Epoch 17/50 P1 | steps=4 | train=0.032583 | val=0.056697 | lr=7.73e-05 | 17s
|
| 186 |
+
-> Saved best (val=0.056697)
|
| 187 |
+
[pole_position] Epoch 18/50 P1 | steps=4 | train=0.032016 | val=0.055377 | lr=6.26e-05 | 17s
|
| 188 |
+
-> Saved best (val=0.055377)
|
| 189 |
+
[pole_position] Epoch 19/50 P1 | steps=4 | train=0.031563 | val=0.054685 | lr=4.93e-05 | 17s
|
| 190 |
+
-> Saved best (val=0.054685)
|
| 191 |
+
[pole_position] Epoch 20/50 P1 | steps=4 | train=0.031261 | val=0.055312 | lr=3.77e-05 | 18s
|
| 192 |
+
[pole_position] Epoch 21/50 P1 | steps=4 | train=0.030835 | val=0.054009 | lr=2.79e-05 | 17s
|
| 193 |
+
-> Saved best (val=0.054009)
|
| 194 |
+
[pole_position] Epoch 22/50 P1 | steps=4 | train=0.030547 | val=0.054350 | lr=2.02e-05 | 17s
|
| 195 |
+
[pole_position] Epoch 23/50 P1 | steps=4 | train=0.030317 | val=0.053623 | lr=1.46e-05 | 17s
|
| 196 |
+
-> Saved best (val=0.053623)
|
| 197 |
+
[pole_position] Epoch 24/50 P1 | steps=4 | train=0.030126 | val=0.053729 | lr=1.11e-05 | 17s
|
| 198 |
+
[pole_position] Epoch 25/50 P1 | steps=4 | train=0.029983 | val=0.053887 | lr=1.00e-05 | 17s
|
| 199 |
+
[pole_position] Epoch 26/50 P2 | steps=8 | train=0.044514 | val=0.056167 | lr=9.96e-05 | 50s
|
| 200 |
+
[pole_position] Epoch 27/50 P2 | steps=8 | train=0.043991 | val=0.055933 | lr=9.84e-05 | 51s
|
| 201 |
+
[pole_position] Epoch 28/50 P2 | steps=8 | train=0.043431 | val=0.055243 | lr=9.65e-05 | 51s
|
| 202 |
+
[pole_position] Epoch 29/50 P2 | steps=8 | train=0.042928 | val=0.055930 | lr=9.39e-05 | 50s
|
| 203 |
+
[pole_position] Epoch 30/50 P2 | steps=8 | train=0.042540 | val=0.055295 | lr=9.05e-05 | 50s
|
| 204 |
+
[pole_position] Epoch 31/50 P2 | steps=8 | train=0.041979 | val=0.053855 | lr=8.66e-05 | 50s
|
| 205 |
+
[pole_position] Epoch 32/50 P2 | steps=8 | train=0.041306 | val=0.054584 | lr=8.21e-05 | 51s
|
| 206 |
+
[pole_position] Epoch 33/50 P2 | steps=8 | train=0.040972 | val=0.055035 | lr=7.70e-05 | 51s
|
| 207 |
+
[pole_position] Epoch 34/50 P2 | steps=8 | train=0.040348 | val=0.056113 | lr=7.16e-05 | 52s
|
| 208 |
+
[pole_position] Epoch 35/50 P2 | steps=8 | train=0.039655 | val=0.053253 | lr=6.58e-05 | 51s
|
| 209 |
+
-> Saved best (val=0.053253)
|
| 210 |
+
[pole_position] Epoch 36/50 P2 | steps=8 | train=0.039092 | val=0.054627 | lr=5.98e-05 | 50s
|
| 211 |
+
[pole_position] Epoch 37/50 P2 | steps=8 | train=0.038342 | val=0.056780 | lr=5.36e-05 | 51s
|
| 212 |
+
[pole_position] Epoch 38/50 P2 | steps=8 | train=0.037674 | val=0.056801 | lr=4.74e-05 | 50s
|
| 213 |
+
[pole_position] Epoch 39/50 P2 | steps=8 | train=0.036766 | val=0.052438 | lr=4.12e-05 | 50s
|
| 214 |
+
-> Saved best (val=0.052438)
|
| 215 |
+
[pole_position] Epoch 40/50 P2 | steps=8 | train=0.036030 | val=0.056021 | lr=3.52e-05 | 50s
|
| 216 |
+
[pole_position] Epoch 41/50 P2 | steps=8 | train=0.035279 | val=0.057277 | lr=2.94e-05 | 48s
|
| 217 |
+
[pole_position] Epoch 42/50 P2 | steps=8 | train=0.034573 | val=0.054459 | lr=2.40e-05 | 47s
|
| 218 |
+
[pole_position] Epoch 43/50 P2 | steps=8 | train=0.033854 | val=0.054140 | lr=1.89e-05 | 47s
|
| 219 |
+
[pole_position] Epoch 44/50 P2 | steps=8 | train=0.033173 | val=0.053559 | lr=1.44e-05 | 47s
|
| 220 |
+
[pole_position] Epoch 45/50 P2 | steps=8 | train=0.032674 | val=0.053773 | lr=1.05e-05 | 48s
|
| 221 |
+
[pole_position] Epoch 46/50 P2 | steps=8 | train=0.032193 | val=0.054777 | lr=7.12e-06 | 46s
|
| 222 |
+
[pole_position] Epoch 47/50 P2 | steps=8 | train=0.031758 | val=0.054334 | lr=4.48e-06 | 47s
|
| 223 |
+
[pole_position] Epoch 48/50 P2 | steps=8 | train=0.031504 | val=0.054399 | lr=2.56e-06 | 47s
|
| 224 |
+
[pole_position] Epoch 49/50 P2 | steps=8 | train=0.031345 | val=0.053941 | lr=1.39e-06 | 46s
|
| 225 |
+
[pole_position] Epoch 50/50 P2 | steps=8 | train=0.031216 | val=0.054575 | lr=1.00e-06 | 47s
|
| 226 |
+
|
| 227 |
+
All games trained.
|