Update example.py
Browse files- example.py +2 -2
example.py
CHANGED
|
@@ -3,11 +3,11 @@ import torch
|
|
| 3 |
from model import SRResNet # Import your model class
|
| 4 |
|
| 5 |
# Load the pre-trained model
|
| 6 |
-
model =
|
| 7 |
model.load_state_dict(torch.load("best_netG.pth")) # Load model weights
|
| 8 |
|
| 9 |
# Create a random input tensor (e.g., for testing purposes)
|
| 10 |
-
input_tensor = torch.rand(1,
|
| 11 |
|
| 12 |
# Perform inference
|
| 13 |
output_tensor = model(input_tensor)
|
|
|
|
| 3 |
from model import SRResNet # Import your model class
|
| 4 |
|
| 5 |
# Load the pre-trained model
|
| 6 |
+
model = SRResNet(in_channels=12, out_channels=72, upscale=1)
|
| 7 |
model.load_state_dict(torch.load("best_netG.pth")) # Load model weights
|
| 8 |
|
| 9 |
# Create a random input tensor (e.g., for testing purposes)
|
| 10 |
+
input_tensor = torch.rand(1, 12, 128, 128) # Batch size 1, 3 channels, 128x128 image
|
| 11 |
|
| 12 |
# Perform inference
|
| 13 |
output_tensor = model(input_tensor)
|