| import torch |
| from torch import nn |
| from torch.utils.data import DataLoader |
|
|
| # Hyperparameters |
| image_size = (224, 224, 3) # Adjust based on your data |
|
|
| # Define the Generator Network |
| class Generator(nn.Module): |
| def __init__(self): |
| super(Generator, self).__init__() |
| # Define convolutional layers with appropriate filters and activations |
| self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1) |
| # ... Add more convolutional layers as needed |
| self.conv_final = nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1, activation=nn.Tanh) # Tanh for shadow intensity |
|
|
| def forward(self, x): |
| # Define the forward pass through the convolutional layers |
| x = self.conv1(x) |
| # ... Forward pass through remaining convolutional layers |
| return self.conv_final(x) |
|
|
| # Define the Discriminator Network |
| class Discriminator(nn.Module): |
| def __init__(self): |
| super(Discriminator, self).__init__() |
| # Define convolutional layers with appropriate filters and activations |
| self.conv1 = nn.Conv2d(6, 32, kernel_size=3, stride=1, padding=1) |
| # ... Add more convolutional layers as needed |
| self.linear = nn.Linear(128, 1) # Final layer with sigmoid activation |
|
|
| def forward(self, car, shadow): |
| # Concatenate car and shadow features |
| x = torch.cat([car, shadow], dim=1) |
| # Define the forward pass through the convolutional layers |
| x = self.conv1(x) |
| # ... Forward pass through remaining convolutional layers |
| return torch.sigmoid(self.linear(x)) |
|
|
| # Create data loaders for training and validation data |
| # ... (Implement data loading logic using PyTorch's DataLoader) |
|
|
| # Create the models |
| generator = Generator() |
| discriminator = Discriminator() |
|
|
| # Define loss function and optimizer |
| criterion = nn.BCELoss() |
| g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002) |
| d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002) |
|
|
| # Training loop |
| for epoch in range(epochs): |
| # Train the Discriminator |
| # ... (Implement discriminator training logic with loss calculation and updates) |
|
|
| # Train the Generator |
| # ... (Implement generator training logic with loss calculation and updates) |
|
|
| # Print training progress |
| # ... (Print loss values or other metrics) |
|
|
| # Save the trained generator |
| torch.save(generator.state_dict(), 'generator.pt') |
|
|