| import argparse |
| import os |
| import torch |
| from diffusers import StableDiffusionPipeline |
| from peft import LoraConfig, get_peft_model |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
| from torchvision import transforms |
|
|
| |
| class ImageDataset(Dataset): |
| def __init__(self, folder, size=512): |
| self.files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith((".png", ".jpg", ".jpeg"))] |
| self.transform = transforms.Compose([ |
| transforms.Resize((size, size)), |
| transforms.ToTensor() |
| ]) |
|
|
| def __len__(self): |
| return len(self.files) |
|
|
| def __getitem__(self, idx): |
| img = Image.open(self.files[idx]).convert("RGB") |
| return self.transform(img) |
|
|
| def main(args): |
| |
| model_id = "runwayml/stable-diffusion-v1-5" |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") |
|
|
| |
| lora_config = LoraConfig( |
| r=args.rank, |
| lora_alpha=16, |
| target_modules=["to_q", "to_v"], |
| lora_dropout=0.1, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| pipe.unet = get_peft_model(pipe.unet, lora_config) |
|
|
| |
| dataset = ImageDataset(args.images_dir) |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=True) |
|
|
| |
| optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=args.learning_rate) |
|
|
| |
| for epoch in range(args.num_epochs): |
| for batch in dataloader: |
| batch = batch.to("cuda") |
| noise = torch.randn_like(batch) |
| optimizer.zero_grad() |
| loss = pipe.unet(batch, noise)["loss"] |
| loss.backward() |
| optimizer.step() |
| print(f"✅ Epoch {epoch+1}/{args.num_epochs} finalizado.") |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| torch.save(pipe.unet.state_dict(), os.path.join(args.output_dir, "lora.safetensors")) |
| print("✅ Treinamento concluído. Arquivo salvo em lora.safetensors") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--images_dir", type=str, required=True) |
| parser.add_argument("--output_dir", type=str, required=True) |
| parser.add_argument("--learning_rate", type=float, default=1e-4) |
| parser.add_argument("--num_epochs", type=int, default=10) |
| parser.add_argument("--rank", type=int, default=4) |
| args = parser.parse_args() |
| main(args) |