CO-SPY / app.py
NghiTran1009's picture
Update app.py
9afe150 verified
import os
import torch
from PIL import Image
import gradio as gr
from evaluate import Detector, seed_torch
# --- Cấu hình args ---
class Args:
gpu = 0 # GPU index
device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu"
semantic_weights_path = "pretrained/semantic_weights.pth"
artifact_weights_path = "pretrained/artifact_weights.pth"
classifier_weights_path = "pretrained/classifier_weights.pth"
batch_size = 32
seed = 1024
args = Args()
seed_torch(args.seed)
# --- Khởi tạo detector 1 lần ---
detector = Detector(args)
# --- Hàm scan 1 ảnh ---
def scan_image(detector, image):
if image is None:
return 0.0, "Invalid Image"
image = image.convert("RGB")
image = detector.model.test_transform(image)
image = image.unsqueeze(0).to(detector.device)
probability = detector.predict(image)[0]
label = "AI-Generated" if probability > 0.5 else "Real"
return probability, label
# --- Callback Gradio cho nhiều ảnh ---
def gradio_scan(images):
results = []
for image in images:
prob, label = scan_image(detector, image)
results.append(f"Prediction: {prob:.3f}, Label: {label}")
return "\n".join(results)
# --- Giao diện Gradio ---
iface = gr.Interface(
fn=gradio_scan,
inputs=gr.Gallery(label="Upload Images", columns=2, type="pil"),
outputs="text",
title="CO-SPY Scan Synthetic Images",
description="Upload 1 hoặc nhiều ảnh và CO-SPY trả xác suất AI-generated hoặc Real."
)
# --- Launch ---
iface.launch(share=True, ssr_mode=False)