| """ |
| Easy inference script for Fake Image Detection |
| Usage: python inference.py --image path/to/image.jpg |
| """ |
|
|
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| import pickle |
| import json |
| import argparse |
| from huggingface_hub import hf_hub_download |
| from model import EnhancedFreqVAE, EdgeNormalizingFlow, SemanticDeepSVDD, Ensemble |
|
|
|
|
| def load_models(device='cuda'): |
| """Load all models from Hugging Face""" |
| repo_id = "ash12321/fake-image-detection-ensemble" |
| |
| print("📥 Downloading models from Hugging Face...") |
| |
| |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
| |
| |
| print("Loading Frequency VAE...") |
| freq_vae = EnhancedFreqVAE() |
| vae_path = hf_hub_download(repo_id=repo_id, filename="freq_vae.pth") |
| freq_vae.load_state_dict(torch.load(vae_path, map_location=device)) |
| freq_vae.to(device) |
| freq_vae.eval() |
| |
| print("Loading Edge Flow...") |
| edge_flow = EdgeNormalizingFlow() |
| flow_path = hf_hub_download(repo_id=repo_id, filename="edge_flow.pth") |
| edge_flow.load_state_dict(torch.load(flow_path, map_location=device)) |
| edge_flow.to(device) |
| edge_flow.eval() |
| |
| print("Loading Semantic SVDD...") |
| semantic_svdd = SemanticDeepSVDD() |
| svdd_path = hf_hub_download(repo_id=repo_id, filename="semantic_svdd.pth") |
| checkpoint = torch.load(svdd_path, map_location=device) |
| semantic_svdd.load_state_dict(checkpoint['model']) |
| semantic_svdd.center = checkpoint['center'] |
| semantic_svdd.to(device) |
| semantic_svdd.eval() |
| |
| |
| print("Loading traditional ML models...") |
| texture_path = hf_hub_download(repo_id=repo_id, filename="texture_ocsvm.pkl") |
| with open(texture_path, 'rb') as f: |
| texture_ocsvm = pickle.load(f) |
| |
| color_path = hf_hub_download(repo_id=repo_id, filename="color_model.pkl") |
| with open(color_path, 'rb') as f: |
| color_model = pickle.load(f) |
| |
| stat_path = hf_hub_download(repo_id=repo_id, filename="stat.pkl") |
| with open(stat_path, 'rb') as f: |
| stat = pickle.load(f) |
| |
| iforest_path = hf_hub_download(repo_id=repo_id, filename="iforest.pkl") |
| with open(iforest_path, 'rb') as f: |
| iforest = pickle.load(f) |
| |
| lof_path = hf_hub_download(repo_id=repo_id, filename="lof.pkl") |
| with open(lof_path, 'rb') as f: |
| lof = pickle.load(f) |
| |
| gmm_path = hf_hub_download(repo_id=repo_id, filename="gmm.pkl") |
| with open(gmm_path, 'rb') as f: |
| gmm = pickle.load(f) |
| |
| |
| models_dict = { |
| 'freq_vae': freq_vae, |
| 'texture_ocsvm': texture_ocsvm, |
| 'color_model': color_model, |
| 'edge_flow': edge_flow, |
| 'semantic_svdd': semantic_svdd, |
| 'stat': stat, |
| 'iforest': iforest, |
| 'lof': lof, |
| 'gmm': gmm |
| } |
| |
| ensemble = Ensemble(models_dict) |
| ensemble.wts = config['weights'] |
| ensemble.norms = config['norms'] |
| ensemble.thresh = config['thresh'] |
| |
| print("✓ All models loaded!\n") |
| return ensemble, device |
|
|
|
|
| def predict_image(image_path, ensemble, device): |
| """Predict if an image is fake""" |
| |
| img = Image.open(image_path) |
| img = img.resize((256, 256), Image.LANCZOS).convert('RGB') |
| |
| tfm = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]) |
| ]) |
| img_tensor = tfm(img) |
| |
| |
| is_fake, score, individual_scores = ensemble.predict(img_tensor, device) |
| |
| return { |
| 'prediction': 'FAKE' if is_fake else 'REAL', |
| 'confidence': abs(score), |
| 'anomaly_score': score, |
| 'individual_scores': individual_scores |
| } |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='Detect fake images') |
| parser.add_argument('--image', type=str, required=True, help='Path to image') |
| parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)') |
| args = parser.parse_args() |
| |
| |
| device = args.device if torch.cuda.is_available() else 'cpu' |
| print(f"Using device: {device}\n") |
| |
| |
| ensemble, device = load_models(device) |
| |
| |
| print(f"Analyzing: {args.image}") |
| result = predict_image(args.image, ensemble, device) |
| |
| print("\n" + "="*50) |
| print("RESULT") |
| print("="*50) |
| print(f"Prediction: {result['prediction']}") |
| print(f"Confidence: {result['confidence']:.4f}") |
| print(f"Anomaly Score: {result['anomaly_score']:.4f}") |
| print(f"\nIndividual Model Scores:") |
| for model, score in result['individual_scores'].items(): |
| print(f" {model}: {score:.4f}") |
| print("="*50) |
|
|