|
|
| import torch
|
| import torch.nn as nn
|
| import matplotlib.pyplot as plt
|
| import numpy as np
|
| from PIL import Image
|
| import os
|
|
|
|
|
| class Generator(nn.Module):
|
| def __init__(self):
|
| super(Generator, self).__init__()
|
|
|
| self.label_emb = nn.Embedding(10, 10)
|
|
|
| self.model = nn.Sequential(
|
| nn.Linear(100 + 10, 256),
|
| nn.LeakyReLU(0.2),
|
| nn.BatchNorm1d(256),
|
|
|
| nn.Linear(256, 512),
|
| nn.LeakyReLU(0.2),
|
| nn.BatchNorm1d(512),
|
|
|
| nn.Linear(512, 1024),
|
| nn.LeakyReLU(0.2),
|
| nn.BatchNorm1d(1024),
|
|
|
| nn.Linear(1024, 784),
|
| nn.Tanh()
|
| )
|
|
|
| def forward(self, noise, labels):
|
| label_embedding = self.label_emb(labels)
|
| gen_input = torch.cat((noise, label_embedding), -1)
|
| img = self.model(gen_input)
|
| img = img.view(img.size(0), 1, 28, 28)
|
| return img
|
|
|
|
|
| class SevenGenInference:
|
| def __init__(self, model_path='models/7gen_generator.pth'):
|
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| self.latent_dim = 100
|
|
|
|
|
| self.generator = Generator().to(self.device)
|
| self.generator.load_state_dict(torch.load(model_path, map_location=self.device))
|
| self.generator.eval()
|
|
|
| print(f"🚀 7Gen yüklendi! Cihaz: {self.device}")
|
|
|
| def generate_digit(self, digit, count=5):
|
| """Belirli bir rakamdan istenen sayıda üret"""
|
| with torch.no_grad():
|
| z = torch.randn(count, self.latent_dim).to(self.device)
|
| labels = torch.full((count,), digit).to(self.device)
|
|
|
| images = self.generator(z, labels)
|
| images = (images + 1) / 2
|
|
|
| return images.cpu()
|
|
|
| def visualize_digits(self, digit, count=5, save_path=None):
|
| """Üretilen rakamları görselleştir"""
|
| images = self.generate_digit(digit, count)
|
|
|
| fig, axes = plt.subplots(1, count, figsize=(2*count, 2))
|
| if count == 1:
|
| axes = [axes]
|
|
|
| for i, ax in enumerate(axes):
|
| ax.imshow(images[i][0], cmap='gray')
|
| ax.axis('off')
|
| ax.set_title(f'Digit: {digit}')
|
|
|
| plt.tight_layout()
|
|
|
| if save_path:
|
| plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| print(f"💾 Görsel kaydedildi: {save_path}")
|
|
|
| plt.show()
|
|
|
| def generate_grid(self, samples_per_digit=10, save_path=None):
|
| """Her rakamdan örneklerle 10x10 grid oluştur"""
|
| all_images = []
|
|
|
| for digit in range(10):
|
| images = self.generate_digit(digit, samples_per_digit)
|
| all_images.append(images)
|
|
|
| all_images = torch.cat(all_images, dim=0)
|
|
|
| fig, axes = plt.subplots(10, samples_per_digit, figsize=(15, 15))
|
|
|
| for i in range(10):
|
| for j in range(samples_per_digit):
|
| idx = i * samples_per_digit + j
|
| axes[i, j].imshow(all_images[idx][0], cmap='gray')
|
| axes[i, j].axis('off')
|
|
|
| if j == 0:
|
| axes[i, j].set_ylabel(f'{i}', rotation=0, size=20, labelpad=20)
|
|
|
| plt.suptitle('7Gen - Üretilen Rakamlar', size=20)
|
| plt.tight_layout()
|
|
|
| if save_path:
|
| plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| print(f"💾 Grid kaydedildi: {save_path}")
|
|
|
| plt.show()
|
|
|
| def save_as_png(self, digit, count=1, output_dir='output'):
|
| """Tekil PNG dosyaları olarak kaydet"""
|
| os.makedirs(output_dir, exist_ok=True)
|
|
|
| images = self.generate_digit(digit, count)
|
|
|
| for i in range(count):
|
| img = images[i][0].numpy()
|
| img = (img * 255).astype(np.uint8)
|
|
|
| pil_img = Image.fromarray(img)
|
| filename = f"{output_dir}/digit_{digit}_{i+1}.png"
|
| pil_img.save(filename)
|
|
|
| print(f"💾 Kaydedildi: {filename}")
|
|
|
| def interactive_generate(self):
|
| """İnteraktif kullanım"""
|
| print("\n🎮 7Gen İnteraktif Mod")
|
| print("Çıkmak için 'q' yazın")
|
|
|
| while True:
|
| try:
|
| digit_input = input("\nHangi rakamı üretmek istersin? (0-9): ")
|
|
|
| if digit_input.lower() == 'q':
|
| print("👋 Görüşürüz!")
|
| break
|
|
|
| digit = int(digit_input)
|
| if 0 <= digit <= 9:
|
| count = int(input("Kaç tane üreteyim? (1-20): "))
|
| if 1 <= count <= 20:
|
| self.visualize_digits(digit, count)
|
| else:
|
| print("❌ 1-20 arası bir sayı gir!")
|
| else:
|
| print("❌ 0-9 arası bir rakam gir!")
|
|
|
| except ValueError:
|
| print("❌ Geçerli bir sayı gir!")
|
| except KeyboardInterrupt:
|
| print("\n👋 Görüşürüz!")
|
| break
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| seven_gen = SevenGenInference()
|
|
|
|
|
| print("\n📝 Örnek kullanımlar:")
|
| print("1. Tekil rakam üret")
|
| seven_gen.visualize_digits(digit=7, count=5)
|
|
|
| print("\n2. Grid oluştur")
|
| seven_gen.generate_grid(samples_per_digit=10, save_path='7gen_showcase.png')
|
|
|
| print("\n3. PNG olarak kaydet")
|
| seven_gen.save_as_png(digit=5, count=3, output_dir='output')
|
|
|
| print("\n4. İnteraktif mod")
|
| seven_gen.interactive_generate() |