| import os
|
| import torch
|
| import data_setup, engine, model_builder, utils
|
| from torchvision import transforms, models
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("-e", "--num_epochs", help="an integer to perform number of epochs", type=int)
|
| parser.add_argument("-b", "--batch_size", help="an integer of number of element per batch", type=int)
|
|
|
| parser.add_argument("-lr", "--learning_rate", help="a float for the learning rate", type=float)
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| NUM_EPOCHS = args.num_epochs if args.num_epochs else 10
|
| BATCH_SIZE = args.batch_size
|
|
|
| LEARNING_RATE = args.learning_rate if args.learning_rate else 0.001
|
|
|
|
|
| train_dir = "data/pizza_sushi_steak/train"
|
| test_dir = "data/pizza_sushi_steak/test"
|
|
|
| def main():
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
| data_transform = transforms.Compose([
|
| transforms.Resize(size=(224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225]),
|
| ])
|
|
|
|
|
| train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
|
| train_dir=train_dir,
|
| test_dir=test_dir,
|
| transform=data_transform,
|
| batch_size=BATCH_SIZE,
|
| num_workers=0
|
| )
|
|
|
|
|
| model = model_builder.create_model_baseline_effnetb0(out_feats=len(class_names), device=device)
|
|
|
|
|
| loss_fn = torch.nn.CrossEntropyLoss()
|
| optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
|
|
|
|
|
| engine.train(model=model,
|
| train_dataloader=train_dataloader,
|
| test_dataloader=test_dataloader,
|
| loss_fn=loss_fn,
|
| optimizer=optimizer,
|
| epochs=NUM_EPOCHS,
|
| device=device)
|
|
|
|
|
| utils.save_model(model=model, target_dir="models", model_name="tinyfood-effnet.pt")
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|