Update train.py
Browse files
train.py
CHANGED
|
@@ -59,7 +59,7 @@ print(f"using {device} device")
|
|
| 59 |
|
| 60 |
|
| 61 |
|
| 62 |
-
class
|
| 63 |
def __init__(
|
| 64 |
self,
|
| 65 |
image_size=32,
|
|
@@ -89,7 +89,7 @@ class TensorMapperImageClassification(LiteTensorMapper):
|
|
| 89 |
out = self.classifier(embedding)
|
| 90 |
return out
|
| 91 |
|
| 92 |
-
model =
|
| 93 |
print(model)
|
| 94 |
|
| 95 |
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
|
| 62 |
+
class LiteTensorMapperImageClassification(LiteTensorMapper):
|
| 63 |
def __init__(
|
| 64 |
self,
|
| 65 |
image_size=32,
|
|
|
|
| 89 |
out = self.classifier(embedding)
|
| 90 |
return out
|
| 91 |
|
| 92 |
+
model = LiteTensorMapperImageClassification().to(device)
|
| 93 |
print(model)
|
| 94 |
|
| 95 |
|