File size: 1,582 Bytes
f4ec716
65ad6d0
4d9cf23
 
9afe150
cdba4b4
9afe150
65ad6d0
9afe150
 
65ad6d0
 
 
 
 
 
 
 
 
9afe150
4d9cf23
 
 
f10a693
9afe150
 
f10a693
4d9cf23
 
 
 
 
 
9afe150
4d9cf23
 
 
f10a693
 
4d9cf23
2ffa87f
4d9cf23
a8f987b
4d9cf23
f10a693
4104437
4d9cf23
9afe150
cdba4b4
 
4d9cf23
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import torch
from PIL import Image
import gradio as gr
from evaluate import Detector, seed_torch

# --- Cấu hình args ---
class Args:
    gpu = 0  # GPU index
    device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu"
    semantic_weights_path = "pretrained/semantic_weights.pth"
    artifact_weights_path = "pretrained/artifact_weights.pth"
    classifier_weights_path = "pretrained/classifier_weights.pth"
    batch_size = 32
    seed = 1024

args = Args()
seed_torch(args.seed)

# --- Khởi tạo detector 1 lần ---
detector = Detector(args)

# --- Hàm scan 1 ảnh ---
def scan_image(detector, image):
    if image is None:
        return 0.0, "Invalid Image"
    image = image.convert("RGB")
    image = detector.model.test_transform(image)
    image = image.unsqueeze(0).to(detector.device)
    probability = detector.predict(image)[0]
    label = "AI-Generated" if probability > 0.5 else "Real"
    return probability, label

# --- Callback Gradio cho nhiều ảnh ---
def gradio_scan(images):
    results = []
    for image in images:
        prob, label = scan_image(detector, image)
        results.append(f"Prediction: {prob:.3f}, Label: {label}")
    return "\n".join(results)

# --- Giao diện Gradio ---
iface = gr.Interface(
    fn=gradio_scan,
    inputs=gr.Gallery(label="Upload Images", columns=2, type="pil"),
    outputs="text",
    title="CO-SPY Scan Synthetic Images",
    description="Upload 1 hoặc nhiều ảnh và CO-SPY trả xác suất AI-generated hoặc Real."
)

# --- Launch ---
iface.launch(share=True, ssr_mode=False)