Spaces:
Sleeping
Sleeping
Commit
·
cab012d
1
Parent(s):
f1aeecb
Upload clean CO-SPY project
Browse files- .gitignore +0 -0
- Datasets/__init__.py +11 -0
- Datasets/dataset.py +105 -0
- Datasets/flickr.py +30 -0
- Datasets/mscoco.py +34 -0
- Detectors/__init__.py +6 -0
- Detectors/artifact_detector.py +82 -0
- Detectors/artifact_extractor.py +162 -0
- Detectors/cospy_calibrate_detector.py +96 -0
- Detectors/cospy_detector.py +124 -0
- Detectors/semantic_detector.py +86 -0
- LICENSE +21 -0
- ProGANDetectors/__init__.py +5 -0
- ProGANDetectors/artifact_detector.py +196 -0
- ProGANDetectors/cospy_calibrate_detector.py +96 -0
- ProGANDetectors/semantic_detector.py +82 -0
- __pycache__/utils.cpython-311.pyc +0 -0
- calibrate_combine.py +297 -0
- data/in_the_wild/README.md +14 -0
- data/in_the_wild/urls/flux.txt +0 -0
- data/in_the_wild/urls/lexica.txt +0 -0
- data/test/README.md +13 -0
- data/train/download.sh +16 -0
- environment.yml +105 -0
- evaluate.py +227 -0
- pretrained/classifer_weights.pth +3 -0
- pretrained/classifier_weights.pth +3 -0
- pretrained/semantic_weights.pth +3 -0
- requirements.txt +8 -0
- train.py +271 -0
- train_single.py +293 -0
- utils.py +162 -0
.gitignore
ADDED
|
Binary file (16 Bytes). View file
|
|
|
Datasets/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataset import TrainDataset, TestDataset
|
| 2 |
+
|
| 3 |
+
# List of evaluated real datasets
|
| 4 |
+
EVAL_DATASET_LIST = [
|
| 5 |
+
"real"
|
| 6 |
+
]
|
| 7 |
+
# Danh sách model generative
|
| 8 |
+
EVAL_MODEL_LIST = [
|
| 9 |
+
"stable_diffusion"
|
| 10 |
+
]
|
| 11 |
+
__all__ = ["TrainDataset", "TestDataset", "EVAL_DATASET_LIST", "EVAL_MODEL_LIST"]
|
Datasets/dataset.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
|
| 7 |
+
from utils import get_list, png_to_jpeg
|
| 8 |
+
from .mscoco import MSCOCO2017
|
| 9 |
+
from .flickr import Flickr30k
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TrainDataset(Dataset):
|
| 13 |
+
def __init__(self, data_path, split="train", transform=None, add_jpeg=True):
|
| 14 |
+
assert split in ["train", "val"]
|
| 15 |
+
|
| 16 |
+
# Load the dataset for training
|
| 17 |
+
real_list = get_list(os.path.join(data_path, "mscoco2017", f"{split}2017"))
|
| 18 |
+
fake_list = get_list(os.path.join(data_path, "stable-diffusion-v1-4", f"{split}2017"))
|
| 19 |
+
|
| 20 |
+
# Setting the labels for the dataset
|
| 21 |
+
self.labels_dict = {}
|
| 22 |
+
for i in real_list:
|
| 23 |
+
self.labels_dict[i] = 0
|
| 24 |
+
for i in fake_list:
|
| 25 |
+
self.labels_dict[i] = 1
|
| 26 |
+
|
| 27 |
+
# Construct the entire dataset
|
| 28 |
+
self.total_list = real_list + fake_list
|
| 29 |
+
np.random.shuffle(self.total_list)
|
| 30 |
+
|
| 31 |
+
# JPEG compression
|
| 32 |
+
self.add_jpeg = add_jpeg
|
| 33 |
+
|
| 34 |
+
# Transformations
|
| 35 |
+
self.transform = transform
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.total_list)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
img_path = self.total_list[idx]
|
| 42 |
+
label = self.labels_dict[img_path]
|
| 43 |
+
image = Image.open(img_path).convert("RGB")
|
| 44 |
+
|
| 45 |
+
# Add JPEG compression
|
| 46 |
+
if self.add_jpeg:
|
| 47 |
+
image = png_to_jpeg(image, quality=95)
|
| 48 |
+
|
| 49 |
+
# Apply the transformation
|
| 50 |
+
if self.transform is not None:
|
| 51 |
+
image = self.transform(image)
|
| 52 |
+
return image, label
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TestDataset(Dataset):
|
| 56 |
+
def __init__(self, dataset, model, root_path, transform=None, add_jpeg=True):
|
| 57 |
+
fake_dir = os.path.join(root_path, dataset, model)
|
| 58 |
+
self.fake = sorted([
|
| 59 |
+
os.path.join(fake_dir, i)
|
| 60 |
+
for i in os.listdir(fake_dir)
|
| 61 |
+
if i.lower().endswith((".png", ".jpg", ".jpeg"))
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
real_dir = os.path.join(root_path, dataset, "real")
|
| 65 |
+
if not os.path.exists(real_dir):
|
| 66 |
+
raise ValueError(f"Real images directory not found: {real_dir}")
|
| 67 |
+
|
| 68 |
+
self.real = sorted([
|
| 69 |
+
os.path.join(real_dir, i)
|
| 70 |
+
for i in os.listdir(real_dir)
|
| 71 |
+
if i.lower().endswith((".png", ".jpg", ".jpeg"))
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
self.image_idx = list(range(len(self.real) + len(self.fake)))
|
| 75 |
+
self.labels = [0] * len(self.real) + [1] * len(self.fake)
|
| 76 |
+
self.image_paths = self.real + self.fake
|
| 77 |
+
|
| 78 |
+
self.add_jpeg = add_jpeg
|
| 79 |
+
self.transform = transform
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.image_idx)
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, idx):
|
| 85 |
+
if idx < len(self.real):
|
| 86 |
+
img_path = self.real[idx]
|
| 87 |
+
else:
|
| 88 |
+
img_path = self.fake[idx - len(self.real)]
|
| 89 |
+
|
| 90 |
+
# ---- FIX: Bỏ qua ảnh hỏng / lỗi ----
|
| 91 |
+
try:
|
| 92 |
+
image = Image.open(img_path).convert("RGB")
|
| 93 |
+
except Exception:
|
| 94 |
+
print("Lỗi ảnh hỏng:", img_path)
|
| 95 |
+
# load ảnh kế tiếp thay thế
|
| 96 |
+
return self.__getitem__((idx + 1) % len(self))
|
| 97 |
+
|
| 98 |
+
if self.add_jpeg:
|
| 99 |
+
image = png_to_jpeg(image, quality=95)
|
| 100 |
+
|
| 101 |
+
if self.transform is not None:
|
| 102 |
+
image = self.transform(image)
|
| 103 |
+
|
| 104 |
+
label = self.labels[idx]
|
| 105 |
+
return image, label, img_path
|
Datasets/flickr.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import datasets as ds
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Flickr30k(torch.utils.data.Dataset):
|
| 10 |
+
def __init__(self, split='test', transform=None):
|
| 11 |
+
# Split [test: 31014]
|
| 12 |
+
self.dataset = ds.load_dataset("nlphuji/flickr30k")[split]
|
| 13 |
+
|
| 14 |
+
# Preprocess the images
|
| 15 |
+
self.transform = transform
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self.dataset)
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, idx):
|
| 21 |
+
example = self.dataset[idx]
|
| 22 |
+
# PIL RGB image
|
| 23 |
+
image = example['image']
|
| 24 |
+
if self.transform:
|
| 25 |
+
image = self.transform(image)
|
| 26 |
+
# A list of valid captions
|
| 27 |
+
caption_list = example['caption']
|
| 28 |
+
# Randomly select a caption
|
| 29 |
+
caption = np.random.choice(caption_list)
|
| 30 |
+
return image, caption
|
Datasets/mscoco.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import datasets as ds
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MSCOCO2017(torch.utils.data.Dataset):
|
| 10 |
+
def __init__(self, split='train', transform=None):
|
| 11 |
+
# Split [train: 118287, val: 5000]
|
| 12 |
+
self.dataset = ds.load_dataset(
|
| 13 |
+
"shunk031/MSCOCO",
|
| 14 |
+
year=2017,
|
| 15 |
+
coco_task="captions"
|
| 16 |
+
)[split]
|
| 17 |
+
|
| 18 |
+
# Preprocess the images
|
| 19 |
+
self.transform = transform
|
| 20 |
+
|
| 21 |
+
def __len__(self):
|
| 22 |
+
return len(self.dataset)
|
| 23 |
+
|
| 24 |
+
def __getitem__(self, idx):
|
| 25 |
+
example = self.dataset[idx]
|
| 26 |
+
# PIL RGB image
|
| 27 |
+
image = example['image'].convert('RGB')
|
| 28 |
+
if self.transform:
|
| 29 |
+
image = self.transform(image)
|
| 30 |
+
# A list of valid captions
|
| 31 |
+
caption_list = example['annotations']['caption']
|
| 32 |
+
# Randomly select a caption
|
| 33 |
+
caption = np.random.choice(caption_list)
|
| 34 |
+
return image, caption
|
Detectors/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .artifact_detector import ArtifactDetector
|
| 2 |
+
from .semantic_detector import SemanticDetector
|
| 3 |
+
from .cospy_calibrate_detector import CospyCalibrateDetector
|
| 4 |
+
from .cospy_detector import CospyDetector, LabelSmoothingBCEWithLogits
|
| 5 |
+
|
| 6 |
+
__all__ = ["ArtifactDetector", "SemanticDetector", "CospyCalibrateDetector", "CospyDetector", "LabelSmoothingBCEWithLogits"]
|
Detectors/artifact_detector.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import StableDiffusionPipeline
|
| 3 |
+
from .artifact_extractor import VAEReconEncoder
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from utils import data_augment
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Artifact Detector (Extract artifact features using VAE)
|
| 9 |
+
class ArtifactDetector(torch.nn.Module):
|
| 10 |
+
def __init__(self, dim_artifact=512, num_classes=1):
|
| 11 |
+
super(ArtifactDetector, self).__init__()
|
| 12 |
+
# Load the pre-trained VAE
|
| 13 |
+
model_id = "CompVis/stable-diffusion-v1-4"
|
| 14 |
+
vae = StableDiffusionPipeline.from_pretrained(model_id).vae
|
| 15 |
+
# Freeze the VAE visual encoder
|
| 16 |
+
vae.requires_grad_(False)
|
| 17 |
+
self.artifact_encoder = VAEReconEncoder(vae)
|
| 18 |
+
|
| 19 |
+
# Classifier
|
| 20 |
+
self.fc = torch.nn.Linear(dim_artifact, num_classes)
|
| 21 |
+
|
| 22 |
+
# Normalization
|
| 23 |
+
self.mean = [0.0, 0.0, 0.0]
|
| 24 |
+
self.std = [1.0, 1.0, 1.0]
|
| 25 |
+
|
| 26 |
+
# Resolution
|
| 27 |
+
self.loadSize = 256
|
| 28 |
+
self.cropSize = 224
|
| 29 |
+
|
| 30 |
+
# Data augmentation
|
| 31 |
+
self.blur_prob = 0.0
|
| 32 |
+
self.blur_sig = [0.0, 3.0]
|
| 33 |
+
self.jpg_prob = 0.5
|
| 34 |
+
self.jpg_method = ['cv2', 'pil']
|
| 35 |
+
self.jpg_qual = list(range(70, 96))
|
| 36 |
+
|
| 37 |
+
# Define the augmentation configuration
|
| 38 |
+
self.aug_config = {
|
| 39 |
+
"blur_prob": self.blur_prob,
|
| 40 |
+
"blur_sig": self.blur_sig,
|
| 41 |
+
"jpg_prob": self.jpg_prob,
|
| 42 |
+
"jpg_method": self.jpg_method,
|
| 43 |
+
"jpg_qual": self.jpg_qual,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Pre-processing
|
| 47 |
+
crop_func = transforms.RandomCrop(self.cropSize)
|
| 48 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 49 |
+
rz_func = transforms.Resize(self.loadSize)
|
| 50 |
+
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
|
| 51 |
+
|
| 52 |
+
self.train_transform = transforms.Compose([
|
| 53 |
+
aug_func,
|
| 54 |
+
rz_func,
|
| 55 |
+
crop_func,
|
| 56 |
+
flip_func,
|
| 57 |
+
transforms.ToTensor(),
|
| 58 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
self.test_transform = transforms.Compose([
|
| 62 |
+
rz_func,
|
| 63 |
+
crop_func,
|
| 64 |
+
transforms.ToTensor(),
|
| 65 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
def forward(self, x, return_feat=False):
|
| 69 |
+
feat = self.artifact_encoder(x)
|
| 70 |
+
out = self.fc(feat)
|
| 71 |
+
if return_feat:
|
| 72 |
+
return feat, out
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
def save_weights(self, weights_path):
|
| 76 |
+
save_params = {k: v.cpu() for k, v in self.state_dict().items()}
|
| 77 |
+
torch.save(save_params, weights_path)
|
| 78 |
+
|
| 79 |
+
def load_weights(self, weights_path):
|
| 80 |
+
weights = torch.load(weights_path)
|
| 81 |
+
self.load_state_dict(weights)
|
| 82 |
+
|
Detectors/artifact_extractor.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 7 |
+
"""3x3 convolution with padding"""
|
| 8 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 12 |
+
"""1x1 convolution"""
|
| 13 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BasicBlock(nn.Module):
|
| 17 |
+
expansion = 1
|
| 18 |
+
|
| 19 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 20 |
+
super(BasicBlock, self).__init__()
|
| 21 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 22 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 23 |
+
self.relu = nn.ReLU(inplace=True)
|
| 24 |
+
self.conv2 = conv3x3(planes, planes)
|
| 25 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 26 |
+
self.downsample = downsample
|
| 27 |
+
self.stride = stride
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
identity = x
|
| 31 |
+
|
| 32 |
+
out = self.conv1(x)
|
| 33 |
+
out = self.bn1(out)
|
| 34 |
+
out = self.relu(out)
|
| 35 |
+
|
| 36 |
+
out = self.conv2(out)
|
| 37 |
+
out = self.bn2(out)
|
| 38 |
+
|
| 39 |
+
if self.downsample is not None:
|
| 40 |
+
identity = self.downsample(x)
|
| 41 |
+
|
| 42 |
+
out += identity
|
| 43 |
+
out = self.relu(out)
|
| 44 |
+
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Bottleneck(nn.Module):
|
| 49 |
+
expansion = 4
|
| 50 |
+
|
| 51 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 52 |
+
super(Bottleneck, self).__init__()
|
| 53 |
+
self.conv1 = conv1x1(inplanes, planes)
|
| 54 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 55 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
| 56 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 57 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
| 58 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 59 |
+
self.relu = nn.ReLU(inplace=True)
|
| 60 |
+
self.downsample = downsample
|
| 61 |
+
self.stride = stride
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
identity = x
|
| 65 |
+
|
| 66 |
+
out = self.conv1(x)
|
| 67 |
+
out = self.bn1(out)
|
| 68 |
+
out = self.relu(out)
|
| 69 |
+
|
| 70 |
+
out = self.conv2(out)
|
| 71 |
+
out = self.bn2(out)
|
| 72 |
+
out = self.relu(out)
|
| 73 |
+
|
| 74 |
+
out = self.conv3(out)
|
| 75 |
+
out = self.bn3(out)
|
| 76 |
+
|
| 77 |
+
if self.downsample is not None:
|
| 78 |
+
identity = self.downsample(x)
|
| 79 |
+
|
| 80 |
+
out += identity
|
| 81 |
+
out = self.relu(out)
|
| 82 |
+
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class VAEReconEncoder(nn.Module):
|
| 87 |
+
def __init__(self, vae, block=Bottleneck):
|
| 88 |
+
super(VAEReconEncoder, self).__init__()
|
| 89 |
+
|
| 90 |
+
# Define the ResNet model
|
| 91 |
+
self.inplanes = 64
|
| 92 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 93 |
+
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 94 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 95 |
+
self.relu = nn.ReLU(inplace=True)
|
| 96 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 97 |
+
|
| 98 |
+
# ResNet-50 is [3, 4, 6, 3]
|
| 99 |
+
self.layer1 = self._make_layer(block, 64 , 3)
|
| 100 |
+
self.layer2 = self._make_layer(block, 128, 4, stride=2)
|
| 101 |
+
# self.layer3 = self._make_layer(block, 256, 6, stride=2)
|
| 102 |
+
# self.layer4 = self._make_layer(block, 512, 3, stride=2)
|
| 103 |
+
|
| 104 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 105 |
+
|
| 106 |
+
# Kaiming initialization
|
| 107 |
+
for m in self.modules():
|
| 108 |
+
if isinstance(m, nn.Conv2d):
|
| 109 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 110 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 111 |
+
nn.init.constant_(m.weight, 1)
|
| 112 |
+
nn.init.constant_(m.bias, 0)
|
| 113 |
+
|
| 114 |
+
# Load the VAE model
|
| 115 |
+
self.vae = vae
|
| 116 |
+
|
| 117 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 118 |
+
downsample = None
|
| 119 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 120 |
+
downsample = nn.Sequential(
|
| 121 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 122 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
layers = []
|
| 126 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 127 |
+
self.inplanes = planes * block.expansion
|
| 128 |
+
for _ in range(1, blocks):
|
| 129 |
+
layers.append(block(self.inplanes, planes))
|
| 130 |
+
|
| 131 |
+
return nn.Sequential(*layers)
|
| 132 |
+
|
| 133 |
+
def reconstruct(self, x):
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
# `.sample()` means to sample a latent vector from the distribution
|
| 136 |
+
# `.mean` means to use the mean of the distribution
|
| 137 |
+
latent = self.vae.encode(x).latent_dist.mean
|
| 138 |
+
decoded = self.vae.decode(latent).sample
|
| 139 |
+
return decoded
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
# Reconstruct
|
| 143 |
+
x_recon = self.reconstruct(x)
|
| 144 |
+
# Compute the artifacts
|
| 145 |
+
x = x - x_recon
|
| 146 |
+
|
| 147 |
+
# Scale the artifacts
|
| 148 |
+
x = x / 7. * 100.
|
| 149 |
+
|
| 150 |
+
# Forward pass
|
| 151 |
+
x = self.conv1(x)
|
| 152 |
+
x = self.bn1(x)
|
| 153 |
+
x = self.relu(x)
|
| 154 |
+
x = self.maxpool(x)
|
| 155 |
+
|
| 156 |
+
x = self.layer1(x)
|
| 157 |
+
x = self.layer2(x)
|
| 158 |
+
|
| 159 |
+
x = self.avgpool(x)
|
| 160 |
+
x = x.view(x.size(0), -1)
|
| 161 |
+
|
| 162 |
+
return x
|
Detectors/cospy_calibrate_detector.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from utils import data_augment
|
| 4 |
+
from .semantic_detector import SemanticDetector
|
| 5 |
+
from .artifact_detector import ArtifactDetector
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# CO-SPY Calibrate Detector (Calibrate the integration of semantic and artifact detectors)
|
| 9 |
+
class CospyCalibrateDetector(torch.nn.Module):
|
| 10 |
+
def __init__(self, semantic_weights_path, artifact_weights_path, num_classes=1):
|
| 11 |
+
super(CospyCalibrateDetector, self).__init__()
|
| 12 |
+
|
| 13 |
+
# Load the semantic detector
|
| 14 |
+
self.sem = SemanticDetector()
|
| 15 |
+
self.sem.load_weights(semantic_weights_path)
|
| 16 |
+
|
| 17 |
+
# Load the artifact detector
|
| 18 |
+
self.art = ArtifactDetector()
|
| 19 |
+
self.art.load_weights(artifact_weights_path)
|
| 20 |
+
|
| 21 |
+
# Freeze the two pre-trained models
|
| 22 |
+
for param in self.sem.parameters():
|
| 23 |
+
param.requires_grad = False
|
| 24 |
+
for param in self.art.parameters():
|
| 25 |
+
param.requires_grad = False
|
| 26 |
+
|
| 27 |
+
# Classifier
|
| 28 |
+
self.fc = torch.nn.Linear(2, num_classes)
|
| 29 |
+
|
| 30 |
+
# Transformations inside the forward function
|
| 31 |
+
# Including the normalization and resizing (only for the artifact detector)
|
| 32 |
+
self.sem_transform = transforms.Compose([
|
| 33 |
+
transforms.Normalize(self.sem.mean, self.sem.std)
|
| 34 |
+
])
|
| 35 |
+
self.art_transform = transforms.Compose([
|
| 36 |
+
transforms.Resize(self.art.cropSize, antialias=False),
|
| 37 |
+
transforms.Normalize(self.art.mean, self.art.std)
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
# Resolution
|
| 41 |
+
self.loadSize = 384
|
| 42 |
+
self.cropSize = 384
|
| 43 |
+
|
| 44 |
+
# Data augmentation
|
| 45 |
+
self.blur_prob = 0.0
|
| 46 |
+
self.blur_sig = [0.0, 3.0]
|
| 47 |
+
self.jpg_prob = 0.5
|
| 48 |
+
self.jpg_method = ['cv2', 'pil']
|
| 49 |
+
self.jpg_qual = list(range(70, 96))
|
| 50 |
+
|
| 51 |
+
# Define the augmentation configuration
|
| 52 |
+
self.aug_config = {
|
| 53 |
+
"blur_prob": self.blur_prob,
|
| 54 |
+
"blur_sig": self.blur_sig,
|
| 55 |
+
"jpg_prob": self.jpg_prob,
|
| 56 |
+
"jpg_method": self.jpg_method,
|
| 57 |
+
"jpg_qual": self.jpg_qual,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Pre-processing
|
| 61 |
+
crop_func = transforms.RandomCrop(self.cropSize)
|
| 62 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 63 |
+
rz_func = transforms.Resize(self.loadSize)
|
| 64 |
+
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
|
| 65 |
+
|
| 66 |
+
self.train_transform = transforms.Compose([
|
| 67 |
+
flip_func,
|
| 68 |
+
aug_func,
|
| 69 |
+
rz_func,
|
| 70 |
+
crop_func,
|
| 71 |
+
transforms.ToTensor(),
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
self.test_transform = transforms.Compose([
|
| 75 |
+
rz_func,
|
| 76 |
+
crop_func,
|
| 77 |
+
transforms.ToTensor(),
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
x_sem = self.sem_transform(x)
|
| 82 |
+
x_art = self.art_transform(x)
|
| 83 |
+
pred_sem = self.sem(x_sem)
|
| 84 |
+
pred_art = self.art(x_art)
|
| 85 |
+
x = torch.cat([pred_sem, pred_art], dim=1)
|
| 86 |
+
x = self.fc(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
def save_weights(self, weights_path):
|
| 90 |
+
save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
|
| 91 |
+
torch.save(save_params, weights_path)
|
| 92 |
+
|
| 93 |
+
def load_weights(self, weights_path):
|
| 94 |
+
weights = torch.load(weights_path)
|
| 95 |
+
self.fc.weight.data = weights["fc.weight"]
|
| 96 |
+
self.fc.bias.data = weights["fc.bias"]
|
Detectors/cospy_detector.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from utils import data_augment, weights2cpu
|
| 5 |
+
from .semantic_detector import SemanticDetector
|
| 6 |
+
from .artifact_detector import ArtifactDetector
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# CO-SPY Detector
|
| 10 |
+
class CospyDetector(torch.nn.Module):
|
| 11 |
+
def __init__(self, num_classes=1):
|
| 12 |
+
super(CospyDetector, self).__init__()
|
| 13 |
+
|
| 14 |
+
# Load the semantic detector
|
| 15 |
+
self.sem = SemanticDetector()
|
| 16 |
+
self.sem_dim = self.sem.fc.in_features
|
| 17 |
+
|
| 18 |
+
# Load the artifact detector
|
| 19 |
+
self.art = ArtifactDetector()
|
| 20 |
+
self.art_dim = self.art.fc.in_features
|
| 21 |
+
|
| 22 |
+
# Classifier
|
| 23 |
+
self.fc = torch.nn.Linear(self.sem_dim + self.art_dim, num_classes)
|
| 24 |
+
|
| 25 |
+
# Transformations inside the forward function
|
| 26 |
+
# Including the normalization and resizing (only for the artifact detector)
|
| 27 |
+
self.sem_transform = transforms.Compose([
|
| 28 |
+
transforms.Normalize(self.sem.mean, self.sem.std)
|
| 29 |
+
])
|
| 30 |
+
self.art_transform = transforms.Compose([
|
| 31 |
+
transforms.Resize(self.art.cropSize, antialias=False),
|
| 32 |
+
transforms.Normalize(self.art.mean, self.art.std)
|
| 33 |
+
])
|
| 34 |
+
|
| 35 |
+
# Resolution
|
| 36 |
+
self.loadSize = 384
|
| 37 |
+
self.cropSize = 384
|
| 38 |
+
|
| 39 |
+
# Data augmentation
|
| 40 |
+
self.blur_prob = 0.0
|
| 41 |
+
self.blur_sig = [0.0, 3.0]
|
| 42 |
+
self.jpg_prob = 0.5
|
| 43 |
+
self.jpg_method = ['cv2', 'pil']
|
| 44 |
+
self.jpg_qual = list(range(70, 96))
|
| 45 |
+
|
| 46 |
+
# Define the augmentation configuration
|
| 47 |
+
self.aug_config = {
|
| 48 |
+
"blur_prob": self.blur_prob,
|
| 49 |
+
"blur_sig": self.blur_sig,
|
| 50 |
+
"jpg_prob": self.jpg_prob,
|
| 51 |
+
"jpg_method": self.jpg_method,
|
| 52 |
+
"jpg_qual": self.jpg_qual,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# Pre-processing
|
| 56 |
+
crop_func = transforms.RandomCrop(self.cropSize)
|
| 57 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 58 |
+
rz_func = transforms.Resize(self.loadSize)
|
| 59 |
+
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
|
| 60 |
+
|
| 61 |
+
self.train_transform = transforms.Compose([
|
| 62 |
+
flip_func,
|
| 63 |
+
aug_func,
|
| 64 |
+
rz_func,
|
| 65 |
+
crop_func,
|
| 66 |
+
transforms.ToTensor(),
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
self.test_transform = transforms.Compose([
|
| 70 |
+
rz_func,
|
| 71 |
+
crop_func,
|
| 72 |
+
transforms.ToTensor(),
|
| 73 |
+
])
|
| 74 |
+
|
| 75 |
+
def forward(self, x, dropout_rate=0.3):
|
| 76 |
+
x_sem = self.sem_transform(x)
|
| 77 |
+
x_art = self.art_transform(x)
|
| 78 |
+
|
| 79 |
+
# Forward pass
|
| 80 |
+
sem_feat, sem_coeff = self.sem(x_sem, return_feat=True)
|
| 81 |
+
art_feat, art_coeff = self.art(x_art, return_feat=True)
|
| 82 |
+
|
| 83 |
+
# Dropout
|
| 84 |
+
if self.train():
|
| 85 |
+
# Random dropout
|
| 86 |
+
if random.random() < dropout_rate:
|
| 87 |
+
# Randomly select a feature to drop
|
| 88 |
+
idx_drop = random.randint(0, 1)
|
| 89 |
+
if idx_drop == 0:
|
| 90 |
+
sem_coeff = torch.zeros_like(sem_coeff)
|
| 91 |
+
else:
|
| 92 |
+
art_coeff = torch.zeros_like(art_coeff)
|
| 93 |
+
|
| 94 |
+
# Concatenate the features
|
| 95 |
+
x = torch.cat([sem_coeff * sem_feat, art_coeff * art_feat], dim=1)
|
| 96 |
+
x = self.fc(x)
|
| 97 |
+
|
| 98 |
+
return x
|
| 99 |
+
def save_weights(self, weights_path):
|
| 100 |
+
save_params = {
|
| 101 |
+
"sem_fc": weights2cpu(self.sem.fc.state_dict()),
|
| 102 |
+
"art_fc": weights2cpu(self.art.fc.state_dict()),
|
| 103 |
+
"art_encoder": weights2cpu(self.art.artifact_encoder.state_dict()),
|
| 104 |
+
"classifier": weights2cpu(self.fc.state_dict()),
|
| 105 |
+
}
|
| 106 |
+
torch.save(save_params, weights_path)
|
| 107 |
+
|
| 108 |
+
def load_weights(self, weights_path):
|
| 109 |
+
weights = torch.load(weights_path)
|
| 110 |
+
self.sem.fc.load_state_dict(weights["sem_fc"])
|
| 111 |
+
self.art.fc.load_state_dict(weights["art_fc"])
|
| 112 |
+
self.art.artifact_encoder.load_state_dict(weights["art_encoder"])
|
| 113 |
+
self.fc.load_state_dict(weights["classifier"])
|
| 114 |
+
|
| 115 |
+
# Define the label smoothing loss
|
| 116 |
+
class LabelSmoothingBCEWithLogits(torch.nn.Module):
|
| 117 |
+
def __init__(self, smoothing=0.1):
|
| 118 |
+
super(LabelSmoothingBCEWithLogits, self).__init__()
|
| 119 |
+
self.smoothing = smoothing
|
| 120 |
+
|
| 121 |
+
def forward(self, pred, target):
|
| 122 |
+
target = target.float() * (1.0 - self.smoothing) + 0.5 * self.smoothing
|
| 123 |
+
loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target, reduction='mean')
|
| 124 |
+
return
|
Detectors/semantic_detector.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import open_clip
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from utils import data_augment
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Semantic Detector (Extract semantic features using CLIP)
|
| 8 |
+
class SemanticDetector(torch.nn.Module):
|
| 9 |
+
def __init__(self, dim_clip=1152, num_classes=1):
|
| 10 |
+
super(SemanticDetector, self).__init__()
|
| 11 |
+
|
| 12 |
+
# Get the pre-trained CLIP
|
| 13 |
+
model_name = "ViT-SO400M-14-SigLIP-384"
|
| 14 |
+
version = "webli"
|
| 15 |
+
self.clip, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=version)
|
| 16 |
+
# Freeze the CLIP visual encoder
|
| 17 |
+
self.clip.requires_grad_(False)
|
| 18 |
+
|
| 19 |
+
# Classifier
|
| 20 |
+
self.fc = torch.nn.Linear(dim_clip, num_classes)
|
| 21 |
+
|
| 22 |
+
# Normalization
|
| 23 |
+
self.mean = [0.5, 0.5, 0.5]
|
| 24 |
+
self.std = [0.5, 0.5, 0.5]
|
| 25 |
+
|
| 26 |
+
# Resolution
|
| 27 |
+
self.loadSize = 384
|
| 28 |
+
self.cropSize = 384
|
| 29 |
+
|
| 30 |
+
# Data augmentation
|
| 31 |
+
self.blur_prob = 0.5
|
| 32 |
+
self.blur_sig = [0.0, 3.0]
|
| 33 |
+
self.jpg_prob = 0.5
|
| 34 |
+
self.jpg_method = ['cv2', 'pil']
|
| 35 |
+
self.jpg_qual = list(range(30, 101))
|
| 36 |
+
|
| 37 |
+
# Define the augmentation configuration
|
| 38 |
+
self.aug_config = {
|
| 39 |
+
"blur_prob": self.blur_prob,
|
| 40 |
+
"blur_sig": self.blur_sig,
|
| 41 |
+
"jpg_prob": self.jpg_prob,
|
| 42 |
+
"jpg_method": self.jpg_method,
|
| 43 |
+
"jpg_qual": self.jpg_qual,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Pre-processing
|
| 47 |
+
crop_func = transforms.RandomCrop(self.cropSize)
|
| 48 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 49 |
+
rz_func = transforms.Resize(self.loadSize)
|
| 50 |
+
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
|
| 51 |
+
|
| 52 |
+
self.train_transform = transforms.Compose([
|
| 53 |
+
rz_func,
|
| 54 |
+
aug_func,
|
| 55 |
+
crop_func,
|
| 56 |
+
flip_func,
|
| 57 |
+
transforms.ToTensor(),
|
| 58 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
self.test_transform = transforms.Compose([
|
| 62 |
+
rz_func,
|
| 63 |
+
crop_func,
|
| 64 |
+
transforms.ToTensor(),
|
| 65 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
def forward(self, x, return_feat=False):
|
| 69 |
+
device = next(self.fc.parameters()).device # lấy device của fc
|
| 70 |
+
x = x.to(device) # đảm bảo input cùng device
|
| 71 |
+
feat = self.clip.encode_image(x)
|
| 72 |
+
feat = feat.to(device) # đảm bảo feat cùng device với fc
|
| 73 |
+
out = self.fc(feat)
|
| 74 |
+
if return_feat:
|
| 75 |
+
return feat, out
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
def save_weights(self, weights_path):
|
| 79 |
+
save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
|
| 80 |
+
torch.save(save_params, weights_path)
|
| 81 |
+
|
| 82 |
+
def load_weights(self, weights_path):
|
| 83 |
+
device = next(self.fc.parameters()).device # lấy device hiện tại của model
|
| 84 |
+
weights = torch.load(weights_path, map_location=device)
|
| 85 |
+
self.fc.weight.data = weights["fc.weight"].to(device)
|
| 86 |
+
self.fc.bias.data = weights["fc.bias"].to(device)
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Siyuan Cheng
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
ProGANDetectors/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .artifact_detector import ArtifactDetectorProGAN
|
| 2 |
+
from .semantic_detector import SemanticDetectorProGAN
|
| 3 |
+
from .cospy_calibrate_detector import CospyCalibrateDetectorProGAN
|
| 4 |
+
|
| 5 |
+
__all__ = ["ArtifactDetectorProGAN", "SemanticDetectorProGAN", "CospyCalibrateDetectorProGAN"]
|
ProGANDetectors/artifact_detector.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from utils import data_augment
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Artifact Detector (Extract artifact features using VAE)
|
| 9 |
+
class ArtifactDetectorProGAN(torch.nn.Module):
|
| 10 |
+
def __init__(self, dim_artifact=512, num_classes=1):
|
| 11 |
+
super(ArtifactDetectorProGAN, self).__init__()
|
| 12 |
+
# Load the artifact encoder based on NPR
|
| 13 |
+
self.artifact_encoder = ResNet(Bottleneck, [3, 4, 6, 3])
|
| 14 |
+
|
| 15 |
+
# Classifier
|
| 16 |
+
self.fc = torch.nn.Linear(dim_artifact, num_classes)
|
| 17 |
+
|
| 18 |
+
# Normalization
|
| 19 |
+
self.mean = [0.485, 0.456, 0.406]
|
| 20 |
+
self.std = [0.229, 0.224, 0.225]
|
| 21 |
+
|
| 22 |
+
# Resolution
|
| 23 |
+
self.loadSize = 256
|
| 24 |
+
self.cropSize = 224
|
| 25 |
+
|
| 26 |
+
# Data augmentation
|
| 27 |
+
self.blur_prob = 0.0
|
| 28 |
+
self.blur_sig = [0.0, 3.0]
|
| 29 |
+
self.jpg_prob = 0.0
|
| 30 |
+
self.jpg_method = ['cv2', 'pil']
|
| 31 |
+
self.jpg_qual = list(range(70, 96))
|
| 32 |
+
|
| 33 |
+
# Define the augmentation configuration
|
| 34 |
+
self.aug_config = {
|
| 35 |
+
"blur_prob": self.blur_prob,
|
| 36 |
+
"blur_sig": self.blur_sig,
|
| 37 |
+
"jpg_prob": self.jpg_prob,
|
| 38 |
+
"jpg_method": self.jpg_method,
|
| 39 |
+
"jpg_qual": self.jpg_qual,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Pre-processing
|
| 43 |
+
crop_func = transforms.RandomCrop(self.cropSize)
|
| 44 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 45 |
+
rz_func = transforms.Resize(self.loadSize)
|
| 46 |
+
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
|
| 47 |
+
|
| 48 |
+
self.train_transform = transforms.Compose([
|
| 49 |
+
aug_func,
|
| 50 |
+
rz_func,
|
| 51 |
+
crop_func,
|
| 52 |
+
flip_func,
|
| 53 |
+
transforms.ToTensor(),
|
| 54 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 55 |
+
])
|
| 56 |
+
|
| 57 |
+
self.test_transform = transforms.Compose([
|
| 58 |
+
rz_func,
|
| 59 |
+
crop_func,
|
| 60 |
+
transforms.ToTensor(),
|
| 61 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
def forward(self, x, return_feat=False):
|
| 65 |
+
feat = self.artifact_encoder(x)
|
| 66 |
+
out = self.fc(feat)
|
| 67 |
+
if return_feat:
|
| 68 |
+
return feat, out
|
| 69 |
+
return out
|
| 70 |
+
|
| 71 |
+
def save_weights(self, weights_path):
|
| 72 |
+
save_params = {k: v.cpu() for k, v in self.state_dict().items()}
|
| 73 |
+
torch.save(save_params, weights_path)
|
| 74 |
+
|
| 75 |
+
def load_weights(self, weights_path):
|
| 76 |
+
weights = torch.load(weights_path)
|
| 77 |
+
self.load_state_dict(weights)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Define the artifact encoder (based on NPR)
|
| 81 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 82 |
+
"""1x1 convolution"""
|
| 83 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 87 |
+
"""3x3 convolution with padding"""
|
| 88 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 89 |
+
padding=1, bias=False)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Bottleneck(nn.Module):
|
| 93 |
+
expansion = 4
|
| 94 |
+
|
| 95 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 96 |
+
super(Bottleneck, self).__init__()
|
| 97 |
+
self.conv1 = conv1x1(inplanes, planes)
|
| 98 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 99 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
| 100 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 101 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
| 102 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 103 |
+
self.relu = nn.ReLU(inplace=True)
|
| 104 |
+
self.downsample = downsample
|
| 105 |
+
self.stride = stride
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
identity = x
|
| 109 |
+
|
| 110 |
+
out = self.conv1(x)
|
| 111 |
+
out = self.bn1(out)
|
| 112 |
+
out = self.relu(out)
|
| 113 |
+
|
| 114 |
+
out = self.conv2(out)
|
| 115 |
+
out = self.bn2(out)
|
| 116 |
+
out = self.relu(out)
|
| 117 |
+
|
| 118 |
+
out = self.conv3(out)
|
| 119 |
+
out = self.bn3(out)
|
| 120 |
+
|
| 121 |
+
if self.downsample is not None:
|
| 122 |
+
identity = self.downsample(x)
|
| 123 |
+
|
| 124 |
+
out += identity
|
| 125 |
+
out = self.relu(out)
|
| 126 |
+
|
| 127 |
+
return out
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ResNet(nn.Module):
|
| 131 |
+
|
| 132 |
+
def __init__(self, block, layers, num_classes=1):
|
| 133 |
+
super(ResNet, self).__init__()
|
| 134 |
+
|
| 135 |
+
self.unfoldSize = 2
|
| 136 |
+
self.unfoldIndex = 0
|
| 137 |
+
assert self.unfoldSize > 1
|
| 138 |
+
assert -1 < self.unfoldIndex and self.unfoldIndex < self.unfoldSize*self.unfoldSize
|
| 139 |
+
self.inplanes = 64
|
| 140 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 141 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 142 |
+
self.relu = nn.ReLU(inplace=True)
|
| 143 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 144 |
+
self.layer1 = self._make_layer(block, 64 , layers[0])
|
| 145 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 146 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 147 |
+
self.fc1 = nn.Linear(512, num_classes)
|
| 148 |
+
|
| 149 |
+
for m in self.modules():
|
| 150 |
+
if isinstance(m, nn.Conv2d):
|
| 151 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 152 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 153 |
+
nn.init.constant_(m.weight, 1)
|
| 154 |
+
nn.init.constant_(m.bias, 0)
|
| 155 |
+
|
| 156 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 157 |
+
downsample = None
|
| 158 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 159 |
+
downsample = nn.Sequential(
|
| 160 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 161 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
layers = []
|
| 165 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 166 |
+
self.inplanes = planes * block.expansion
|
| 167 |
+
for _ in range(1, blocks):
|
| 168 |
+
layers.append(block(self.inplanes, planes))
|
| 169 |
+
|
| 170 |
+
return nn.Sequential(*layers)
|
| 171 |
+
|
| 172 |
+
def interpolate(self, img, factor):
|
| 173 |
+
return F.interpolate(
|
| 174 |
+
F.interpolate(img,
|
| 175 |
+
scale_factor=factor,
|
| 176 |
+
mode='nearest',
|
| 177 |
+
recompute_scale_factor=True),
|
| 178 |
+
scale_factor=1 / factor,
|
| 179 |
+
mode='nearest',
|
| 180 |
+
recompute_scale_factor=True)
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
artifact = x - self.interpolate(x, 0.5)
|
| 184 |
+
|
| 185 |
+
x = self.conv1(artifact * 2.0 / 3.0)
|
| 186 |
+
x = self.bn1(x)
|
| 187 |
+
x = self.relu(x)
|
| 188 |
+
x = self.maxpool(x)
|
| 189 |
+
|
| 190 |
+
x = self.layer1(x)
|
| 191 |
+
x = self.layer2(x)
|
| 192 |
+
|
| 193 |
+
x = self.avgpool(x)
|
| 194 |
+
x = x.view(x.size(0), -1)
|
| 195 |
+
|
| 196 |
+
return x
|
ProGANDetectors/cospy_calibrate_detector.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from utils import data_augment
|
| 4 |
+
from .semantic_detector import SemanticDetectorProGAN
|
| 5 |
+
from .artifact_detector import ArtifactDetectorProGAN
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# CO-SPY Calibrate Detector (Calibrate the integration of semantic and artifact detectors)
|
| 9 |
+
class CospyCalibrateDetectorProGAN(torch.nn.Module):
|
| 10 |
+
def __init__(self, semantic_weights_path, artifact_weights_path, num_classes=1):
|
| 11 |
+
super(CospyCalibrateDetectorProGAN, self).__init__()
|
| 12 |
+
|
| 13 |
+
# Load the semantic detector
|
| 14 |
+
self.sem = SemanticDetectorProGAN()
|
| 15 |
+
self.sem.load_weights(semantic_weights_path)
|
| 16 |
+
|
| 17 |
+
# Load the artifact detector
|
| 18 |
+
self.art = ArtifactDetectorProGAN()
|
| 19 |
+
self.art.load_weights(artifact_weights_path)
|
| 20 |
+
|
| 21 |
+
# Freeze the two pre-trained models
|
| 22 |
+
for param in self.sem.parameters():
|
| 23 |
+
param.requires_grad = False
|
| 24 |
+
for param in self.art.parameters():
|
| 25 |
+
param.requires_grad = False
|
| 26 |
+
|
| 27 |
+
# Classifier
|
| 28 |
+
self.fc = torch.nn.Linear(2, num_classes)
|
| 29 |
+
|
| 30 |
+
# Transformations inside the forward function
|
| 31 |
+
# Including the normalization and resizing (only for the artifact detector)
|
| 32 |
+
self.sem_transform = transforms.Compose([
|
| 33 |
+
transforms.Normalize(self.sem.mean, self.sem.std)
|
| 34 |
+
])
|
| 35 |
+
self.art_transform = transforms.Compose([
|
| 36 |
+
transforms.Resize(self.art.cropSize, antialias=False),
|
| 37 |
+
transforms.Normalize(self.art.mean, self.art.std)
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
# Resolution
|
| 41 |
+
self.loadSize = 256
|
| 42 |
+
self.cropSize = 224
|
| 43 |
+
|
| 44 |
+
# Data augmentation
|
| 45 |
+
self.blur_prob = 0.0
|
| 46 |
+
self.blur_sig = [0.0, 3.0]
|
| 47 |
+
self.jpg_prob = 0.0
|
| 48 |
+
self.jpg_method = ['cv2', 'pil']
|
| 49 |
+
self.jpg_qual = list(range(70, 96))
|
| 50 |
+
|
| 51 |
+
# Define the augmentation configuration
|
| 52 |
+
self.aug_config = {
|
| 53 |
+
"blur_prob": self.blur_prob,
|
| 54 |
+
"blur_sig": self.blur_sig,
|
| 55 |
+
"jpg_prob": self.jpg_prob,
|
| 56 |
+
"jpg_method": self.jpg_method,
|
| 57 |
+
"jpg_qual": self.jpg_qual,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Pre-processing
|
| 61 |
+
crop_func = transforms.RandomCrop(self.cropSize)
|
| 62 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 63 |
+
rz_func = transforms.Resize(self.loadSize)
|
| 64 |
+
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
|
| 65 |
+
|
| 66 |
+
self.train_transform = transforms.Compose([
|
| 67 |
+
flip_func,
|
| 68 |
+
aug_func,
|
| 69 |
+
rz_func,
|
| 70 |
+
crop_func,
|
| 71 |
+
transforms.ToTensor(),
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
self.test_transform = transforms.Compose([
|
| 75 |
+
rz_func,
|
| 76 |
+
crop_func,
|
| 77 |
+
transforms.ToTensor(),
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
x_sem = self.sem_transform(x)
|
| 82 |
+
x_art = self.art_transform(x)
|
| 83 |
+
pred_sem = self.sem(x_sem)
|
| 84 |
+
pred_art = self.art(x_art)
|
| 85 |
+
x = torch.cat([pred_sem, pred_art], dim=1)
|
| 86 |
+
x = self.fc(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
def save_weights(self, weights_path):
|
| 90 |
+
save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
|
| 91 |
+
torch.save(save_params, weights_path)
|
| 92 |
+
|
| 93 |
+
def load_weights(self, weights_path):
|
| 94 |
+
weights = torch.load(weights_path)
|
| 95 |
+
self.fc.weight.data = weights["fc.weight"]
|
| 96 |
+
self.fc.bias.data = weights["fc.bias"]
|
ProGANDetectors/semantic_detector.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import CLIPModel
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from utils import data_augment
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Semantic Detector (Extract semantic features using CLIP)
|
| 8 |
+
class SemanticDetectorProGAN(torch.nn.Module):
|
| 9 |
+
def __init__(self, dim_clip=768, num_classes=1):
|
| 10 |
+
super(SemanticDetectorProGAN, self).__init__()
|
| 11 |
+
|
| 12 |
+
# Get the pre-trained CLIP
|
| 13 |
+
model_name = "openai/clip-vit-large-patch14"
|
| 14 |
+
self.clip = CLIPModel.from_pretrained(model_name)
|
| 15 |
+
|
| 16 |
+
# Freeze the CLIP visual encoder
|
| 17 |
+
self.clip.requires_grad_(False)
|
| 18 |
+
|
| 19 |
+
# Classifier
|
| 20 |
+
self.fc = torch.nn.Linear(dim_clip, num_classes)
|
| 21 |
+
|
| 22 |
+
# Normalization
|
| 23 |
+
self.mean = [0.48145466, 0.4578275, 0.40821073]
|
| 24 |
+
self.std = [0.26862954, 0.26130258, 0.27577711]
|
| 25 |
+
|
| 26 |
+
# Resolution
|
| 27 |
+
self.loadSize = 256
|
| 28 |
+
self.cropSize = 224
|
| 29 |
+
|
| 30 |
+
# Data augmentation
|
| 31 |
+
self.blur_prob = 0.5
|
| 32 |
+
self.blur_sig = [0.0, 3.0]
|
| 33 |
+
self.jpg_prob = 0.5
|
| 34 |
+
self.jpg_method = ['cv2', 'pil']
|
| 35 |
+
self.jpg_qual = list(range(30, 101))
|
| 36 |
+
|
| 37 |
+
# Define the augmentation configuration
|
| 38 |
+
self.aug_config = {
|
| 39 |
+
"blur_prob": self.blur_prob,
|
| 40 |
+
"blur_sig": self.blur_sig,
|
| 41 |
+
"jpg_prob": self.jpg_prob,
|
| 42 |
+
"jpg_method": self.jpg_method,
|
| 43 |
+
"jpg_qual": self.jpg_qual,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Pre-processing
|
| 47 |
+
crop_func = transforms.RandomCrop(self.cropSize)
|
| 48 |
+
flip_func = transforms.RandomHorizontalFlip()
|
| 49 |
+
rz_func = transforms.Resize(self.loadSize)
|
| 50 |
+
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
|
| 51 |
+
|
| 52 |
+
self.train_transform = transforms.Compose([
|
| 53 |
+
rz_func,
|
| 54 |
+
aug_func,
|
| 55 |
+
crop_func,
|
| 56 |
+
flip_func,
|
| 57 |
+
transforms.ToTensor(),
|
| 58 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
self.test_transform = transforms.Compose([
|
| 62 |
+
rz_func,
|
| 63 |
+
crop_func,
|
| 64 |
+
transforms.ToTensor(),
|
| 65 |
+
transforms.Normalize(mean=self.mean, std=self.std),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
def forward(self, x, return_feat=False):
|
| 69 |
+
feat = self.clip.get_image_features(x)
|
| 70 |
+
out = self.fc(feat)
|
| 71 |
+
if return_feat:
|
| 72 |
+
return feat, out
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
def save_weights(self, weights_path):
|
| 76 |
+
save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
|
| 77 |
+
torch.save(save_params, weights_path)
|
| 78 |
+
|
| 79 |
+
def load_weights(self, weights_path):
|
| 80 |
+
weights = torch.load(weights_path)
|
| 81 |
+
self.fc.weight.data = weights["fc.weight"]
|
| 82 |
+
self.fc.bias.data = weights["fc.bias"]
|
__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (9.5 kB). View file
|
|
|
calibrate_combine.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from sklearn.metrics import average_precision_score
|
| 9 |
+
|
| 10 |
+
from Detectors import CospyCalibrateDetector
|
| 11 |
+
from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
|
| 12 |
+
from utils import seed_torch
|
| 13 |
+
|
| 14 |
+
import warnings
|
| 15 |
+
warnings.filterwarnings("ignore")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Detector():
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super(Detector, self).__init__()
|
| 21 |
+
|
| 22 |
+
# Device
|
| 23 |
+
self.device = args.device
|
| 24 |
+
|
| 25 |
+
# ===== Khởi tạo model =====
|
| 26 |
+
self.model = CospyCalibrateDetector(
|
| 27 |
+
semantic_weights_path=args.semantic_weights_path,
|
| 28 |
+
artifact_weights_path=args.artifact_weights_path
|
| 29 |
+
)
|
| 30 |
+
self.model.to(self.device)
|
| 31 |
+
|
| 32 |
+
# Khởi tạo fc layer nếu muốn
|
| 33 |
+
torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
|
| 34 |
+
|
| 35 |
+
# ===== Optimizer =====
|
| 36 |
+
_lr = 1e-1
|
| 37 |
+
_beta1 = 0.9
|
| 38 |
+
_weight_decay = 0.0
|
| 39 |
+
params = [p for p in self.model.parameters() if p.requires_grad]
|
| 40 |
+
print(f'Trainable parameters: {len(params)}')
|
| 41 |
+
self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay)
|
| 42 |
+
|
| 43 |
+
# ===== Loss =====
|
| 44 |
+
self.criterion = torch.nn.BCEWithLogitsLoss()
|
| 45 |
+
|
| 46 |
+
# Scheduler
|
| 47 |
+
self.delr_freq = 10
|
| 48 |
+
|
| 49 |
+
# ===== Load checkpoint nếu có =====
|
| 50 |
+
if args.resume is not None:
|
| 51 |
+
print(f"Loading checkpoint from {args.resume}")
|
| 52 |
+
state = torch.load(args.resume, map_location=self.device)
|
| 53 |
+
|
| 54 |
+
# hỗ trợ cả 2 dạng: {'model': state_dict} hoặc state_dict trực tiếp
|
| 55 |
+
if isinstance(state, dict) and "model" in state:
|
| 56 |
+
state = state["model"]
|
| 57 |
+
|
| 58 |
+
self.model.load_state_dict(state, strict=False)
|
| 59 |
+
print("Checkpoint loaded. Continue training...")
|
| 60 |
+
|
| 61 |
+
self.model.to(self.device)
|
| 62 |
+
self.model.train()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Training function for the detector
|
| 67 |
+
def train_step(self, batch_data):
|
| 68 |
+
# Decompose the batch data
|
| 69 |
+
inputs, labels = batch_data
|
| 70 |
+
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
| 71 |
+
|
| 72 |
+
self.optimizer.zero_grad()
|
| 73 |
+
outputs = self.model(inputs)
|
| 74 |
+
loss = self.criterion(outputs, labels.unsqueeze(1).float())
|
| 75 |
+
loss.backward()
|
| 76 |
+
self.optimizer.step()
|
| 77 |
+
|
| 78 |
+
eval_loss = loss.item()
|
| 79 |
+
y_pred = outputs.sigmoid().flatten().tolist()
|
| 80 |
+
y_true = labels.tolist()
|
| 81 |
+
return eval_loss, y_pred, y_true
|
| 82 |
+
|
| 83 |
+
# Schedule the training
|
| 84 |
+
# Early stopping / learning rate adjustment
|
| 85 |
+
def scheduler(self, status_dict):
|
| 86 |
+
epoch = status_dict['epoch']
|
| 87 |
+
if epoch % self.delr_freq == 0 and epoch != 0:
|
| 88 |
+
for param_group in self.optimizer.param_groups:
|
| 89 |
+
param_group['lr'] *= 0.9
|
| 90 |
+
self.lr = param_group['lr']
|
| 91 |
+
return True
|
| 92 |
+
|
| 93 |
+
# Prediction function
|
| 94 |
+
def predict(self, inputs):
|
| 95 |
+
inputs = inputs.to(self.device)
|
| 96 |
+
outputs = self.model(inputs)
|
| 97 |
+
prediction = outputs.sigmoid().flatten().tolist()
|
| 98 |
+
return prediction
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def evaluate(y_pred, y_true):
|
| 102 |
+
ap = average_precision_score(y_true, y_pred)
|
| 103 |
+
accuracy = ((np.array(y_pred) > 0.5) == y_true).mean()
|
| 104 |
+
return ap, accuracy
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def train(args):
|
| 108 |
+
# Set the saving directory **trước**
|
| 109 |
+
model_dir = os.path.join(args.ckpt, "cospy_calibrate")
|
| 110 |
+
if not os.path.exists(model_dir):
|
| 111 |
+
os.makedirs(model_dir)
|
| 112 |
+
|
| 113 |
+
log_path = f"{model_dir}/training.log"
|
| 114 |
+
if os.path.exists(log_path):
|
| 115 |
+
os.remove(log_path)
|
| 116 |
+
|
| 117 |
+
logger_id = logger.add(
|
| 118 |
+
log_path,
|
| 119 |
+
format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
|
| 120 |
+
level="DEBUG",
|
| 121 |
+
)
|
| 122 |
+
# Get the detector
|
| 123 |
+
detector = Detector(args)
|
| 124 |
+
# --- Resume checkpoint ---
|
| 125 |
+
start_epoch = 0
|
| 126 |
+
best_acc = 0
|
| 127 |
+
|
| 128 |
+
if args.resume:
|
| 129 |
+
resume_path = os.path.join(model_dir, "best_model.pth")
|
| 130 |
+
if os.path.exists(resume_path):
|
| 131 |
+
print(f"Resuming from {resume_path} ...")
|
| 132 |
+
detector.model.load_weights(resume_path)
|
| 133 |
+
detector.model.to(args.device)
|
| 134 |
+
|
| 135 |
+
# Load the calibration dataset using the "val" split
|
| 136 |
+
train_dataset = TrainDataset(data_path=args.calibration_dirpath,
|
| 137 |
+
split="val",
|
| 138 |
+
transform=detector.model.test_transform)
|
| 139 |
+
|
| 140 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
| 141 |
+
batch_size=args.batch_size,
|
| 142 |
+
shuffle=True,
|
| 143 |
+
num_workers=4,
|
| 144 |
+
pin_memory=True)
|
| 145 |
+
|
| 146 |
+
logger.info(f"Train size {len(train_dataset)}")
|
| 147 |
+
|
| 148 |
+
# Set the saving directory
|
| 149 |
+
model_dir = os.path.join(args.ckpt, "cospy_calibrate")
|
| 150 |
+
if not os.path.exists(model_dir):
|
| 151 |
+
os.makedirs(model_dir)
|
| 152 |
+
log_path = f"{model_dir}/training.log"
|
| 153 |
+
if os.path.exists(log_path):
|
| 154 |
+
os.remove(log_path)
|
| 155 |
+
|
| 156 |
+
logger_id = logger.add(
|
| 157 |
+
log_path,
|
| 158 |
+
format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
|
| 159 |
+
level="DEBUG",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Train the detector
|
| 163 |
+
best_acc = 0
|
| 164 |
+
for epoch in range(start_epoch, args.epochs):
|
| 165 |
+
# Set the model to training mode
|
| 166 |
+
detector.model.train()
|
| 167 |
+
time_start = time.time()
|
| 168 |
+
for step_id, batch_data in enumerate(train_loader):
|
| 169 |
+
eval_loss, y_pred, y_true = detector.train_step(batch_data)
|
| 170 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 171 |
+
|
| 172 |
+
# Log the training information
|
| 173 |
+
if (step_id + 1) % 100 == 0:
|
| 174 |
+
time_end = time.time()
|
| 175 |
+
logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s")
|
| 176 |
+
time_start = time.time()
|
| 177 |
+
|
| 178 |
+
# Evaluate the model
|
| 179 |
+
detector.model.eval()
|
| 180 |
+
y_pred, y_true = [], []
|
| 181 |
+
for inputs in train_loader:
|
| 182 |
+
inputs, labels = inputs
|
| 183 |
+
y_pred.extend(detector.predict(inputs))
|
| 184 |
+
y_true.extend(labels.tolist())
|
| 185 |
+
|
| 186 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 187 |
+
logger.info(f"Epoch {epoch} | Total AP {ap*100:.2f}% | Total Accuracy {accuracy*100:.2f}%")
|
| 188 |
+
|
| 189 |
+
# Schedule the training
|
| 190 |
+
status_dict = {'epoch': epoch, 'AP': ap, 'Accuracy': accuracy}
|
| 191 |
+
proceed = detector.scheduler(status_dict)
|
| 192 |
+
if not proceed:
|
| 193 |
+
logger.info("Early stopping")
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
# Save the model
|
| 197 |
+
if accuracy >= best_acc:
|
| 198 |
+
best_acc = accuracy
|
| 199 |
+
detector.model.save_weights(f"{model_dir}/best_model.pth")
|
| 200 |
+
logger.info(f"Best model saved with accuracy {best_acc.mean()*100:.2f}%")
|
| 201 |
+
|
| 202 |
+
if epoch % 5 == 0:
|
| 203 |
+
detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth")
|
| 204 |
+
logger.info(f"Model saved at epoch {epoch}")
|
| 205 |
+
|
| 206 |
+
# Save the final model
|
| 207 |
+
detector.model.save_weights(f"{model_dir}/final_model.pth")
|
| 208 |
+
logger.info("Final model saved")
|
| 209 |
+
|
| 210 |
+
# Remove the logger
|
| 211 |
+
logger.remove(logger_id)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def test(args):
|
| 215 |
+
# Initialize the detector
|
| 216 |
+
detector = Detector(args)
|
| 217 |
+
|
| 218 |
+
# Load the [best/final] model
|
| 219 |
+
weights_path = os.path.join(args.ckpt, "cospy_calibrate", "best_model.pth")
|
| 220 |
+
|
| 221 |
+
detector.model.load_weights(weights_path)
|
| 222 |
+
detector.model.to(args.device)
|
| 223 |
+
detector.model.eval()
|
| 224 |
+
|
| 225 |
+
# Set the pre-processing function
|
| 226 |
+
test_transform = detector.model.test_transform
|
| 227 |
+
|
| 228 |
+
# Set the saving directory
|
| 229 |
+
save_dir = os.path.join(args.ckpt, "cospy_calibrate")
|
| 230 |
+
save_result_path = os.path.join(save_dir, "result.json")
|
| 231 |
+
save_output_path = os.path.join(save_dir, "output.json")
|
| 232 |
+
|
| 233 |
+
# Begin the evaluation
|
| 234 |
+
result_all = {}
|
| 235 |
+
output_all = {}
|
| 236 |
+
for dataset_name in EVAL_DATASET_LIST:
|
| 237 |
+
result_all[dataset_name] = {}
|
| 238 |
+
output_all[dataset_name] = {}
|
| 239 |
+
for model_name in EVAL_MODEL_LIST:
|
| 240 |
+
test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform)
|
| 241 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 242 |
+
batch_size=args.batch_size,
|
| 243 |
+
shuffle=False,
|
| 244 |
+
num_workers=4,
|
| 245 |
+
pin_memory=True)
|
| 246 |
+
|
| 247 |
+
# Evaluate the model
|
| 248 |
+
y_pred, y_true = [], []
|
| 249 |
+
for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
|
| 250 |
+
y_pred.extend(detector.predict(images))
|
| 251 |
+
y_true.extend(labels.tolist())
|
| 252 |
+
|
| 253 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 254 |
+
print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%")
|
| 255 |
+
|
| 256 |
+
result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy}
|
| 257 |
+
output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
|
| 258 |
+
|
| 259 |
+
# Save the results
|
| 260 |
+
with open(save_result_path, "w") as f:
|
| 261 |
+
json.dump(result_all, f, indent=4)
|
| 262 |
+
|
| 263 |
+
with open(save_output_path, "w") as f:
|
| 264 |
+
json.dump(output_all, f, indent=4)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
import argparse
|
| 269 |
+
|
| 270 |
+
parser = argparse.ArgumentParser("Deep Fake Detection")
|
| 271 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
|
| 272 |
+
parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment")
|
| 273 |
+
parser.add_argument("--semantic_weights_path", type=str, default="ckpt/semantic/best_model.pth", help="Semantic weights path")
|
| 274 |
+
parser.add_argument("--artifact_weights_path", type=str, default="ckpt/artifact/best_model.pth", help="Artifact weights path")
|
| 275 |
+
parser.add_argument("--calibration_dirpath", type=str, default="data/train", help="Calibration directory")
|
| 276 |
+
parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
|
| 277 |
+
parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory")
|
| 278 |
+
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
| 279 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
| 280 |
+
parser.add_argument("--seed", type=int, default=1024, help="Random seed")
|
| 281 |
+
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training")
|
| 282 |
+
|
| 283 |
+
args = parser.parse_args()
|
| 284 |
+
|
| 285 |
+
# Set the random seed
|
| 286 |
+
seed_torch(args.seed)
|
| 287 |
+
|
| 288 |
+
# Set the GPU ID
|
| 289 |
+
args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| 290 |
+
|
| 291 |
+
# Begin the experiment
|
| 292 |
+
if args.phase == "train":
|
| 293 |
+
train(args)
|
| 294 |
+
elif args.phase == "test":
|
| 295 |
+
test(args)
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError("Unknown phase")
|
data/in_the_wild/README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Our in-the-wild evaluation dataset is constructed from five sources:
|
| 2 |
+
|
| 3 |
+
(1) Civitai [https://civitai.com/]
|
| 4 |
+
|
| 5 |
+
(2) DALL-E 3 [https://huggingface.co/datasets/ProGamerGov/synthetic-dataset-1m-dalle3-high-quality-captions]
|
| 6 |
+
|
| 7 |
+
(3) instavibe.ai [https://www.instavibe.ai/discover]
|
| 8 |
+
|
| 9 |
+
(4) Lexica [https://lexica.art/]
|
| 10 |
+
|
| 11 |
+
(5) Midjourney-v6 [https://huggingface.co/datasets/terminusresearch/midjourney-v6-520k-raw]
|
| 12 |
+
|
| 13 |
+
Data from sources (1), (2), (5) can be easily accessed and downloaded.
|
| 14 |
+
For sources (3) and (4), we provide the image URLs used in our dataset under the `./urls` directory for your convenience.
|
data/in_the_wild/urls/flux.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/in_the_wild/urls/lexica.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/test/README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Please download the test samples from [Co-Spy-Bench](https://huggingface.co/datasets/ruojiruoli/Co-Spy-Bench) and place them in this directory.
|
| 2 |
+
|
| 3 |
+
For real images:
|
| 4 |
+
|
| 5 |
+
* **CC3M**, **MSCOCO**, **TextCaps**, **Flickr**, and **SBU** are used.
|
| 6 |
+
* For **MSCOCO** and **Flickr**, refer to `Datasets/mscoco.py` and `Datasets/flickr.py` for instructions on downloading via HuggingFace Datasets.
|
| 7 |
+
* For the remaining datasets, download from their original sources:
|
| 8 |
+
|
| 9 |
+
* [CC3M](https://ai.google.com/research/ConceptualCaptions/download)
|
| 10 |
+
* [TextCaps](https://textvqa.org/textcaps/dataset/)
|
| 11 |
+
* [SBU](https://huggingface.co/datasets/vicenteor/sbu_captions)
|
| 12 |
+
|
| 13 |
+
Example test samples are also available on [Google Drive](https://drive.google.com/file/d/1JaaIGItyDYprr4_k0C_90MGIIRVQpmIP/view?usp=sharing). Please ensure their use complies with the original licenses.
|
data/train/download.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Download and unzip the synthetic training dataset from DRCT
|
| 2 |
+
# Reference: https://icml.cc/virtual/2024/poster/33086
|
| 3 |
+
# Data source: https://github.com/beibuwandeluori/DRCT
|
| 4 |
+
wget --no-check-certificate https://modelscope.cn/datasets/BokingChen/DRCT-2M/resolve/master/images/stable-diffusion-v1-4.zip
|
| 5 |
+
unzip stable-diffusion-v1-4.zip
|
| 6 |
+
|
| 7 |
+
# Download the real training dataset from MSCOCO2017
|
| 8 |
+
# Reference: https://arxiv.org/pdf/1405.0312
|
| 9 |
+
# Data source: https://cocodataset.org/#download
|
| 10 |
+
mkdir mscoco2017
|
| 11 |
+
cd mscoco2017
|
| 12 |
+
wget http://images.cocodataset.org/zips/train2017.zip
|
| 13 |
+
wget http://images.cocodataset.org/zips/val2017.zip
|
| 14 |
+
unzip train2017.zip
|
| 15 |
+
unzip val2017.zip
|
| 16 |
+
cd ..
|
environment.yml
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: cospy
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=main
|
| 6 |
+
- _openmp_mutex=5.1=1_gnu
|
| 7 |
+
- ca-certificates=2025.2.25=h06a4308_0
|
| 8 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 9 |
+
- libffi=3.4.4=h6a678d5_1
|
| 10 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 11 |
+
- libgomp=11.2.0=h1234567_1
|
| 12 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 13 |
+
- ncurses=6.4=h6a678d5_0
|
| 14 |
+
- openssl=3.0.16=h5eee18b_0
|
| 15 |
+
- pip=24.2=py38h06a4308_0
|
| 16 |
+
- python=3.8.18=h955ad1f_0
|
| 17 |
+
- readline=8.2=h5eee18b_0
|
| 18 |
+
- setuptools=75.1.0=py38h06a4308_0
|
| 19 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 20 |
+
- tk=8.6.14=h39e8969_0
|
| 21 |
+
- wheel=0.44.0=py38h06a4308_0
|
| 22 |
+
- xz=5.6.4=h5eee18b_1
|
| 23 |
+
- zlib=1.2.13=h5eee18b_1
|
| 24 |
+
- pip:
|
| 25 |
+
- accelerate==1.0.1
|
| 26 |
+
- aiohappyeyeballs==2.4.4
|
| 27 |
+
- aiohttp==3.10.11
|
| 28 |
+
- aiosignal==1.3.1
|
| 29 |
+
- async-timeout==5.0.1
|
| 30 |
+
- attrs==25.3.0
|
| 31 |
+
- certifi==2025.1.31
|
| 32 |
+
- charset-normalizer==3.4.1
|
| 33 |
+
- contourpy==1.1.1
|
| 34 |
+
- cycler==0.12.1
|
| 35 |
+
- datasets==3.1.0
|
| 36 |
+
- diffusers==0.32.2
|
| 37 |
+
- dill==0.3.8
|
| 38 |
+
- filelock==3.16.1
|
| 39 |
+
- fonttools==4.56.0
|
| 40 |
+
- frozenlist==1.5.0
|
| 41 |
+
- fsspec==2024.9.0
|
| 42 |
+
- ftfy==6.2.3
|
| 43 |
+
- huggingface-hub==0.29.3
|
| 44 |
+
- idna==3.10
|
| 45 |
+
- importlib-metadata==8.5.0
|
| 46 |
+
- importlib-resources==6.4.5
|
| 47 |
+
- jinja2==3.1.6
|
| 48 |
+
- joblib==1.4.2
|
| 49 |
+
- kiwisolver==1.4.7
|
| 50 |
+
- loguru==0.7.3
|
| 51 |
+
- markupsafe==2.1.5
|
| 52 |
+
- matplotlib==3.7.5
|
| 53 |
+
- mpmath==1.3.0
|
| 54 |
+
- multidict==6.1.0
|
| 55 |
+
- multiprocess==0.70.16
|
| 56 |
+
- networkx==3.1
|
| 57 |
+
- numpy==1.24.4
|
| 58 |
+
- nvidia-cublas-cu12==12.1.3.1
|
| 59 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
| 60 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
| 61 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
| 62 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
| 63 |
+
- nvidia-cufft-cu12==11.0.2.54
|
| 64 |
+
- nvidia-curand-cu12==10.3.2.106
|
| 65 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
| 66 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
| 67 |
+
- nvidia-nccl-cu12==2.20.5
|
| 68 |
+
- nvidia-nvjitlink-cu12==12.8.93
|
| 69 |
+
- nvidia-nvtx-cu12==12.1.105
|
| 70 |
+
- open-clip-torch==2.31.0
|
| 71 |
+
- opencv-python==4.11.0.86
|
| 72 |
+
- packaging==24.2
|
| 73 |
+
- pandas==2.0.3
|
| 74 |
+
- pillow==10.4.0
|
| 75 |
+
- propcache==0.2.0
|
| 76 |
+
- psutil==7.0.0
|
| 77 |
+
- pyarrow==17.0.0
|
| 78 |
+
- pycocotools==2.0.7
|
| 79 |
+
- pyparsing==3.1.4
|
| 80 |
+
- python-dateutil==2.9.0.post0
|
| 81 |
+
- pytz==2025.1
|
| 82 |
+
- pyyaml==6.0.2
|
| 83 |
+
- regex==2024.11.6
|
| 84 |
+
- requests==2.32.3
|
| 85 |
+
- safetensors==0.5.3
|
| 86 |
+
- scikit-learn==1.3.2
|
| 87 |
+
- scipy==1.10.1
|
| 88 |
+
- six==1.17.0
|
| 89 |
+
- sympy==1.13.3
|
| 90 |
+
- threadpoolctl==3.5.0
|
| 91 |
+
- timm==1.0.15
|
| 92 |
+
- tokenizers==0.20.3
|
| 93 |
+
- torch==2.4.1
|
| 94 |
+
- torchvision==0.19.1
|
| 95 |
+
- tqdm==4.67.1
|
| 96 |
+
- transformers==4.46.3
|
| 97 |
+
- triton==3.0.0
|
| 98 |
+
- typing-extensions==4.12.2
|
| 99 |
+
- tzdata==2025.1
|
| 100 |
+
- urllib3==2.2.3
|
| 101 |
+
- wcwidth==0.2.13
|
| 102 |
+
- xxhash==3.5.0
|
| 103 |
+
- yarl==1.15.2
|
| 104 |
+
- zipp==3.20.2
|
| 105 |
+
prefix: /connect4/cheng535-new/anaconda3/envs/cospy
|
evaluate.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from sklearn.metrics import average_precision_score
|
| 8 |
+
import csv
|
| 9 |
+
|
| 10 |
+
from Detectors import CospyCalibrateDetector
|
| 11 |
+
from Datasets import TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
|
| 12 |
+
from utils import seed_torch
|
| 13 |
+
from sklearn.metrics import (
|
| 14 |
+
accuracy_score, log_loss, average_precision_score, f1_score,
|
| 15 |
+
roc_auc_score, balanced_accuracy_score, confusion_matrix, recall_score
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import warnings
|
| 20 |
+
warnings.filterwarnings("ignore")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Detector():
|
| 24 |
+
def __init__(self, args):
|
| 25 |
+
super(Detector, self).__init__()
|
| 26 |
+
|
| 27 |
+
# Device
|
| 28 |
+
self.device = args.device
|
| 29 |
+
|
| 30 |
+
# Initialize the detector
|
| 31 |
+
self.model = CospyCalibrateDetector(
|
| 32 |
+
semantic_weights_path=args.semantic_weights_path,
|
| 33 |
+
artifact_weights_path=args.artifact_weights_path)
|
| 34 |
+
|
| 35 |
+
# Load the pre-trained weights
|
| 36 |
+
self.model.load_weights(args.classifier_weights_path)
|
| 37 |
+
self.model.eval()
|
| 38 |
+
|
| 39 |
+
# Put the model on the device
|
| 40 |
+
self.model.to(self.device)
|
| 41 |
+
|
| 42 |
+
# Prediction function
|
| 43 |
+
def predict(self, inputs):
|
| 44 |
+
inputs = inputs.to(self.device)
|
| 45 |
+
outputs = self.model(inputs)
|
| 46 |
+
prediction = outputs.sigmoid().flatten().tolist()
|
| 47 |
+
return prediction
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def expected_calibration_error(y_true, y_prob, n_bins=10):
|
| 51 |
+
"""Tính ECE (Expected Calibration Error)"""
|
| 52 |
+
y_true = np.array(y_true)
|
| 53 |
+
y_prob = np.array(y_prob)
|
| 54 |
+
bins = np.linspace(0.0, 1.0, n_bins + 1)
|
| 55 |
+
ece = 0.0
|
| 56 |
+
for i in range(n_bins):
|
| 57 |
+
mask = (y_prob > bins[i]) & (y_prob <= bins[i+1])
|
| 58 |
+
if np.sum(mask) > 0:
|
| 59 |
+
prob_mean = y_prob[mask].mean()
|
| 60 |
+
acc = y_true[mask].mean()
|
| 61 |
+
ece += np.sum(mask) / len(y_true) * abs(acc - prob_mean)
|
| 62 |
+
return ece
|
| 63 |
+
|
| 64 |
+
def evaluate(y_pred, y_true):
|
| 65 |
+
y_pred = np.array(y_pred)
|
| 66 |
+
y_true = np.array(y_true)
|
| 67 |
+
pred_label = y_pred > 0.5
|
| 68 |
+
|
| 69 |
+
# Metrics
|
| 70 |
+
acc = accuracy_score(y_true, pred_label)
|
| 71 |
+
nll = log_loss(y_true, y_pred, eps=1e-7)
|
| 72 |
+
ap = average_precision_score(y_true, y_pred)
|
| 73 |
+
ece = expected_calibration_error(y_true, y_pred)
|
| 74 |
+
f1 = f1_score(y_true, pred_label)
|
| 75 |
+
try:
|
| 76 |
+
auc = roc_auc_score(y_true, y_pred)
|
| 77 |
+
except:
|
| 78 |
+
auc = float('nan')
|
| 79 |
+
bacc = balanced_accuracy_score(y_true, pred_label)
|
| 80 |
+
tn, fp, fn, tp = confusion_matrix(y_true, pred_label).ravel()
|
| 81 |
+
fnr = fn / (fn + tp) if (fn + tp) > 0 else float('nan')
|
| 82 |
+
recall_total = recall_score(y_true, pred_label) # recall tổng thể
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"ACC": acc,
|
| 86 |
+
"NLL": nll,
|
| 87 |
+
"AP": ap,
|
| 88 |
+
"ECE": ece,
|
| 89 |
+
"F1": f1,
|
| 90 |
+
"AUC": auc,
|
| 91 |
+
"bAcc": bacc,
|
| 92 |
+
"FNR": fnr,
|
| 93 |
+
"Recall": recall_total
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test(args):
|
| 99 |
+
# Initialize the detector
|
| 100 |
+
detector = Detector(args)
|
| 101 |
+
|
| 102 |
+
# Set the saving directory
|
| 103 |
+
if not os.path.exists(args.save_dir):
|
| 104 |
+
os.makedirs(args.save_dir)
|
| 105 |
+
save_result_path = os.path.join(args.save_dir, "result.json")
|
| 106 |
+
save_output_path = os.path.join(args.save_dir, "output.json")
|
| 107 |
+
|
| 108 |
+
# Begin the evaluation
|
| 109 |
+
result_all = {}
|
| 110 |
+
output_all = {}
|
| 111 |
+
for dataset_name in EVAL_DATASET_LIST:
|
| 112 |
+
result_all[dataset_name] = {}
|
| 113 |
+
output_all[dataset_name] = {}
|
| 114 |
+
for model_name in EVAL_MODEL_LIST:
|
| 115 |
+
test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=detector.model.test_transform)
|
| 116 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 117 |
+
batch_size=args.batch_size,
|
| 118 |
+
shuffle=False,
|
| 119 |
+
num_workers=4,
|
| 120 |
+
pin_memory=True)
|
| 121 |
+
|
| 122 |
+
# Evaluate the model
|
| 123 |
+
y_pred, y_true = [], []
|
| 124 |
+
for images, labels, _ in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
|
| 125 |
+
y_pred.extend(detector.predict(images))
|
| 126 |
+
y_true.extend(labels.tolist())
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
metrics = evaluate(y_pred, y_true)
|
| 130 |
+
print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | "
|
| 131 |
+
f"ACC {metrics['ACC']*100:.2f}% | Recall {metrics['Recall']*100:.2f}% | "
|
| 132 |
+
f"NLL {metrics['NLL']:.4f} | AP {metrics['AP']*100:.2f}% | "
|
| 133 |
+
f"ECE {metrics['ECE']:.4f} | F1 {metrics['F1']*100:.2f}% | "
|
| 134 |
+
f"AUC {metrics['AUC']*100:.2f}% | bAcc {metrics['bAcc']*100:.2f}% | "
|
| 135 |
+
f"FNR {metrics['FNR']*100:.2f}%")
|
| 136 |
+
|
| 137 |
+
result_all[dataset_name][model_name] = {"size": len(y_true), **metrics}
|
| 138 |
+
csv_dir = os.path.join(args.save_dir, "csv_outputs")
|
| 139 |
+
os.makedirs(csv_dir, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
csv_path = os.path.join(csv_dir, f"{dataset_name}_{model_name}.csv")
|
| 142 |
+
|
| 143 |
+
with open(csv_path, mode="w", newline="", encoding="utf-8") as f:
|
| 144 |
+
writer = csv.writer(f)
|
| 145 |
+
writer.writerow(["path_to_image", "true_label", "pred_percentage", "pred_label"])
|
| 146 |
+
|
| 147 |
+
idx = 0
|
| 148 |
+
for img_path in test_dataset.image_paths:
|
| 149 |
+
pred_score = float(y_pred[idx])
|
| 150 |
+
pred_label = 1 if pred_score > 0.5 else 0
|
| 151 |
+
true_label = int(y_true[idx])
|
| 152 |
+
|
| 153 |
+
writer.writerow([
|
| 154 |
+
img_path,
|
| 155 |
+
true_label,
|
| 156 |
+
pred_score,
|
| 157 |
+
pred_label
|
| 158 |
+
])
|
| 159 |
+
idx += 1
|
| 160 |
+
|
| 161 |
+
print(f"[CSV SAVED] {csv_path}")
|
| 162 |
+
output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
|
| 163 |
+
|
| 164 |
+
# Save the results
|
| 165 |
+
with open(save_result_path, "w") as f:
|
| 166 |
+
json.dump(result_all, f, indent=4)
|
| 167 |
+
|
| 168 |
+
with open(save_output_path, "w") as f:
|
| 169 |
+
json.dump(output_all, f, indent=4)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def scan(args):
|
| 173 |
+
# Initialize the detector
|
| 174 |
+
detector = Detector(args)
|
| 175 |
+
|
| 176 |
+
# Define the pre-processing function
|
| 177 |
+
test_transform = detector.model.test_transform
|
| 178 |
+
|
| 179 |
+
# Load the image
|
| 180 |
+
image_filepath = input("Please enter the image filepath for scanning: ")
|
| 181 |
+
if not os.path.exists(image_filepath):
|
| 182 |
+
print(f"Image file not found: {image_filepath}")
|
| 183 |
+
image_filepath = input("Please enter the image filepath for scanning: ")
|
| 184 |
+
|
| 185 |
+
image = Image.open(image_filepath).convert("RGB")
|
| 186 |
+
image = test_transform(image)
|
| 187 |
+
image = image.unsqueeze(0)
|
| 188 |
+
image = image.to(args.device)
|
| 189 |
+
|
| 190 |
+
# Make the prediction
|
| 191 |
+
prediction = detector.predict(image)[0]
|
| 192 |
+
|
| 193 |
+
if prediction > 0.5:
|
| 194 |
+
print(f"CO-SPY Prediction: {prediction:.3f} - AI-Generated")
|
| 195 |
+
else:
|
| 196 |
+
print(f"CO-SPY Prediction: {prediction:.3f} - Real")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
import argparse
|
| 201 |
+
|
| 202 |
+
parser = argparse.ArgumentParser("Deep Fake Detection")
|
| 203 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
|
| 204 |
+
parser.add_argument("--phase", type=str, default="scan", choices=["scan", "test"], help="Phase of the experiment")
|
| 205 |
+
parser.add_argument("--semantic_weights_path", type=str, default="pretrained/semantic_weights.pth", help="Semantic weights path")
|
| 206 |
+
parser.add_argument("--artifact_weights_path", type=str, default="pretrained/artifact_weights.pth", help="Artifact weights path")
|
| 207 |
+
parser.add_argument("--classifier_weights_path", type=str, default="pretrained/classifier_weights.pth", help="Classifier weights path")
|
| 208 |
+
parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
|
| 209 |
+
parser.add_argument("--save_dir", type=str, default="test_results", help="Save directory")
|
| 210 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
| 211 |
+
parser.add_argument("--seed", type=int, default=1024, help="Random seed")
|
| 212 |
+
|
| 213 |
+
args = parser.parse_args()
|
| 214 |
+
|
| 215 |
+
# Set the random seed
|
| 216 |
+
seed_torch(args.seed)
|
| 217 |
+
|
| 218 |
+
# Set the GPU ID
|
| 219 |
+
args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| 220 |
+
|
| 221 |
+
# Begin the experiment
|
| 222 |
+
if args.phase == "scan":
|
| 223 |
+
scan(args)
|
| 224 |
+
elif args.phase == "test":
|
| 225 |
+
test(args)
|
| 226 |
+
else:
|
| 227 |
+
raise ValueError("Unknown phase")
|
pretrained/classifer_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b3cd62721dca4183bfd37790c12ccaf964f71fe7c6bbf4d97eda5f44c6bafab
|
| 3 |
+
size 1456
|
pretrained/classifier_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bf7e0efb68cf57718742ec3c944640856fd86ddaf1bb219e6cacdc280f781dc
|
| 3 |
+
size 1450
|
pretrained/semantic_weights.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a3e7c4cf6534e7fac0f2f898d3764d5aa892653dab96ed1316fd123fa4e0a17c
|
| 3 |
+
size 6064
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
opencv-python-headless
|
| 4 |
+
numpy
|
| 5 |
+
Pillow
|
| 6 |
+
streamlit
|
| 7 |
+
tqdm
|
| 8 |
+
einops
|
train.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from sklearn.metrics import average_precision_score
|
| 9 |
+
|
| 10 |
+
from utils import seed_torch
|
| 11 |
+
from Detectors import CospyDetector, LabelSmoothingBCEWithLogits
|
| 12 |
+
from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
|
| 13 |
+
|
| 14 |
+
import warnings
|
| 15 |
+
warnings.filterwarnings("ignore")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Detector():
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super(Detector, self).__init__()
|
| 21 |
+
|
| 22 |
+
# Device
|
| 23 |
+
self.device = args.device
|
| 24 |
+
|
| 25 |
+
# Get the detector
|
| 26 |
+
self.model = CospyDetector()
|
| 27 |
+
|
| 28 |
+
# Put the model on the device
|
| 29 |
+
self.model.to(self.device)
|
| 30 |
+
|
| 31 |
+
# Initialize the fc layer
|
| 32 |
+
torch.nn.init.normal_(self.model.sem.fc.weight.data, 0.0, 0.02)
|
| 33 |
+
torch.nn.init.normal_(self.model.art.fc.weight.data, 0.0, 0.02)
|
| 34 |
+
torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
|
| 35 |
+
|
| 36 |
+
# Optimizer
|
| 37 |
+
_lr = 1e-4
|
| 38 |
+
_beta1 = 0.9
|
| 39 |
+
_weight_decay = 0.0
|
| 40 |
+
params = []
|
| 41 |
+
for name, param in self.model.named_parameters():
|
| 42 |
+
if param.requires_grad:
|
| 43 |
+
params.append(param)
|
| 44 |
+
print(f"Trainable parameters: {len(params)}")
|
| 45 |
+
|
| 46 |
+
self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay)
|
| 47 |
+
|
| 48 |
+
# Loss function
|
| 49 |
+
if args.no_label_smooth:
|
| 50 |
+
self.criterion = torch.nn.BCEWithLogitsLoss()
|
| 51 |
+
else:
|
| 52 |
+
self.criterion = LabelSmoothingBCEWithLogits(smoothing=0.1)
|
| 53 |
+
|
| 54 |
+
# Scheduler
|
| 55 |
+
self.delr_freq = 10
|
| 56 |
+
|
| 57 |
+
# Training function for the detector
|
| 58 |
+
def train_step(self, batch_data):
|
| 59 |
+
# Decompose the batch data
|
| 60 |
+
inputs, labels = batch_data
|
| 61 |
+
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
| 62 |
+
|
| 63 |
+
self.optimizer.zero_grad()
|
| 64 |
+
|
| 65 |
+
outputs = self.model(inputs)
|
| 66 |
+
|
| 67 |
+
loss = self.criterion(outputs, labels.unsqueeze(1).float())
|
| 68 |
+
loss.backward()
|
| 69 |
+
self.optimizer.step()
|
| 70 |
+
|
| 71 |
+
eval_loss = loss.item()
|
| 72 |
+
y_pred = outputs.sigmoid().flatten().tolist()
|
| 73 |
+
y_true = labels.tolist()
|
| 74 |
+
return eval_loss, y_pred, y_true
|
| 75 |
+
|
| 76 |
+
# Schedule the training
|
| 77 |
+
# Early stopping / learning rate adjustment
|
| 78 |
+
def scheduler(self, status_dict):
|
| 79 |
+
epoch = status_dict["epoch"]
|
| 80 |
+
if epoch % self.delr_freq == 0 and epoch != 0:
|
| 81 |
+
for param_group in self.optimizer.param_groups:
|
| 82 |
+
param_group["lr"] *= 0.9
|
| 83 |
+
self.lr = param_group["lr"]
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
# Prediction function
|
| 87 |
+
def predict(self, inputs):
|
| 88 |
+
inputs = inputs.to(self.device)
|
| 89 |
+
outputs = self.model(inputs)
|
| 90 |
+
prediction = outputs.sigmoid().flatten().tolist()
|
| 91 |
+
return prediction
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def evaluate(y_pred, y_true):
|
| 95 |
+
ap = average_precision_score(y_true, y_pred)
|
| 96 |
+
accuracy = ((np.array(y_pred) > 0.5) == y_true).mean()
|
| 97 |
+
return ap, accuracy
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def train(args):
|
| 101 |
+
# Get the detector
|
| 102 |
+
detector = Detector(args)
|
| 103 |
+
|
| 104 |
+
# Load the dataset
|
| 105 |
+
train_dataset = TrainDataset(data_path=args.trainset_dirpath,
|
| 106 |
+
split="train",
|
| 107 |
+
transform=detector.model.train_transform)
|
| 108 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
| 109 |
+
batch_size=args.batch_size,
|
| 110 |
+
shuffle=True,
|
| 111 |
+
num_workers=4,
|
| 112 |
+
pin_memory=True)
|
| 113 |
+
|
| 114 |
+
test_dataset = TrainDataset(data_path=args.trainset_dirpath,
|
| 115 |
+
split="val",
|
| 116 |
+
transform=detector.model.test_transform)
|
| 117 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 118 |
+
batch_size=args.batch_size,
|
| 119 |
+
shuffle=False,
|
| 120 |
+
num_workers=4,
|
| 121 |
+
pin_memory=True)
|
| 122 |
+
|
| 123 |
+
logger.info(f"Train size {len(train_dataset)} | Test size {len(test_dataset)}")
|
| 124 |
+
|
| 125 |
+
# Set the saving directory
|
| 126 |
+
model_dir = os.path.join(args.ckpt, "cospy")
|
| 127 |
+
if not os.path.exists(model_dir):
|
| 128 |
+
os.makedirs(model_dir)
|
| 129 |
+
log_path = f"{model_dir}/training.log"
|
| 130 |
+
if os.path.exists(log_path):
|
| 131 |
+
os.remove(log_path)
|
| 132 |
+
|
| 133 |
+
logger_id = logger.add(
|
| 134 |
+
log_path,
|
| 135 |
+
format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
|
| 136 |
+
level="DEBUG",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Train the detector
|
| 140 |
+
best_acc = 0
|
| 141 |
+
for epoch in range(args.epochs):
|
| 142 |
+
# Set the model to training mode
|
| 143 |
+
detector.model.train()
|
| 144 |
+
time_start = time.time()
|
| 145 |
+
for step_id, batch_data in enumerate(train_loader):
|
| 146 |
+
eval_loss, y_pred, y_true = detector.train_step(batch_data)
|
| 147 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 148 |
+
|
| 149 |
+
# Log the training information
|
| 150 |
+
if (step_id + 1) % 100 == 0:
|
| 151 |
+
time_end = time.time()
|
| 152 |
+
logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s")
|
| 153 |
+
time_start = time.time()
|
| 154 |
+
|
| 155 |
+
# Evaluate the model
|
| 156 |
+
detector.model.eval()
|
| 157 |
+
y_pred, y_true = [], []
|
| 158 |
+
for (images, labels) in test_loader:
|
| 159 |
+
y_pred.extend(detector.predict(images))
|
| 160 |
+
y_true.extend(labels.tolist())
|
| 161 |
+
|
| 162 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 163 |
+
logger.info(f"Epoch {epoch} | Test AP {ap*100:.2f}% | Test Accuracy {accuracy*100:.2f}%")
|
| 164 |
+
|
| 165 |
+
# Schedule the training
|
| 166 |
+
status_dict = {"epoch": epoch, "AP": ap, "Accuracy": accuracy}
|
| 167 |
+
proceed = detector.scheduler(status_dict)
|
| 168 |
+
if not proceed:
|
| 169 |
+
logger.info("Early stopping")
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
# Save the model
|
| 173 |
+
if accuracy >= best_acc:
|
| 174 |
+
best_acc = accuracy
|
| 175 |
+
detector.model.save_weights(f"{model_dir}/best_model.pth")
|
| 176 |
+
logger.info(f"Best model saved with accuracy {best_acc.mean()*100:.2f}%")
|
| 177 |
+
|
| 178 |
+
if epoch % 5 == 0:
|
| 179 |
+
detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth")
|
| 180 |
+
logger.info(f"Model saved at epoch {epoch}")
|
| 181 |
+
|
| 182 |
+
# Save the final model
|
| 183 |
+
detector.model.save_weights(f"{model_dir}/final_model.pth")
|
| 184 |
+
logger.info("Final model saved")
|
| 185 |
+
|
| 186 |
+
# Remove the logger
|
| 187 |
+
logger.remove(logger_id)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def test(args):
|
| 191 |
+
# Initialize the detector
|
| 192 |
+
detector = Detector(args)
|
| 193 |
+
|
| 194 |
+
# Load the [best/final] model
|
| 195 |
+
weights_path = os.path.join(args.ckpt, "cospy", "best_model.pth")
|
| 196 |
+
|
| 197 |
+
detector.model.load_weights(weights_path)
|
| 198 |
+
detector.model.to(args.device)
|
| 199 |
+
detector.model.eval()
|
| 200 |
+
|
| 201 |
+
# Set the pre-processing function
|
| 202 |
+
test_transform = detector.model.test_transform
|
| 203 |
+
|
| 204 |
+
# Set the saving directory
|
| 205 |
+
save_dir = os.path.join(args.ckpt, "cospy")
|
| 206 |
+
save_result_path = os.path.join(save_dir, "result.json")
|
| 207 |
+
save_output_path = os.path.join(save_dir, "output.json")
|
| 208 |
+
|
| 209 |
+
# Begin the evaluation
|
| 210 |
+
result_all = {}
|
| 211 |
+
output_all = {}
|
| 212 |
+
for dataset_name in EVAL_DATASET_LIST:
|
| 213 |
+
result_all[dataset_name] = {}
|
| 214 |
+
output_all[dataset_name] = {}
|
| 215 |
+
for model_name in EVAL_MODEL_LIST:
|
| 216 |
+
test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform)
|
| 217 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 218 |
+
batch_size=args.batch_size,
|
| 219 |
+
shuffle=False,
|
| 220 |
+
num_workers=4,
|
| 221 |
+
pin_memory=True)
|
| 222 |
+
|
| 223 |
+
# Evaluate the model
|
| 224 |
+
y_pred, y_true = [], []
|
| 225 |
+
for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
|
| 226 |
+
y_pred.extend(detector.predict(images))
|
| 227 |
+
y_true.extend(labels.tolist())
|
| 228 |
+
|
| 229 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 230 |
+
print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%")
|
| 231 |
+
|
| 232 |
+
result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy}
|
| 233 |
+
output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
|
| 234 |
+
|
| 235 |
+
# Save the results
|
| 236 |
+
with open(save_result_path, "w") as f:
|
| 237 |
+
json.dump(result_all, f, indent=4)
|
| 238 |
+
|
| 239 |
+
with open(save_output_path, "w") as f:
|
| 240 |
+
json.dump(output_all, f, indent=4)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
import argparse
|
| 245 |
+
|
| 246 |
+
parser = argparse.ArgumentParser("Deep Fake Detection")
|
| 247 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
|
| 248 |
+
parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment")
|
| 249 |
+
parser.add_argument("--no_label_smooth", action="store_true", help="Whether to use label smoothing")
|
| 250 |
+
parser.add_argument("--trainset_dirpath", type=str, default="data/train", help="Trainset directory")
|
| 251 |
+
parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
|
| 252 |
+
parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory")
|
| 253 |
+
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
| 254 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
| 255 |
+
parser.add_argument("--seed", type=int, default=1024, help="Random seed")
|
| 256 |
+
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
# Set the random seed
|
| 260 |
+
seed_torch(args.seed)
|
| 261 |
+
|
| 262 |
+
# Set the GPU ID
|
| 263 |
+
args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| 264 |
+
|
| 265 |
+
# Begin the experiment
|
| 266 |
+
if args.phase == "train":
|
| 267 |
+
train(args)
|
| 268 |
+
elif args.phase == "test":
|
| 269 |
+
test(args)
|
| 270 |
+
else:
|
| 271 |
+
raise ValueError("Unknown phase")
|
train_single.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from loguru import logger
|
| 8 |
+
from sklearn.metrics import average_precision_score
|
| 9 |
+
|
| 10 |
+
from utils import seed_torch
|
| 11 |
+
from Detectors import ArtifactDetector, SemanticDetector
|
| 12 |
+
from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
|
| 13 |
+
|
| 14 |
+
import warnings
|
| 15 |
+
warnings.filterwarnings("ignore")
|
| 16 |
+
|
| 17 |
+
class Detector():
|
| 18 |
+
def __init__(self, args):
|
| 19 |
+
super(Detector, self).__init__()
|
| 20 |
+
|
| 21 |
+
# Device
|
| 22 |
+
self.device = args.device
|
| 23 |
+
|
| 24 |
+
# Get the detector
|
| 25 |
+
if args.detector == "artifact":
|
| 26 |
+
self.model = ArtifactDetector()
|
| 27 |
+
elif args.detector == "semantic":
|
| 28 |
+
self.model = SemanticDetector()
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError("Unknown detector")
|
| 31 |
+
|
| 32 |
+
# Put the model on the device
|
| 33 |
+
self.model.to(self.device)
|
| 34 |
+
|
| 35 |
+
# Initialize the fc layer
|
| 36 |
+
torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
|
| 37 |
+
|
| 38 |
+
# Optimizer
|
| 39 |
+
_lr = 1e-4
|
| 40 |
+
_beta1 = 0.9
|
| 41 |
+
_weight_decay = 0.0
|
| 42 |
+
params = [p for p in self.model.parameters() if p.requires_grad]
|
| 43 |
+
print(f"Trainable parameters: {len(params)}")
|
| 44 |
+
|
| 45 |
+
self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay)
|
| 46 |
+
|
| 47 |
+
# Loss function
|
| 48 |
+
self.criterion = torch.nn.BCEWithLogitsLoss()
|
| 49 |
+
|
| 50 |
+
# Scheduler
|
| 51 |
+
self.delr_freq = 10
|
| 52 |
+
|
| 53 |
+
# Resume info
|
| 54 |
+
self.start_epoch = 0
|
| 55 |
+
self.best_acc = 0.0
|
| 56 |
+
|
| 57 |
+
def train_step(self, batch_data):
|
| 58 |
+
inputs, labels = batch_data
|
| 59 |
+
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
| 60 |
+
|
| 61 |
+
self.optimizer.zero_grad()
|
| 62 |
+
outputs = self.model(inputs)
|
| 63 |
+
loss = self.criterion(outputs, labels.unsqueeze(1).float())
|
| 64 |
+
loss.backward()
|
| 65 |
+
self.optimizer.step()
|
| 66 |
+
|
| 67 |
+
eval_loss = loss.item()
|
| 68 |
+
y_pred = outputs.sigmoid().flatten().tolist()
|
| 69 |
+
y_true = labels.tolist()
|
| 70 |
+
return eval_loss, y_pred, y_true
|
| 71 |
+
|
| 72 |
+
def scheduler(self, status_dict):
|
| 73 |
+
epoch = status_dict["epoch"]
|
| 74 |
+
if epoch % self.delr_freq == 0 and epoch != 0:
|
| 75 |
+
for param_group in self.optimizer.param_groups:
|
| 76 |
+
param_group["lr"] *= 0.9
|
| 77 |
+
self.lr = param_group["lr"]
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
def predict(self, inputs):
|
| 81 |
+
inputs = inputs.to(self.device)
|
| 82 |
+
outputs = self.model(inputs)
|
| 83 |
+
return outputs.sigmoid().flatten().tolist()
|
| 84 |
+
|
| 85 |
+
# --- Checkpoint functions ---
|
| 86 |
+
def save_checkpoint(self, path, epoch, best_acc):
|
| 87 |
+
torch.save({
|
| 88 |
+
"epoch": epoch,
|
| 89 |
+
"best_acc": best_acc,
|
| 90 |
+
"model_state": self.model.state_dict(),
|
| 91 |
+
"optimizer_state": self.optimizer.state_dict()
|
| 92 |
+
}, path)
|
| 93 |
+
|
| 94 |
+
def load_checkpoint(self, path):
|
| 95 |
+
if os.path.exists(path):
|
| 96 |
+
ckpt = torch.load(path, map_location=self.device)
|
| 97 |
+
self.model.load_state_dict(ckpt["model_state"])
|
| 98 |
+
self.optimizer.load_state_dict(ckpt["optimizer_state"])
|
| 99 |
+
self.start_epoch = ckpt.get("epoch", 0) + 1
|
| 100 |
+
self.best_acc = ckpt.get("best_acc", 0.0)
|
| 101 |
+
print(f"[INFO] Loaded checkpoint '{path}' (start_epoch={self.start_epoch}, best_acc={self.best_acc})")
|
| 102 |
+
else:
|
| 103 |
+
print(f"[WARNING] Checkpoint not found: {path}")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def evaluate(y_pred, y_true):
|
| 107 |
+
ap = average_precision_score(y_true, y_pred)
|
| 108 |
+
accuracy = ((np.array(y_pred) > 0.5) == y_true).mean()
|
| 109 |
+
return ap, accuracy
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def train(args):
|
| 113 |
+
# Get the detector
|
| 114 |
+
detector = Detector(args)
|
| 115 |
+
|
| 116 |
+
# --- Resume checkpoint ---
|
| 117 |
+
start_epoch = 0
|
| 118 |
+
best_acc = 0
|
| 119 |
+
if args.resume != "":
|
| 120 |
+
if os.path.exists(args.resume):
|
| 121 |
+
print(f"[INFO] Loading checkpoint from {args.resume}")
|
| 122 |
+
ckpt = torch.load(args.resume, map_location=args.device)
|
| 123 |
+
detector.model.load_weights(args.resume)
|
| 124 |
+
# Nếu lưu thêm optimizer & best_acc, load ở đây
|
| 125 |
+
if "best_acc" in ckpt:
|
| 126 |
+
best_acc = ckpt["best_acc"]
|
| 127 |
+
if "epoch" in ckpt:
|
| 128 |
+
start_epoch = ckpt["epoch"] + 1
|
| 129 |
+
else:
|
| 130 |
+
print(f"[WARNING] Resume checkpoint not found: {args.resume}")
|
| 131 |
+
|
| 132 |
+
# Load datasets
|
| 133 |
+
train_dataset = TrainDataset(data_path=args.trainset_dirpath,
|
| 134 |
+
split="train",
|
| 135 |
+
transform=detector.model.train_transform)
|
| 136 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
| 137 |
+
batch_size=args.batch_size,
|
| 138 |
+
shuffle=True,
|
| 139 |
+
num_workers=4,
|
| 140 |
+
pin_memory=True)
|
| 141 |
+
|
| 142 |
+
test_dataset = TrainDataset(data_path=args.trainset_dirpath,
|
| 143 |
+
split="val",
|
| 144 |
+
transform=detector.model.test_transform)
|
| 145 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 146 |
+
batch_size=args.batch_size,
|
| 147 |
+
shuffle=False,
|
| 148 |
+
num_workers=4,
|
| 149 |
+
pin_memory=True)
|
| 150 |
+
|
| 151 |
+
logger.info(f"Train size {len(train_dataset)} | Test size {len(test_dataset)}")
|
| 152 |
+
|
| 153 |
+
# Set saving directory
|
| 154 |
+
model_dir = os.path.join(args.ckpt, args.detector)
|
| 155 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 156 |
+
log_path = f"{model_dir}/training.log"
|
| 157 |
+
if os.path.exists(log_path):
|
| 158 |
+
os.remove(log_path)
|
| 159 |
+
logger_id = logger.add(log_path, format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}", level="DEBUG")
|
| 160 |
+
|
| 161 |
+
# Train loop
|
| 162 |
+
for epoch in range(start_epoch, args.epochs):
|
| 163 |
+
detector.model.train()
|
| 164 |
+
time_start = time.time()
|
| 165 |
+
for step_id, batch_data in enumerate(train_loader):
|
| 166 |
+
eval_loss, y_pred, y_true = detector.train_step(batch_data)
|
| 167 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 168 |
+
|
| 169 |
+
if (step_id + 1) % 100 == 0:
|
| 170 |
+
time_end = time.time()
|
| 171 |
+
logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s")
|
| 172 |
+
time_start = time.time()
|
| 173 |
+
|
| 174 |
+
# Evaluate
|
| 175 |
+
detector.model.eval()
|
| 176 |
+
y_pred, y_true = [], []
|
| 177 |
+
for (images, labels) in test_loader:
|
| 178 |
+
y_pred.extend(detector.predict(images))
|
| 179 |
+
y_true.extend(labels.tolist())
|
| 180 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 181 |
+
logger.info(f"Epoch {epoch} | Test AP {ap*100:.2f}% | Test Accuracy {accuracy*100:.2f}%")
|
| 182 |
+
|
| 183 |
+
# Save best model
|
| 184 |
+
if accuracy >= best_acc:
|
| 185 |
+
best_acc = accuracy
|
| 186 |
+
detector.model.save_weights(f"{model_dir}/best_model.pth")
|
| 187 |
+
torch.save({"epoch": epoch, "best_acc": best_acc}, f"{model_dir}/best_model_meta.pth")
|
| 188 |
+
logger.info(f"Best model saved with accuracy {best_acc*100:.2f}%")
|
| 189 |
+
|
| 190 |
+
# Save periodic checkpoints
|
| 191 |
+
if epoch % 5 == 0:
|
| 192 |
+
detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth")
|
| 193 |
+
logger.info(f"Model saved at epoch {epoch}")
|
| 194 |
+
|
| 195 |
+
# Save final model
|
| 196 |
+
detector.model.save_weights(f"{model_dir}/final_model.pth")
|
| 197 |
+
logger.info("Final model saved")
|
| 198 |
+
logger.remove(logger_id)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def test(args):
|
| 203 |
+
# Initialize the detector
|
| 204 |
+
detector = Detector(args)
|
| 205 |
+
# --- Load checkpoint if resume is provided ---
|
| 206 |
+
if args.resume != "":
|
| 207 |
+
ckpt_path = args.resume
|
| 208 |
+
if os.path.exists(ckpt_path):
|
| 209 |
+
print(f"[INFO] Loading checkpoint from {ckpt_path}")
|
| 210 |
+
detector.model.load_weights(ckpt_path)
|
| 211 |
+
else:
|
| 212 |
+
print(f"[WARNING] Resume checkpoint not found: {ckpt_path}")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Load the [best/final] model
|
| 216 |
+
weights_path = os.path.join(args.ckpt, args.detector, "best_model.pth")
|
| 217 |
+
|
| 218 |
+
detector.model.load_weights(weights_path)
|
| 219 |
+
detector.model.to(args.device)
|
| 220 |
+
detector.model.eval()
|
| 221 |
+
|
| 222 |
+
# Set the pre-processing function
|
| 223 |
+
test_transform = detector.model.test_transform
|
| 224 |
+
|
| 225 |
+
# Set the saving directory
|
| 226 |
+
save_dir = os.path.join(args.ckpt, args.detector)
|
| 227 |
+
save_result_path = os.path.join(save_dir, "result.json")
|
| 228 |
+
save_output_path = os.path.join(save_dir, "output.json")
|
| 229 |
+
|
| 230 |
+
# Begin the evaluation
|
| 231 |
+
result_all = {}
|
| 232 |
+
output_all = {}
|
| 233 |
+
for dataset_name in EVAL_DATASET_LIST:
|
| 234 |
+
result_all[dataset_name] = {}
|
| 235 |
+
output_all[dataset_name] = {}
|
| 236 |
+
for model_name in EVAL_MODEL_LIST:
|
| 237 |
+
test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform)
|
| 238 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 239 |
+
batch_size=args.batch_size,
|
| 240 |
+
shuffle=False,
|
| 241 |
+
num_workers=4,
|
| 242 |
+
pin_memory=True)
|
| 243 |
+
|
| 244 |
+
# Evaluate the model
|
| 245 |
+
y_pred, y_true = [], []
|
| 246 |
+
for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
|
| 247 |
+
y_pred.extend(detector.predict(images))
|
| 248 |
+
y_true.extend(labels.tolist())
|
| 249 |
+
|
| 250 |
+
ap, accuracy = evaluate(y_pred, y_true)
|
| 251 |
+
print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%")
|
| 252 |
+
|
| 253 |
+
result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy}
|
| 254 |
+
output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
|
| 255 |
+
|
| 256 |
+
# Save the results
|
| 257 |
+
with open(save_result_path, "w") as f:
|
| 258 |
+
json.dump(result_all, f, indent=4)
|
| 259 |
+
|
| 260 |
+
with open(save_output_path, "w") as f:
|
| 261 |
+
json.dump(output_all, f, indent=4)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
import argparse
|
| 266 |
+
|
| 267 |
+
parser = argparse.ArgumentParser("Deep Fake Detection")
|
| 268 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
|
| 269 |
+
parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment")
|
| 270 |
+
parser.add_argument("--detector", type=str, default="artifact", choices=["artifact", "semantic"], help="Detector to use")
|
| 271 |
+
parser.add_argument("--trainset_dirpath", type=str, default="data/train", help="Trainset directory")
|
| 272 |
+
parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
|
| 273 |
+
parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory")
|
| 274 |
+
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
| 275 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
| 276 |
+
parser.add_argument("--seed", type=int, default=1024, help="Random seed")
|
| 277 |
+
parser.add_argument("--resume", type=str, default="", help="Path to checkpoint to resume training")
|
| 278 |
+
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
|
| 281 |
+
# Set the random seed
|
| 282 |
+
seed_torch(args.seed)
|
| 283 |
+
|
| 284 |
+
# Set the GPU ID
|
| 285 |
+
args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| 286 |
+
|
| 287 |
+
# Begin the experiment
|
| 288 |
+
if args.phase == "train":
|
| 289 |
+
train(args)
|
| 290 |
+
elif args.phase == "test":
|
| 291 |
+
test(args)
|
| 292 |
+
else:
|
| 293 |
+
raise ValueError("Unknown phase")
|
utils.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import pickle
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from PIL import Image, ImageFile
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from scipy.ndimage.filters import gaussian_filter
|
| 11 |
+
|
| 12 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Set random seed
|
| 16 |
+
def seed_torch(seed):
|
| 17 |
+
random.seed(seed)
|
| 18 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 19 |
+
np.random.seed(seed)
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
torch.cuda.manual_seed(seed)
|
| 22 |
+
torch.cuda.manual_seed_all(seed)
|
| 23 |
+
torch.backends.cudnn.benchmark = False
|
| 24 |
+
torch.backends.cudnn.deterministic = True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Load dataset
|
| 28 |
+
def recursively_read(rootdir, must_contain, exts=["png", "PNG", "jpg", "JPG", "jpeg", "JPEG"]):
|
| 29 |
+
out = []
|
| 30 |
+
for r, d, f in os.walk(rootdir):
|
| 31 |
+
for file in f:
|
| 32 |
+
if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)):
|
| 33 |
+
out.append(os.path.join(r, file))
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_list(path, must_contain=''):
|
| 38 |
+
if ".pickle" in path:
|
| 39 |
+
with open(path, 'rb') as f:
|
| 40 |
+
image_list = pickle.load(f)
|
| 41 |
+
image_list = [item for item in image_list if must_contain in item]
|
| 42 |
+
else:
|
| 43 |
+
image_list = recursively_read(path, must_contain)
|
| 44 |
+
return image_list
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Data augmentation techniques
|
| 48 |
+
def data_augment(img, aug_config):
|
| 49 |
+
img = np.array(img)
|
| 50 |
+
if img.ndim == 2:
|
| 51 |
+
img = np.expand_dims(img, axis=2)
|
| 52 |
+
img = np.repeat(img, 3, axis=2)
|
| 53 |
+
|
| 54 |
+
if random.random() < aug_config["blur_prob"]:
|
| 55 |
+
sig = sample_continuous(aug_config["blur_sig"])
|
| 56 |
+
gaussian_blur(img, sig)
|
| 57 |
+
|
| 58 |
+
if random.random() < aug_config["jpg_prob"]:
|
| 59 |
+
method = sample_discrete(aug_config["jpg_method"])
|
| 60 |
+
qual = sample_discrete(aug_config["jpg_qual"])
|
| 61 |
+
img = jpeg_from_key(img, qual, method)
|
| 62 |
+
|
| 63 |
+
return Image.fromarray(img)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Data augmentation techniques
|
| 67 |
+
def tensor_data_augment(images, aug_config):
|
| 68 |
+
device = images.device
|
| 69 |
+
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 70 |
+
images = np.uint8(images * 255.)
|
| 71 |
+
outputs = []
|
| 72 |
+
for img in images:
|
| 73 |
+
if random.random() < aug_config["blur_prob"]:
|
| 74 |
+
sig = sample_continuous(aug_config["blur_sig"])
|
| 75 |
+
gaussian_blur(img, sig)
|
| 76 |
+
|
| 77 |
+
if random.random() < aug_config["jpg_prob"]:
|
| 78 |
+
method = sample_discrete(aug_config["jpg_method"])
|
| 79 |
+
qual = sample_discrete(aug_config["jpg_qual"])
|
| 80 |
+
img = jpeg_from_key(img, qual, method)
|
| 81 |
+
outputs.append(img)
|
| 82 |
+
outputs = np.stack(outputs)
|
| 83 |
+
outputs = torch.from_numpy(outputs).to(device).permute(0, 3, 1, 2).float() / 255.
|
| 84 |
+
return outputs
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Sample continuous or discrete values
|
| 88 |
+
def sample_continuous(s):
|
| 89 |
+
if len(s) == 1:
|
| 90 |
+
return s[0]
|
| 91 |
+
if len(s) == 2:
|
| 92 |
+
rg = s[1] - s[0]
|
| 93 |
+
return random.random() * rg + s[0]
|
| 94 |
+
raise ValueError("Length of iterable s should be 1 or 2.")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def sample_discrete(s):
|
| 98 |
+
if len(s) == 1:
|
| 99 |
+
return s[0]
|
| 100 |
+
return random.choice(s)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Gaussian blur
|
| 104 |
+
def gaussian_blur(img, sigma):
|
| 105 |
+
gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma)
|
| 106 |
+
gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma)
|
| 107 |
+
gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# JPEG compression
|
| 111 |
+
def cv2_jpg(img, compress_val):
|
| 112 |
+
img_cv2 = img[:,:,::-1]
|
| 113 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
|
| 114 |
+
result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
|
| 115 |
+
decimg = cv2.imdecode(encimg, 1)
|
| 116 |
+
return decimg[:,:,::-1]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def pil_jpg(img, compress_val):
|
| 120 |
+
out = BytesIO()
|
| 121 |
+
img = Image.fromarray(img)
|
| 122 |
+
img.save(out, format='jpeg', quality=compress_val)
|
| 123 |
+
img = Image.open(out)
|
| 124 |
+
# load from memory before ByteIO closes
|
| 125 |
+
img = np.array(img)
|
| 126 |
+
out.close()
|
| 127 |
+
return img
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def png_to_jpeg(img, quality=95):
|
| 131 |
+
# Convert the PNG image to JPEG
|
| 132 |
+
# Input: PIL image
|
| 133 |
+
# Output: PIL image
|
| 134 |
+
out = BytesIO()
|
| 135 |
+
img.save(out, format='jpeg', quality=quality)
|
| 136 |
+
img = np.array(Image.open(out))
|
| 137 |
+
# Load from memory before ByteIO closes
|
| 138 |
+
out.close()
|
| 139 |
+
img = Image.fromarray(img)
|
| 140 |
+
return img
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def jpeg_from_key(img, compress_val, key):
|
| 144 |
+
jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg}
|
| 145 |
+
method = jpeg_dict[key]
|
| 146 |
+
return method(img, compress_val)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# Custom resize function
|
| 150 |
+
def custom_resize(img, rz_interp, loadSize):
|
| 151 |
+
rz_dict = {'bilinear': Image.BILINEAR,
|
| 152 |
+
'bicubic': Image.BICUBIC,
|
| 153 |
+
'lanczos': Image.LANCZOS,
|
| 154 |
+
'nearest': Image.NEAREST}
|
| 155 |
+
interp = sample_discrete(rz_interp)
|
| 156 |
+
return TF.resize(img, loadSize, interpolation=rz_dict[interp])
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def weights2cpu(weights):
|
| 160 |
+
for key in weights:
|
| 161 |
+
weights[key] = weights[key].cpu()
|
| 162 |
+
return weights
|