import os import torch from PIL import Image import gradio as gr from evaluate import Detector, seed_torch # --- Cấu hình args --- class Args: gpu = 0 # GPU index device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu" semantic_weights_path = "pretrained/semantic_weights.pth" artifact_weights_path = "pretrained/artifact_weights.pth" classifier_weights_path = "pretrained/classifier_weights.pth" batch_size = 32 seed = 1024 args = Args() seed_torch(args.seed) # --- Khởi tạo detector 1 lần --- detector = Detector(args) # --- Hàm scan 1 ảnh --- def scan_image(detector, image): if image is None: return 0.0, "Invalid Image" image = image.convert("RGB") image = detector.model.test_transform(image) image = image.unsqueeze(0).to(detector.device) probability = detector.predict(image)[0] label = "AI-Generated" if probability > 0.5 else "Real" return probability, label # --- Callback Gradio cho nhiều ảnh --- def gradio_scan(images): results = [] for image in images: prob, label = scan_image(detector, image) results.append(f"Prediction: {prob:.3f}, Label: {label}") return "\n".join(results) # --- Giao diện Gradio --- iface = gr.Interface( fn=gradio_scan, inputs=gr.Gallery(label="Upload Images", columns=2, type="pil"), outputs="text", title="CO-SPY Scan Synthetic Images", description="Upload 1 hoặc nhiều ảnh và CO-SPY trả xác suất AI-generated hoặc Real." ) # --- Launch --- iface.launch(share=True, ssr_mode=False)