import os import sys import time import math import json import io from contextlib import nullcontext from pathlib import Path import gradio as gr import torch from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from dotenv import load_dotenv from PIL import Image, ImageOps from transformers import AutoProcessor, AutoImageProcessor, AutoModelForImageClassification ROOT_DIR = Path(__file__).resolve().parent SCRIPTS_DIR = ROOT_DIR / "scripts" if str(SCRIPTS_DIR) not in sys.path: sys.path.insert(0, str(SCRIPTS_DIR)) load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") BASE_MODEL_ID = "llava-hf/llava-1.5-7b-hf" ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model" ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD") ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "") CPU_FALLBACK_MODEL_ID = os.getenv("CPU_FALLBACK_MODEL_ID", "DaMsTaR/Detecto-DeepFake_Image_Detector") DEFAULT_INVERT_FALLBACK = CPU_FALLBACK_MODEL_ID.lower() == "damstar/detecto-deepfake_image_detector" INVERT_FALLBACK_OUTPUT = os.getenv("INVERT_FALLBACK_OUTPUT", str(DEFAULT_INVERT_FALLBACK)).strip().lower() == "true" TMOS_PROMPT = "USER: \nIs this video real or produced by AI?\nASSISTANT:" TARGET_IMAGE_SIZE = 336 THRESHOLD = 0.5 model = None processor = None inference_device = None def resolve_inference_device(model_obj) -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") device_map = getattr(model_obj, "hf_device_map", None) if isinstance(device_map, dict): for mapped in device_map.values(): if isinstance(mapped, str) and mapped.startswith("cuda"): return torch.device(mapped) return torch.device("cpu") def find_classifier_weight_tensor(model_obj): visited = set() queue = [model_obj] while queue: current = queue.pop(0) if current is None: continue obj_id = id(current) if obj_id in visited: continue visited.add(obj_id) classifier = getattr(current, "classifier", None) if classifier is not None and hasattr(classifier, "weight"): return classifier.weight for attr in ("model", "base_model", "module"): nested = getattr(current, attr, None) if nested is not None: queue.append(nested) return None def count_lora_layers(model_obj) -> int: count = 0 for _, module in model_obj.named_modules(): if hasattr(module, "lora_A") and hasattr(module, "lora_B"): count += 1 return count def is_tmos_adapter_config(cfg: dict) -> bool: modules_to_save = cfg.get("modules_to_save") or [] target_modules = set(cfg.get("target_modules") or []) required_targets = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} return ( "classifier" in modules_to_save and cfg.get("r") == 64 and required_targets.issubset(target_modules) ) def load_local_adapter_config(adapter_dir: Path) -> dict | None: cfg_path = adapter_dir / "adapter_config.json" if not cfg_path.exists(): return None with cfg_path.open("r", encoding="utf-8") as fp: return json.load(fp) def load_remote_adapter_config(repo_id: str, subfolder: str) -> dict | None: from peft import PeftConfig try: peft_cfg = PeftConfig.from_pretrained(repo_id, subfolder=subfolder, token=HF_TOKEN) return peft_cfg.to_dict() except Exception: return None def select_torch_dtype() -> torch.dtype: if torch.cuda.is_available(): return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 return torch.float32 def load_tmos_model(): global model, processor, inference_device if not torch.cuda.is_available(): raise RuntimeError( "TMOS mode requires GPU hardware. CPU fallback should be used on CPU-only environments." ) from peft import PeftModel from tmos_classifier import TMOSClassifier adapter_source = None local_adapter_file = next( ( candidate for candidate in ( ADAPTER_PATH / "adapter_model.safetensors", ADAPTER_PATH / "adapter_model.bin", ) if candidate.exists() ), None, ) selected_subfolder = "" if local_adapter_file is not None: adapter_source = str(ADAPTER_PATH) local_cfg = load_local_adapter_config(ADAPTER_PATH) if local_cfg is None or not is_tmos_adapter_config(local_cfg): raise RuntimeError( "Local adapter exists but is not TMOS-compatible. Expected modules_to_save=['classifier'], r=64, and TMOS target modules." ) else: adapter_source = ADAPTER_REPO_ID dtype = select_torch_dtype() print(f"Loading TMOS-DD model from {adapter_source} with dtype={dtype}...") base_model = TMOSClassifier( base_model_id=BASE_MODEL_ID, torch_dtype=dtype, device_map="auto", token=HF_TOKEN, ) base_classifier_weight = find_classifier_weight_tensor(base_model) base_classifier_snapshot = None if base_classifier_weight is not None: base_classifier_snapshot = base_classifier_weight.detach().float().cpu().clone() peft_kwargs = {"is_trainable": False, "token": HF_TOKEN} if adapter_source == ADAPTER_REPO_ID: candidate_subfolders = [ s for s in [ADAPTER_SUBFOLDER, "multimodal", "multimodal/checkpoint-5", "llava"] if s is not None ] last_error = None for subfolder in candidate_subfolders: try: remote_cfg = load_remote_adapter_config(adapter_source, subfolder) if remote_cfg is None or not is_tmos_adapter_config(remote_cfg): raise ValueError("Adapter config is not TMOS-compatible.") current_kwargs = dict(peft_kwargs) if subfolder: current_kwargs["subfolder"] = subfolder loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **current_kwargs) lora_layer_count = count_lora_layers(loaded_model) if lora_layer_count == 0: raise RuntimeError("Loaded adapter has zero LoRA layers attached.") loaded_classifier_weight = find_classifier_weight_tensor(loaded_model) if loaded_classifier_weight is None: raise RuntimeError("Classifier head not found after adapter load.") if base_classifier_snapshot is not None: classifier_delta = ( loaded_classifier_weight.detach().float().cpu() - base_classifier_snapshot ).abs().mean().item() if classifier_delta < 1e-8: raise RuntimeError( "Classifier weights did not change after loading adapter; adapter likely incompatible." ) model = loaded_model.merge_and_unload() selected_subfolder = subfolder print( f"Loaded TMOS adapter from repo subfolder: '{subfolder or '.'}' " f"(lora_layers={lora_layer_count})" ) break except Exception as exc: last_error = exc continue else: raise RuntimeError( "No TMOS-compatible adapter found in remote repo. Upload TMOS production weights with classifier head " "(modules_to_save=['classifier'], r=64, 7-target-module LoRA)." ) from last_error else: loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **peft_kwargs) lora_layer_count = count_lora_layers(loaded_model) if lora_layer_count == 0: raise RuntimeError("Local adapter load produced zero LoRA layers attached.") model = loaded_model.merge_and_unload() print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})") model.eval() processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN) processor.patch_size = 14 processor.vision_feature_select_strategy = "default" inference_device = resolve_inference_device(model) if adapter_source == ADAPTER_REPO_ID: print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.") else: print(f"TMOS-DD ready on {inference_device} using local production adapter.") def load_cpu_fallback_model(): global model, processor, inference_device print(f"Loading CPU fallback model from {CPU_FALLBACK_MODEL_ID}...") processor = AutoImageProcessor.from_pretrained(CPU_FALLBACK_MODEL_ID, token=HF_TOKEN) model = AutoModelForImageClassification.from_pretrained( CPU_FALLBACK_MODEL_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=HF_TOKEN, ) model.to("cpu").eval() inference_device = torch.device("cpu") print("CPU fallback classifier ready.") def load_model_and_processor(): global model, processor, inference_device if model is not None and processor is not None and inference_device is not None: return model, processor, inference_device if torch.cuda.is_available(): print("GPU detected -> loading TMOS") try: load_tmos_model() except Exception as exc: print(f"TMOS failed: {exc}") print("Falling back to CPU model...") load_cpu_fallback_model() else: print("No GPU detected -> using CPU fallback") load_cpu_fallback_model() return model, processor, inference_device def preprocess_image(image: Image.Image) -> Image.Image: image = image.convert("RGB") return ImageOps.contain(image, (TARGET_IMAGE_SIZE, TARGET_IMAGE_SIZE), method=Image.Resampling.BICUBIC) def confidence_card(prob_fake: float, label: str) -> str: confidence = prob_fake if label == "Fake" else 1.0 - prob_fake confidence_pct = confidence * 100.0 fake_pct = prob_fake * 100.0 real_pct = (1.0 - prob_fake) * 100.0 accent = "#ef4444" if label == "Fake" else "#10b981" return f"""
Confidence
{confidence_pct:.2f}%
for {label}
Real: {real_pct:.2f}% Fake: {fake_pct:.2f}%
""" def score_fallback_logits(logits: torch.Tensor, id2label: dict) -> tuple[float, str]: probs = torch.softmax(logits.float(), dim=0) fake_indices = [] real_indices = [] for idx in range(len(probs)): label = str(id2label.get(idx, "")).lower() if any(key in label for key in ["fake", "deepfake", "ai", "synthetic"]): fake_indices.append(idx) if any(key in label for key in ["real", "authentic", "genuine"]): real_indices.append(idx) if len(probs) == 2 and not fake_indices and not real_indices: fake_indices = [1] real_indices = [0] fake_prob = float(probs[fake_indices].sum().item()) if fake_indices else 0.0 real_prob = float(probs[real_indices].sum().item()) if real_indices else 0.0 total = fake_prob + real_prob if total > 0: prob_fake = fake_prob / total else: prob_fake = float(probs.max().item()) if len(probs) == 1 else float(probs[1].item()) if len(probs) > 1 else 0.5 if INVERT_FALLBACK_OUTPUT: prob_fake = 1.0 - prob_fake label = "Fake" if prob_fake >= THRESHOLD else "Real" return prob_fake, label def infer_image(image: Image.Image): try: if image is None: return None, "Error: please upload an image.", None, None, None, "
Please upload an image before running detection.
" model_obj, processor_obj, device = load_model_and_processor() prepared_image = preprocess_image(image) autocast_context = ( torch.autocast(device_type="cuda", dtype=select_torch_dtype()) if device.type == "cuda" else nullcontext() ) start_time = time.perf_counter() with torch.inference_mode(), autocast_context: if inference_device.type == "cuda": inputs = processor_obj(text=TMOS_PROMPT, images=prepared_image, return_tensors="pt", padding=True) inputs = {name: tensor.to(device) for name, tensor in inputs.items()} outputs = model_obj( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], attention_mask=inputs["attention_mask"], ) logit = float(outputs["logit"].squeeze().detach().float().cpu().item()) if not math.isfinite(logit): raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.") prob_fake = float(torch.sigmoid(torch.tensor(logit)).item()) label = "Fake" if prob_fake >= THRESHOLD else "Real" else: inputs = processor_obj(images=prepared_image, return_tensors="pt") inputs = {name: tensor.to(device) for name, tensor in inputs.items()} outputs = model_obj(**inputs) logits = outputs.logits.squeeze(0).detach().float().cpu() id2label = getattr(model_obj.config, "id2label", {}) or {} prob_fake, label = score_fallback_logits(logits, id2label) if device.type == "cuda": torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - start_time) * 1000.0 if not math.isfinite(prob_fake): raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.") confidence = prob_fake if label == "Fake" else 1.0 - prob_fake return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label) except Exception as exc: err = f"Inference failed: {type(exc).__name__}: {exc}" err_html = f"
\nInference error
{err}
" return None, err, None, None, None, err_html api = FastAPI() api.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @api.post("/predict") async def predict(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") _, label, prob_fake, confidence, latency, _ = infer_image(image) return JSONResponse( { "verdict": label, "confidence_percent": confidence, "p_fake": prob_fake, "latency_ms": latency, } ) except Exception as exc: return JSONResponse({"error": str(exc)}, status_code=500) load_model_and_processor() with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo: device_label = "GPU (TMOS Model)" if torch.cuda.is_available() else "CPU Fallback Model" gr.Markdown( f"# TMOS Deepfake Detector\n" f"**Running on:** {device_label}\n\n" f"> Warning: runs on free infrastructure, so startup and inference may take time." ) with gr.Row(): image_input = gr.Image(type="pil", label="Upload image") with gr.Column(): prediction_output = gr.Textbox(label="Prediction", interactive=False) probability_output = gr.Number(label="P(fake)", interactive=False, precision=6) confidence_output = gr.Number(label="Confidence (%)", interactive=False, precision=2) latency_output = gr.Number(label="Latency (ms)", interactive=False, precision=2) preview_output = gr.Image(label="Processed image passed to the model", interactive=False) confidence_html = gr.HTML() detect_button = gr.Button("Run detection", variant="primary") detect_button.click( fn=infer_image, inputs=image_input, outputs=[preview_output, prediction_output, probability_output, confidence_output, latency_output, confidence_html], ) demo.queue(default_concurrency_limit=1, max_size=8) app = gr.mount_gradio_app(api, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))