NghiTran1009 commited on
Commit
cab012d
·
1 Parent(s): f1aeecb

Upload clean CO-SPY project

Browse files
.gitignore ADDED
Binary file (16 Bytes). View file
 
Datasets/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import TrainDataset, TestDataset
2
+
3
+ # List of evaluated real datasets
4
+ EVAL_DATASET_LIST = [
5
+ "real"
6
+ ]
7
+ # Danh sách model generative
8
+ EVAL_MODEL_LIST = [
9
+ "stable_diffusion"
10
+ ]
11
+ __all__ = ["TrainDataset", "TestDataset", "EVAL_DATASET_LIST", "EVAL_MODEL_LIST"]
Datasets/dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ from torch.utils.data import Dataset
6
+
7
+ from utils import get_list, png_to_jpeg
8
+ from .mscoco import MSCOCO2017
9
+ from .flickr import Flickr30k
10
+
11
+
12
+ class TrainDataset(Dataset):
13
+ def __init__(self, data_path, split="train", transform=None, add_jpeg=True):
14
+ assert split in ["train", "val"]
15
+
16
+ # Load the dataset for training
17
+ real_list = get_list(os.path.join(data_path, "mscoco2017", f"{split}2017"))
18
+ fake_list = get_list(os.path.join(data_path, "stable-diffusion-v1-4", f"{split}2017"))
19
+
20
+ # Setting the labels for the dataset
21
+ self.labels_dict = {}
22
+ for i in real_list:
23
+ self.labels_dict[i] = 0
24
+ for i in fake_list:
25
+ self.labels_dict[i] = 1
26
+
27
+ # Construct the entire dataset
28
+ self.total_list = real_list + fake_list
29
+ np.random.shuffle(self.total_list)
30
+
31
+ # JPEG compression
32
+ self.add_jpeg = add_jpeg
33
+
34
+ # Transformations
35
+ self.transform = transform
36
+
37
+ def __len__(self):
38
+ return len(self.total_list)
39
+
40
+ def __getitem__(self, idx):
41
+ img_path = self.total_list[idx]
42
+ label = self.labels_dict[img_path]
43
+ image = Image.open(img_path).convert("RGB")
44
+
45
+ # Add JPEG compression
46
+ if self.add_jpeg:
47
+ image = png_to_jpeg(image, quality=95)
48
+
49
+ # Apply the transformation
50
+ if self.transform is not None:
51
+ image = self.transform(image)
52
+ return image, label
53
+
54
+
55
+ class TestDataset(Dataset):
56
+ def __init__(self, dataset, model, root_path, transform=None, add_jpeg=True):
57
+ fake_dir = os.path.join(root_path, dataset, model)
58
+ self.fake = sorted([
59
+ os.path.join(fake_dir, i)
60
+ for i in os.listdir(fake_dir)
61
+ if i.lower().endswith((".png", ".jpg", ".jpeg"))
62
+ ])
63
+
64
+ real_dir = os.path.join(root_path, dataset, "real")
65
+ if not os.path.exists(real_dir):
66
+ raise ValueError(f"Real images directory not found: {real_dir}")
67
+
68
+ self.real = sorted([
69
+ os.path.join(real_dir, i)
70
+ for i in os.listdir(real_dir)
71
+ if i.lower().endswith((".png", ".jpg", ".jpeg"))
72
+ ])
73
+
74
+ self.image_idx = list(range(len(self.real) + len(self.fake)))
75
+ self.labels = [0] * len(self.real) + [1] * len(self.fake)
76
+ self.image_paths = self.real + self.fake
77
+
78
+ self.add_jpeg = add_jpeg
79
+ self.transform = transform
80
+
81
+ def __len__(self):
82
+ return len(self.image_idx)
83
+
84
+ def __getitem__(self, idx):
85
+ if idx < len(self.real):
86
+ img_path = self.real[idx]
87
+ else:
88
+ img_path = self.fake[idx - len(self.real)]
89
+
90
+ # ---- FIX: Bỏ qua ảnh hỏng / lỗi ----
91
+ try:
92
+ image = Image.open(img_path).convert("RGB")
93
+ except Exception:
94
+ print("Lỗi ảnh hỏng:", img_path)
95
+ # load ảnh kế tiếp thay thế
96
+ return self.__getitem__((idx + 1) % len(self))
97
+
98
+ if self.add_jpeg:
99
+ image = png_to_jpeg(image, quality=95)
100
+
101
+ if self.transform is not None:
102
+ image = self.transform(image)
103
+
104
+ label = self.labels[idx]
105
+ return image, label, img_path
Datasets/flickr.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import datasets as ds
7
+
8
+
9
+ class Flickr30k(torch.utils.data.Dataset):
10
+ def __init__(self, split='test', transform=None):
11
+ # Split [test: 31014]
12
+ self.dataset = ds.load_dataset("nlphuji/flickr30k")[split]
13
+
14
+ # Preprocess the images
15
+ self.transform = transform
16
+
17
+ def __len__(self):
18
+ return len(self.dataset)
19
+
20
+ def __getitem__(self, idx):
21
+ example = self.dataset[idx]
22
+ # PIL RGB image
23
+ image = example['image']
24
+ if self.transform:
25
+ image = self.transform(image)
26
+ # A list of valid captions
27
+ caption_list = example['caption']
28
+ # Randomly select a caption
29
+ caption = np.random.choice(caption_list)
30
+ return image, caption
Datasets/mscoco.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import datasets as ds
7
+
8
+
9
+ class MSCOCO2017(torch.utils.data.Dataset):
10
+ def __init__(self, split='train', transform=None):
11
+ # Split [train: 118287, val: 5000]
12
+ self.dataset = ds.load_dataset(
13
+ "shunk031/MSCOCO",
14
+ year=2017,
15
+ coco_task="captions"
16
+ )[split]
17
+
18
+ # Preprocess the images
19
+ self.transform = transform
20
+
21
+ def __len__(self):
22
+ return len(self.dataset)
23
+
24
+ def __getitem__(self, idx):
25
+ example = self.dataset[idx]
26
+ # PIL RGB image
27
+ image = example['image'].convert('RGB')
28
+ if self.transform:
29
+ image = self.transform(image)
30
+ # A list of valid captions
31
+ caption_list = example['annotations']['caption']
32
+ # Randomly select a caption
33
+ caption = np.random.choice(caption_list)
34
+ return image, caption
Detectors/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .artifact_detector import ArtifactDetector
2
+ from .semantic_detector import SemanticDetector
3
+ from .cospy_calibrate_detector import CospyCalibrateDetector
4
+ from .cospy_detector import CospyDetector, LabelSmoothingBCEWithLogits
5
+
6
+ __all__ = ["ArtifactDetector", "SemanticDetector", "CospyCalibrateDetector", "CospyDetector", "LabelSmoothingBCEWithLogits"]
Detectors/artifact_detector.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+ from .artifact_extractor import VAEReconEncoder
4
+ from torchvision import transforms
5
+ from utils import data_augment
6
+
7
+
8
+ # Artifact Detector (Extract artifact features using VAE)
9
+ class ArtifactDetector(torch.nn.Module):
10
+ def __init__(self, dim_artifact=512, num_classes=1):
11
+ super(ArtifactDetector, self).__init__()
12
+ # Load the pre-trained VAE
13
+ model_id = "CompVis/stable-diffusion-v1-4"
14
+ vae = StableDiffusionPipeline.from_pretrained(model_id).vae
15
+ # Freeze the VAE visual encoder
16
+ vae.requires_grad_(False)
17
+ self.artifact_encoder = VAEReconEncoder(vae)
18
+
19
+ # Classifier
20
+ self.fc = torch.nn.Linear(dim_artifact, num_classes)
21
+
22
+ # Normalization
23
+ self.mean = [0.0, 0.0, 0.0]
24
+ self.std = [1.0, 1.0, 1.0]
25
+
26
+ # Resolution
27
+ self.loadSize = 256
28
+ self.cropSize = 224
29
+
30
+ # Data augmentation
31
+ self.blur_prob = 0.0
32
+ self.blur_sig = [0.0, 3.0]
33
+ self.jpg_prob = 0.5
34
+ self.jpg_method = ['cv2', 'pil']
35
+ self.jpg_qual = list(range(70, 96))
36
+
37
+ # Define the augmentation configuration
38
+ self.aug_config = {
39
+ "blur_prob": self.blur_prob,
40
+ "blur_sig": self.blur_sig,
41
+ "jpg_prob": self.jpg_prob,
42
+ "jpg_method": self.jpg_method,
43
+ "jpg_qual": self.jpg_qual,
44
+ }
45
+
46
+ # Pre-processing
47
+ crop_func = transforms.RandomCrop(self.cropSize)
48
+ flip_func = transforms.RandomHorizontalFlip()
49
+ rz_func = transforms.Resize(self.loadSize)
50
+ aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
51
+
52
+ self.train_transform = transforms.Compose([
53
+ aug_func,
54
+ rz_func,
55
+ crop_func,
56
+ flip_func,
57
+ transforms.ToTensor(),
58
+ transforms.Normalize(mean=self.mean, std=self.std),
59
+ ])
60
+
61
+ self.test_transform = transforms.Compose([
62
+ rz_func,
63
+ crop_func,
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=self.mean, std=self.std),
66
+ ])
67
+
68
+ def forward(self, x, return_feat=False):
69
+ feat = self.artifact_encoder(x)
70
+ out = self.fc(feat)
71
+ if return_feat:
72
+ return feat, out
73
+ return out
74
+
75
+ def save_weights(self, weights_path):
76
+ save_params = {k: v.cpu() for k, v in self.state_dict().items()}
77
+ torch.save(save_params, weights_path)
78
+
79
+ def load_weights(self, weights_path):
80
+ weights = torch.load(weights_path)
81
+ self.load_state_dict(weights)
82
+
Detectors/artifact_extractor.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def conv3x3(in_planes, out_planes, stride=1):
7
+ """3x3 convolution with padding"""
8
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
9
+
10
+
11
+ def conv1x1(in_planes, out_planes, stride=1):
12
+ """1x1 convolution"""
13
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
14
+
15
+
16
+ class BasicBlock(nn.Module):
17
+ expansion = 1
18
+
19
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
20
+ super(BasicBlock, self).__init__()
21
+ self.conv1 = conv3x3(inplanes, planes, stride)
22
+ self.bn1 = nn.BatchNorm2d(planes)
23
+ self.relu = nn.ReLU(inplace=True)
24
+ self.conv2 = conv3x3(planes, planes)
25
+ self.bn2 = nn.BatchNorm2d(planes)
26
+ self.downsample = downsample
27
+ self.stride = stride
28
+
29
+ def forward(self, x):
30
+ identity = x
31
+
32
+ out = self.conv1(x)
33
+ out = self.bn1(out)
34
+ out = self.relu(out)
35
+
36
+ out = self.conv2(out)
37
+ out = self.bn2(out)
38
+
39
+ if self.downsample is not None:
40
+ identity = self.downsample(x)
41
+
42
+ out += identity
43
+ out = self.relu(out)
44
+
45
+ return out
46
+
47
+
48
+ class Bottleneck(nn.Module):
49
+ expansion = 4
50
+
51
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
52
+ super(Bottleneck, self).__init__()
53
+ self.conv1 = conv1x1(inplanes, planes)
54
+ self.bn1 = nn.BatchNorm2d(planes)
55
+ self.conv2 = conv3x3(planes, planes, stride)
56
+ self.bn2 = nn.BatchNorm2d(planes)
57
+ self.conv3 = conv1x1(planes, planes * self.expansion)
58
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
59
+ self.relu = nn.ReLU(inplace=True)
60
+ self.downsample = downsample
61
+ self.stride = stride
62
+
63
+ def forward(self, x):
64
+ identity = x
65
+
66
+ out = self.conv1(x)
67
+ out = self.bn1(out)
68
+ out = self.relu(out)
69
+
70
+ out = self.conv2(out)
71
+ out = self.bn2(out)
72
+ out = self.relu(out)
73
+
74
+ out = self.conv3(out)
75
+ out = self.bn3(out)
76
+
77
+ if self.downsample is not None:
78
+ identity = self.downsample(x)
79
+
80
+ out += identity
81
+ out = self.relu(out)
82
+
83
+ return out
84
+
85
+
86
+ class VAEReconEncoder(nn.Module):
87
+ def __init__(self, vae, block=Bottleneck):
88
+ super(VAEReconEncoder, self).__init__()
89
+
90
+ # Define the ResNet model
91
+ self.inplanes = 64
92
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
93
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
94
+ self.bn1 = nn.BatchNorm2d(64)
95
+ self.relu = nn.ReLU(inplace=True)
96
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
97
+
98
+ # ResNet-50 is [3, 4, 6, 3]
99
+ self.layer1 = self._make_layer(block, 64 , 3)
100
+ self.layer2 = self._make_layer(block, 128, 4, stride=2)
101
+ # self.layer3 = self._make_layer(block, 256, 6, stride=2)
102
+ # self.layer4 = self._make_layer(block, 512, 3, stride=2)
103
+
104
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
105
+
106
+ # Kaiming initialization
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
110
+ elif isinstance(m, nn.BatchNorm2d):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ # Load the VAE model
115
+ self.vae = vae
116
+
117
+ def _make_layer(self, block, planes, blocks, stride=1):
118
+ downsample = None
119
+ if stride != 1 or self.inplanes != planes * block.expansion:
120
+ downsample = nn.Sequential(
121
+ conv1x1(self.inplanes, planes * block.expansion, stride),
122
+ nn.BatchNorm2d(planes * block.expansion),
123
+ )
124
+
125
+ layers = []
126
+ layers.append(block(self.inplanes, planes, stride, downsample))
127
+ self.inplanes = planes * block.expansion
128
+ for _ in range(1, blocks):
129
+ layers.append(block(self.inplanes, planes))
130
+
131
+ return nn.Sequential(*layers)
132
+
133
+ def reconstruct(self, x):
134
+ with torch.no_grad():
135
+ # `.sample()` means to sample a latent vector from the distribution
136
+ # `.mean` means to use the mean of the distribution
137
+ latent = self.vae.encode(x).latent_dist.mean
138
+ decoded = self.vae.decode(latent).sample
139
+ return decoded
140
+
141
+ def forward(self, x):
142
+ # Reconstruct
143
+ x_recon = self.reconstruct(x)
144
+ # Compute the artifacts
145
+ x = x - x_recon
146
+
147
+ # Scale the artifacts
148
+ x = x / 7. * 100.
149
+
150
+ # Forward pass
151
+ x = self.conv1(x)
152
+ x = self.bn1(x)
153
+ x = self.relu(x)
154
+ x = self.maxpool(x)
155
+
156
+ x = self.layer1(x)
157
+ x = self.layer2(x)
158
+
159
+ x = self.avgpool(x)
160
+ x = x.view(x.size(0), -1)
161
+
162
+ return x
Detectors/cospy_calibrate_detector.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from utils import data_augment
4
+ from .semantic_detector import SemanticDetector
5
+ from .artifact_detector import ArtifactDetector
6
+
7
+
8
+ # CO-SPY Calibrate Detector (Calibrate the integration of semantic and artifact detectors)
9
+ class CospyCalibrateDetector(torch.nn.Module):
10
+ def __init__(self, semantic_weights_path, artifact_weights_path, num_classes=1):
11
+ super(CospyCalibrateDetector, self).__init__()
12
+
13
+ # Load the semantic detector
14
+ self.sem = SemanticDetector()
15
+ self.sem.load_weights(semantic_weights_path)
16
+
17
+ # Load the artifact detector
18
+ self.art = ArtifactDetector()
19
+ self.art.load_weights(artifact_weights_path)
20
+
21
+ # Freeze the two pre-trained models
22
+ for param in self.sem.parameters():
23
+ param.requires_grad = False
24
+ for param in self.art.parameters():
25
+ param.requires_grad = False
26
+
27
+ # Classifier
28
+ self.fc = torch.nn.Linear(2, num_classes)
29
+
30
+ # Transformations inside the forward function
31
+ # Including the normalization and resizing (only for the artifact detector)
32
+ self.sem_transform = transforms.Compose([
33
+ transforms.Normalize(self.sem.mean, self.sem.std)
34
+ ])
35
+ self.art_transform = transforms.Compose([
36
+ transforms.Resize(self.art.cropSize, antialias=False),
37
+ transforms.Normalize(self.art.mean, self.art.std)
38
+ ])
39
+
40
+ # Resolution
41
+ self.loadSize = 384
42
+ self.cropSize = 384
43
+
44
+ # Data augmentation
45
+ self.blur_prob = 0.0
46
+ self.blur_sig = [0.0, 3.0]
47
+ self.jpg_prob = 0.5
48
+ self.jpg_method = ['cv2', 'pil']
49
+ self.jpg_qual = list(range(70, 96))
50
+
51
+ # Define the augmentation configuration
52
+ self.aug_config = {
53
+ "blur_prob": self.blur_prob,
54
+ "blur_sig": self.blur_sig,
55
+ "jpg_prob": self.jpg_prob,
56
+ "jpg_method": self.jpg_method,
57
+ "jpg_qual": self.jpg_qual,
58
+ }
59
+
60
+ # Pre-processing
61
+ crop_func = transforms.RandomCrop(self.cropSize)
62
+ flip_func = transforms.RandomHorizontalFlip()
63
+ rz_func = transforms.Resize(self.loadSize)
64
+ aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
65
+
66
+ self.train_transform = transforms.Compose([
67
+ flip_func,
68
+ aug_func,
69
+ rz_func,
70
+ crop_func,
71
+ transforms.ToTensor(),
72
+ ])
73
+
74
+ self.test_transform = transforms.Compose([
75
+ rz_func,
76
+ crop_func,
77
+ transforms.ToTensor(),
78
+ ])
79
+
80
+ def forward(self, x):
81
+ x_sem = self.sem_transform(x)
82
+ x_art = self.art_transform(x)
83
+ pred_sem = self.sem(x_sem)
84
+ pred_art = self.art(x_art)
85
+ x = torch.cat([pred_sem, pred_art], dim=1)
86
+ x = self.fc(x)
87
+ return x
88
+
89
+ def save_weights(self, weights_path):
90
+ save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
91
+ torch.save(save_params, weights_path)
92
+
93
+ def load_weights(self, weights_path):
94
+ weights = torch.load(weights_path)
95
+ self.fc.weight.data = weights["fc.weight"]
96
+ self.fc.bias.data = weights["fc.bias"]
Detectors/cospy_detector.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from torchvision import transforms
4
+ from utils import data_augment, weights2cpu
5
+ from .semantic_detector import SemanticDetector
6
+ from .artifact_detector import ArtifactDetector
7
+
8
+
9
+ # CO-SPY Detector
10
+ class CospyDetector(torch.nn.Module):
11
+ def __init__(self, num_classes=1):
12
+ super(CospyDetector, self).__init__()
13
+
14
+ # Load the semantic detector
15
+ self.sem = SemanticDetector()
16
+ self.sem_dim = self.sem.fc.in_features
17
+
18
+ # Load the artifact detector
19
+ self.art = ArtifactDetector()
20
+ self.art_dim = self.art.fc.in_features
21
+
22
+ # Classifier
23
+ self.fc = torch.nn.Linear(self.sem_dim + self.art_dim, num_classes)
24
+
25
+ # Transformations inside the forward function
26
+ # Including the normalization and resizing (only for the artifact detector)
27
+ self.sem_transform = transforms.Compose([
28
+ transforms.Normalize(self.sem.mean, self.sem.std)
29
+ ])
30
+ self.art_transform = transforms.Compose([
31
+ transforms.Resize(self.art.cropSize, antialias=False),
32
+ transforms.Normalize(self.art.mean, self.art.std)
33
+ ])
34
+
35
+ # Resolution
36
+ self.loadSize = 384
37
+ self.cropSize = 384
38
+
39
+ # Data augmentation
40
+ self.blur_prob = 0.0
41
+ self.blur_sig = [0.0, 3.0]
42
+ self.jpg_prob = 0.5
43
+ self.jpg_method = ['cv2', 'pil']
44
+ self.jpg_qual = list(range(70, 96))
45
+
46
+ # Define the augmentation configuration
47
+ self.aug_config = {
48
+ "blur_prob": self.blur_prob,
49
+ "blur_sig": self.blur_sig,
50
+ "jpg_prob": self.jpg_prob,
51
+ "jpg_method": self.jpg_method,
52
+ "jpg_qual": self.jpg_qual,
53
+ }
54
+
55
+ # Pre-processing
56
+ crop_func = transforms.RandomCrop(self.cropSize)
57
+ flip_func = transforms.RandomHorizontalFlip()
58
+ rz_func = transforms.Resize(self.loadSize)
59
+ aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
60
+
61
+ self.train_transform = transforms.Compose([
62
+ flip_func,
63
+ aug_func,
64
+ rz_func,
65
+ crop_func,
66
+ transforms.ToTensor(),
67
+ ])
68
+
69
+ self.test_transform = transforms.Compose([
70
+ rz_func,
71
+ crop_func,
72
+ transforms.ToTensor(),
73
+ ])
74
+
75
+ def forward(self, x, dropout_rate=0.3):
76
+ x_sem = self.sem_transform(x)
77
+ x_art = self.art_transform(x)
78
+
79
+ # Forward pass
80
+ sem_feat, sem_coeff = self.sem(x_sem, return_feat=True)
81
+ art_feat, art_coeff = self.art(x_art, return_feat=True)
82
+
83
+ # Dropout
84
+ if self.train():
85
+ # Random dropout
86
+ if random.random() < dropout_rate:
87
+ # Randomly select a feature to drop
88
+ idx_drop = random.randint(0, 1)
89
+ if idx_drop == 0:
90
+ sem_coeff = torch.zeros_like(sem_coeff)
91
+ else:
92
+ art_coeff = torch.zeros_like(art_coeff)
93
+
94
+ # Concatenate the features
95
+ x = torch.cat([sem_coeff * sem_feat, art_coeff * art_feat], dim=1)
96
+ x = self.fc(x)
97
+
98
+ return x
99
+ def save_weights(self, weights_path):
100
+ save_params = {
101
+ "sem_fc": weights2cpu(self.sem.fc.state_dict()),
102
+ "art_fc": weights2cpu(self.art.fc.state_dict()),
103
+ "art_encoder": weights2cpu(self.art.artifact_encoder.state_dict()),
104
+ "classifier": weights2cpu(self.fc.state_dict()),
105
+ }
106
+ torch.save(save_params, weights_path)
107
+
108
+ def load_weights(self, weights_path):
109
+ weights = torch.load(weights_path)
110
+ self.sem.fc.load_state_dict(weights["sem_fc"])
111
+ self.art.fc.load_state_dict(weights["art_fc"])
112
+ self.art.artifact_encoder.load_state_dict(weights["art_encoder"])
113
+ self.fc.load_state_dict(weights["classifier"])
114
+
115
+ # Define the label smoothing loss
116
+ class LabelSmoothingBCEWithLogits(torch.nn.Module):
117
+ def __init__(self, smoothing=0.1):
118
+ super(LabelSmoothingBCEWithLogits, self).__init__()
119
+ self.smoothing = smoothing
120
+
121
+ def forward(self, pred, target):
122
+ target = target.float() * (1.0 - self.smoothing) + 0.5 * self.smoothing
123
+ loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target, reduction='mean')
124
+ return
Detectors/semantic_detector.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ from torchvision import transforms
4
+ from utils import data_augment
5
+
6
+
7
+ # Semantic Detector (Extract semantic features using CLIP)
8
+ class SemanticDetector(torch.nn.Module):
9
+ def __init__(self, dim_clip=1152, num_classes=1):
10
+ super(SemanticDetector, self).__init__()
11
+
12
+ # Get the pre-trained CLIP
13
+ model_name = "ViT-SO400M-14-SigLIP-384"
14
+ version = "webli"
15
+ self.clip, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=version)
16
+ # Freeze the CLIP visual encoder
17
+ self.clip.requires_grad_(False)
18
+
19
+ # Classifier
20
+ self.fc = torch.nn.Linear(dim_clip, num_classes)
21
+
22
+ # Normalization
23
+ self.mean = [0.5, 0.5, 0.5]
24
+ self.std = [0.5, 0.5, 0.5]
25
+
26
+ # Resolution
27
+ self.loadSize = 384
28
+ self.cropSize = 384
29
+
30
+ # Data augmentation
31
+ self.blur_prob = 0.5
32
+ self.blur_sig = [0.0, 3.0]
33
+ self.jpg_prob = 0.5
34
+ self.jpg_method = ['cv2', 'pil']
35
+ self.jpg_qual = list(range(30, 101))
36
+
37
+ # Define the augmentation configuration
38
+ self.aug_config = {
39
+ "blur_prob": self.blur_prob,
40
+ "blur_sig": self.blur_sig,
41
+ "jpg_prob": self.jpg_prob,
42
+ "jpg_method": self.jpg_method,
43
+ "jpg_qual": self.jpg_qual,
44
+ }
45
+
46
+ # Pre-processing
47
+ crop_func = transforms.RandomCrop(self.cropSize)
48
+ flip_func = transforms.RandomHorizontalFlip()
49
+ rz_func = transforms.Resize(self.loadSize)
50
+ aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
51
+
52
+ self.train_transform = transforms.Compose([
53
+ rz_func,
54
+ aug_func,
55
+ crop_func,
56
+ flip_func,
57
+ transforms.ToTensor(),
58
+ transforms.Normalize(mean=self.mean, std=self.std),
59
+ ])
60
+
61
+ self.test_transform = transforms.Compose([
62
+ rz_func,
63
+ crop_func,
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=self.mean, std=self.std),
66
+ ])
67
+
68
+ def forward(self, x, return_feat=False):
69
+ device = next(self.fc.parameters()).device # lấy device của fc
70
+ x = x.to(device) # đảm bảo input cùng device
71
+ feat = self.clip.encode_image(x)
72
+ feat = feat.to(device) # đảm bảo feat cùng device với fc
73
+ out = self.fc(feat)
74
+ if return_feat:
75
+ return feat, out
76
+ return out
77
+
78
+ def save_weights(self, weights_path):
79
+ save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
80
+ torch.save(save_params, weights_path)
81
+
82
+ def load_weights(self, weights_path):
83
+ device = next(self.fc.parameters()).device # lấy device hiện tại của model
84
+ weights = torch.load(weights_path, map_location=device)
85
+ self.fc.weight.data = weights["fc.weight"].to(device)
86
+ self.fc.bias.data = weights["fc.bias"].to(device)
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Siyuan Cheng
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ProGANDetectors/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .artifact_detector import ArtifactDetectorProGAN
2
+ from .semantic_detector import SemanticDetectorProGAN
3
+ from .cospy_calibrate_detector import CospyCalibrateDetectorProGAN
4
+
5
+ __all__ = ["ArtifactDetectorProGAN", "SemanticDetectorProGAN", "CospyCalibrateDetectorProGAN"]
ProGANDetectors/artifact_detector.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torchvision import transforms
5
+ from utils import data_augment
6
+
7
+
8
+ # Artifact Detector (Extract artifact features using VAE)
9
+ class ArtifactDetectorProGAN(torch.nn.Module):
10
+ def __init__(self, dim_artifact=512, num_classes=1):
11
+ super(ArtifactDetectorProGAN, self).__init__()
12
+ # Load the artifact encoder based on NPR
13
+ self.artifact_encoder = ResNet(Bottleneck, [3, 4, 6, 3])
14
+
15
+ # Classifier
16
+ self.fc = torch.nn.Linear(dim_artifact, num_classes)
17
+
18
+ # Normalization
19
+ self.mean = [0.485, 0.456, 0.406]
20
+ self.std = [0.229, 0.224, 0.225]
21
+
22
+ # Resolution
23
+ self.loadSize = 256
24
+ self.cropSize = 224
25
+
26
+ # Data augmentation
27
+ self.blur_prob = 0.0
28
+ self.blur_sig = [0.0, 3.0]
29
+ self.jpg_prob = 0.0
30
+ self.jpg_method = ['cv2', 'pil']
31
+ self.jpg_qual = list(range(70, 96))
32
+
33
+ # Define the augmentation configuration
34
+ self.aug_config = {
35
+ "blur_prob": self.blur_prob,
36
+ "blur_sig": self.blur_sig,
37
+ "jpg_prob": self.jpg_prob,
38
+ "jpg_method": self.jpg_method,
39
+ "jpg_qual": self.jpg_qual,
40
+ }
41
+
42
+ # Pre-processing
43
+ crop_func = transforms.RandomCrop(self.cropSize)
44
+ flip_func = transforms.RandomHorizontalFlip()
45
+ rz_func = transforms.Resize(self.loadSize)
46
+ aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
47
+
48
+ self.train_transform = transforms.Compose([
49
+ aug_func,
50
+ rz_func,
51
+ crop_func,
52
+ flip_func,
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=self.mean, std=self.std),
55
+ ])
56
+
57
+ self.test_transform = transforms.Compose([
58
+ rz_func,
59
+ crop_func,
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=self.mean, std=self.std),
62
+ ])
63
+
64
+ def forward(self, x, return_feat=False):
65
+ feat = self.artifact_encoder(x)
66
+ out = self.fc(feat)
67
+ if return_feat:
68
+ return feat, out
69
+ return out
70
+
71
+ def save_weights(self, weights_path):
72
+ save_params = {k: v.cpu() for k, v in self.state_dict().items()}
73
+ torch.save(save_params, weights_path)
74
+
75
+ def load_weights(self, weights_path):
76
+ weights = torch.load(weights_path)
77
+ self.load_state_dict(weights)
78
+
79
+
80
+ # Define the artifact encoder (based on NPR)
81
+ def conv1x1(in_planes, out_planes, stride=1):
82
+ """1x1 convolution"""
83
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
84
+
85
+
86
+ def conv3x3(in_planes, out_planes, stride=1):
87
+ """3x3 convolution with padding"""
88
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
89
+ padding=1, bias=False)
90
+
91
+
92
+ class Bottleneck(nn.Module):
93
+ expansion = 4
94
+
95
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
96
+ super(Bottleneck, self).__init__()
97
+ self.conv1 = conv1x1(inplanes, planes)
98
+ self.bn1 = nn.BatchNorm2d(planes)
99
+ self.conv2 = conv3x3(planes, planes, stride)
100
+ self.bn2 = nn.BatchNorm2d(planes)
101
+ self.conv3 = conv1x1(planes, planes * self.expansion)
102
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
103
+ self.relu = nn.ReLU(inplace=True)
104
+ self.downsample = downsample
105
+ self.stride = stride
106
+
107
+ def forward(self, x):
108
+ identity = x
109
+
110
+ out = self.conv1(x)
111
+ out = self.bn1(out)
112
+ out = self.relu(out)
113
+
114
+ out = self.conv2(out)
115
+ out = self.bn2(out)
116
+ out = self.relu(out)
117
+
118
+ out = self.conv3(out)
119
+ out = self.bn3(out)
120
+
121
+ if self.downsample is not None:
122
+ identity = self.downsample(x)
123
+
124
+ out += identity
125
+ out = self.relu(out)
126
+
127
+ return out
128
+
129
+
130
+ class ResNet(nn.Module):
131
+
132
+ def __init__(self, block, layers, num_classes=1):
133
+ super(ResNet, self).__init__()
134
+
135
+ self.unfoldSize = 2
136
+ self.unfoldIndex = 0
137
+ assert self.unfoldSize > 1
138
+ assert -1 < self.unfoldIndex and self.unfoldIndex < self.unfoldSize*self.unfoldSize
139
+ self.inplanes = 64
140
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
141
+ self.bn1 = nn.BatchNorm2d(64)
142
+ self.relu = nn.ReLU(inplace=True)
143
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
144
+ self.layer1 = self._make_layer(block, 64 , layers[0])
145
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
146
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
147
+ self.fc1 = nn.Linear(512, num_classes)
148
+
149
+ for m in self.modules():
150
+ if isinstance(m, nn.Conv2d):
151
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
152
+ elif isinstance(m, nn.BatchNorm2d):
153
+ nn.init.constant_(m.weight, 1)
154
+ nn.init.constant_(m.bias, 0)
155
+
156
+ def _make_layer(self, block, planes, blocks, stride=1):
157
+ downsample = None
158
+ if stride != 1 or self.inplanes != planes * block.expansion:
159
+ downsample = nn.Sequential(
160
+ conv1x1(self.inplanes, planes * block.expansion, stride),
161
+ nn.BatchNorm2d(planes * block.expansion),
162
+ )
163
+
164
+ layers = []
165
+ layers.append(block(self.inplanes, planes, stride, downsample))
166
+ self.inplanes = planes * block.expansion
167
+ for _ in range(1, blocks):
168
+ layers.append(block(self.inplanes, planes))
169
+
170
+ return nn.Sequential(*layers)
171
+
172
+ def interpolate(self, img, factor):
173
+ return F.interpolate(
174
+ F.interpolate(img,
175
+ scale_factor=factor,
176
+ mode='nearest',
177
+ recompute_scale_factor=True),
178
+ scale_factor=1 / factor,
179
+ mode='nearest',
180
+ recompute_scale_factor=True)
181
+
182
+ def forward(self, x):
183
+ artifact = x - self.interpolate(x, 0.5)
184
+
185
+ x = self.conv1(artifact * 2.0 / 3.0)
186
+ x = self.bn1(x)
187
+ x = self.relu(x)
188
+ x = self.maxpool(x)
189
+
190
+ x = self.layer1(x)
191
+ x = self.layer2(x)
192
+
193
+ x = self.avgpool(x)
194
+ x = x.view(x.size(0), -1)
195
+
196
+ return x
ProGANDetectors/cospy_calibrate_detector.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from utils import data_augment
4
+ from .semantic_detector import SemanticDetectorProGAN
5
+ from .artifact_detector import ArtifactDetectorProGAN
6
+
7
+
8
+ # CO-SPY Calibrate Detector (Calibrate the integration of semantic and artifact detectors)
9
+ class CospyCalibrateDetectorProGAN(torch.nn.Module):
10
+ def __init__(self, semantic_weights_path, artifact_weights_path, num_classes=1):
11
+ super(CospyCalibrateDetectorProGAN, self).__init__()
12
+
13
+ # Load the semantic detector
14
+ self.sem = SemanticDetectorProGAN()
15
+ self.sem.load_weights(semantic_weights_path)
16
+
17
+ # Load the artifact detector
18
+ self.art = ArtifactDetectorProGAN()
19
+ self.art.load_weights(artifact_weights_path)
20
+
21
+ # Freeze the two pre-trained models
22
+ for param in self.sem.parameters():
23
+ param.requires_grad = False
24
+ for param in self.art.parameters():
25
+ param.requires_grad = False
26
+
27
+ # Classifier
28
+ self.fc = torch.nn.Linear(2, num_classes)
29
+
30
+ # Transformations inside the forward function
31
+ # Including the normalization and resizing (only for the artifact detector)
32
+ self.sem_transform = transforms.Compose([
33
+ transforms.Normalize(self.sem.mean, self.sem.std)
34
+ ])
35
+ self.art_transform = transforms.Compose([
36
+ transforms.Resize(self.art.cropSize, antialias=False),
37
+ transforms.Normalize(self.art.mean, self.art.std)
38
+ ])
39
+
40
+ # Resolution
41
+ self.loadSize = 256
42
+ self.cropSize = 224
43
+
44
+ # Data augmentation
45
+ self.blur_prob = 0.0
46
+ self.blur_sig = [0.0, 3.0]
47
+ self.jpg_prob = 0.0
48
+ self.jpg_method = ['cv2', 'pil']
49
+ self.jpg_qual = list(range(70, 96))
50
+
51
+ # Define the augmentation configuration
52
+ self.aug_config = {
53
+ "blur_prob": self.blur_prob,
54
+ "blur_sig": self.blur_sig,
55
+ "jpg_prob": self.jpg_prob,
56
+ "jpg_method": self.jpg_method,
57
+ "jpg_qual": self.jpg_qual,
58
+ }
59
+
60
+ # Pre-processing
61
+ crop_func = transforms.RandomCrop(self.cropSize)
62
+ flip_func = transforms.RandomHorizontalFlip()
63
+ rz_func = transforms.Resize(self.loadSize)
64
+ aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
65
+
66
+ self.train_transform = transforms.Compose([
67
+ flip_func,
68
+ aug_func,
69
+ rz_func,
70
+ crop_func,
71
+ transforms.ToTensor(),
72
+ ])
73
+
74
+ self.test_transform = transforms.Compose([
75
+ rz_func,
76
+ crop_func,
77
+ transforms.ToTensor(),
78
+ ])
79
+
80
+ def forward(self, x):
81
+ x_sem = self.sem_transform(x)
82
+ x_art = self.art_transform(x)
83
+ pred_sem = self.sem(x_sem)
84
+ pred_art = self.art(x_art)
85
+ x = torch.cat([pred_sem, pred_art], dim=1)
86
+ x = self.fc(x)
87
+ return x
88
+
89
+ def save_weights(self, weights_path):
90
+ save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
91
+ torch.save(save_params, weights_path)
92
+
93
+ def load_weights(self, weights_path):
94
+ weights = torch.load(weights_path)
95
+ self.fc.weight.data = weights["fc.weight"]
96
+ self.fc.bias.data = weights["fc.bias"]
ProGANDetectors/semantic_detector.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPModel
3
+ from torchvision import transforms
4
+ from utils import data_augment
5
+
6
+
7
+ # Semantic Detector (Extract semantic features using CLIP)
8
+ class SemanticDetectorProGAN(torch.nn.Module):
9
+ def __init__(self, dim_clip=768, num_classes=1):
10
+ super(SemanticDetectorProGAN, self).__init__()
11
+
12
+ # Get the pre-trained CLIP
13
+ model_name = "openai/clip-vit-large-patch14"
14
+ self.clip = CLIPModel.from_pretrained(model_name)
15
+
16
+ # Freeze the CLIP visual encoder
17
+ self.clip.requires_grad_(False)
18
+
19
+ # Classifier
20
+ self.fc = torch.nn.Linear(dim_clip, num_classes)
21
+
22
+ # Normalization
23
+ self.mean = [0.48145466, 0.4578275, 0.40821073]
24
+ self.std = [0.26862954, 0.26130258, 0.27577711]
25
+
26
+ # Resolution
27
+ self.loadSize = 256
28
+ self.cropSize = 224
29
+
30
+ # Data augmentation
31
+ self.blur_prob = 0.5
32
+ self.blur_sig = [0.0, 3.0]
33
+ self.jpg_prob = 0.5
34
+ self.jpg_method = ['cv2', 'pil']
35
+ self.jpg_qual = list(range(30, 101))
36
+
37
+ # Define the augmentation configuration
38
+ self.aug_config = {
39
+ "blur_prob": self.blur_prob,
40
+ "blur_sig": self.blur_sig,
41
+ "jpg_prob": self.jpg_prob,
42
+ "jpg_method": self.jpg_method,
43
+ "jpg_qual": self.jpg_qual,
44
+ }
45
+
46
+ # Pre-processing
47
+ crop_func = transforms.RandomCrop(self.cropSize)
48
+ flip_func = transforms.RandomHorizontalFlip()
49
+ rz_func = transforms.Resize(self.loadSize)
50
+ aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
51
+
52
+ self.train_transform = transforms.Compose([
53
+ rz_func,
54
+ aug_func,
55
+ crop_func,
56
+ flip_func,
57
+ transforms.ToTensor(),
58
+ transforms.Normalize(mean=self.mean, std=self.std),
59
+ ])
60
+
61
+ self.test_transform = transforms.Compose([
62
+ rz_func,
63
+ crop_func,
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=self.mean, std=self.std),
66
+ ])
67
+
68
+ def forward(self, x, return_feat=False):
69
+ feat = self.clip.get_image_features(x)
70
+ out = self.fc(feat)
71
+ if return_feat:
72
+ return feat, out
73
+ return out
74
+
75
+ def save_weights(self, weights_path):
76
+ save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
77
+ torch.save(save_params, weights_path)
78
+
79
+ def load_weights(self, weights_path):
80
+ weights = torch.load(weights_path)
81
+ self.fc.weight.data = weights["fc.weight"]
82
+ self.fc.bias.data = weights["fc.bias"]
__pycache__/utils.cpython-311.pyc ADDED
Binary file (9.5 kB). View file
 
calibrate_combine.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from loguru import logger
8
+ from sklearn.metrics import average_precision_score
9
+
10
+ from Detectors import CospyCalibrateDetector
11
+ from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
12
+ from utils import seed_torch
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+
18
+ class Detector():
19
+ def __init__(self, args):
20
+ super(Detector, self).__init__()
21
+
22
+ # Device
23
+ self.device = args.device
24
+
25
+ # ===== Khởi tạo model =====
26
+ self.model = CospyCalibrateDetector(
27
+ semantic_weights_path=args.semantic_weights_path,
28
+ artifact_weights_path=args.artifact_weights_path
29
+ )
30
+ self.model.to(self.device)
31
+
32
+ # Khởi tạo fc layer nếu muốn
33
+ torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
34
+
35
+ # ===== Optimizer =====
36
+ _lr = 1e-1
37
+ _beta1 = 0.9
38
+ _weight_decay = 0.0
39
+ params = [p for p in self.model.parameters() if p.requires_grad]
40
+ print(f'Trainable parameters: {len(params)}')
41
+ self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay)
42
+
43
+ # ===== Loss =====
44
+ self.criterion = torch.nn.BCEWithLogitsLoss()
45
+
46
+ # Scheduler
47
+ self.delr_freq = 10
48
+
49
+ # ===== Load checkpoint nếu có =====
50
+ if args.resume is not None:
51
+ print(f"Loading checkpoint from {args.resume}")
52
+ state = torch.load(args.resume, map_location=self.device)
53
+
54
+ # hỗ trợ cả 2 dạng: {'model': state_dict} hoặc state_dict trực tiếp
55
+ if isinstance(state, dict) and "model" in state:
56
+ state = state["model"]
57
+
58
+ self.model.load_state_dict(state, strict=False)
59
+ print("Checkpoint loaded. Continue training...")
60
+
61
+ self.model.to(self.device)
62
+ self.model.train()
63
+
64
+
65
+
66
+ # Training function for the detector
67
+ def train_step(self, batch_data):
68
+ # Decompose the batch data
69
+ inputs, labels = batch_data
70
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
71
+
72
+ self.optimizer.zero_grad()
73
+ outputs = self.model(inputs)
74
+ loss = self.criterion(outputs, labels.unsqueeze(1).float())
75
+ loss.backward()
76
+ self.optimizer.step()
77
+
78
+ eval_loss = loss.item()
79
+ y_pred = outputs.sigmoid().flatten().tolist()
80
+ y_true = labels.tolist()
81
+ return eval_loss, y_pred, y_true
82
+
83
+ # Schedule the training
84
+ # Early stopping / learning rate adjustment
85
+ def scheduler(self, status_dict):
86
+ epoch = status_dict['epoch']
87
+ if epoch % self.delr_freq == 0 and epoch != 0:
88
+ for param_group in self.optimizer.param_groups:
89
+ param_group['lr'] *= 0.9
90
+ self.lr = param_group['lr']
91
+ return True
92
+
93
+ # Prediction function
94
+ def predict(self, inputs):
95
+ inputs = inputs.to(self.device)
96
+ outputs = self.model(inputs)
97
+ prediction = outputs.sigmoid().flatten().tolist()
98
+ return prediction
99
+
100
+
101
+ def evaluate(y_pred, y_true):
102
+ ap = average_precision_score(y_true, y_pred)
103
+ accuracy = ((np.array(y_pred) > 0.5) == y_true).mean()
104
+ return ap, accuracy
105
+
106
+
107
+ def train(args):
108
+ # Set the saving directory **trước**
109
+ model_dir = os.path.join(args.ckpt, "cospy_calibrate")
110
+ if not os.path.exists(model_dir):
111
+ os.makedirs(model_dir)
112
+
113
+ log_path = f"{model_dir}/training.log"
114
+ if os.path.exists(log_path):
115
+ os.remove(log_path)
116
+
117
+ logger_id = logger.add(
118
+ log_path,
119
+ format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
120
+ level="DEBUG",
121
+ )
122
+ # Get the detector
123
+ detector = Detector(args)
124
+ # --- Resume checkpoint ---
125
+ start_epoch = 0
126
+ best_acc = 0
127
+
128
+ if args.resume:
129
+ resume_path = os.path.join(model_dir, "best_model.pth")
130
+ if os.path.exists(resume_path):
131
+ print(f"Resuming from {resume_path} ...")
132
+ detector.model.load_weights(resume_path)
133
+ detector.model.to(args.device)
134
+
135
+ # Load the calibration dataset using the "val" split
136
+ train_dataset = TrainDataset(data_path=args.calibration_dirpath,
137
+ split="val",
138
+ transform=detector.model.test_transform)
139
+
140
+ train_loader = torch.utils.data.DataLoader(train_dataset,
141
+ batch_size=args.batch_size,
142
+ shuffle=True,
143
+ num_workers=4,
144
+ pin_memory=True)
145
+
146
+ logger.info(f"Train size {len(train_dataset)}")
147
+
148
+ # Set the saving directory
149
+ model_dir = os.path.join(args.ckpt, "cospy_calibrate")
150
+ if not os.path.exists(model_dir):
151
+ os.makedirs(model_dir)
152
+ log_path = f"{model_dir}/training.log"
153
+ if os.path.exists(log_path):
154
+ os.remove(log_path)
155
+
156
+ logger_id = logger.add(
157
+ log_path,
158
+ format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
159
+ level="DEBUG",
160
+ )
161
+
162
+ # Train the detector
163
+ best_acc = 0
164
+ for epoch in range(start_epoch, args.epochs):
165
+ # Set the model to training mode
166
+ detector.model.train()
167
+ time_start = time.time()
168
+ for step_id, batch_data in enumerate(train_loader):
169
+ eval_loss, y_pred, y_true = detector.train_step(batch_data)
170
+ ap, accuracy = evaluate(y_pred, y_true)
171
+
172
+ # Log the training information
173
+ if (step_id + 1) % 100 == 0:
174
+ time_end = time.time()
175
+ logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s")
176
+ time_start = time.time()
177
+
178
+ # Evaluate the model
179
+ detector.model.eval()
180
+ y_pred, y_true = [], []
181
+ for inputs in train_loader:
182
+ inputs, labels = inputs
183
+ y_pred.extend(detector.predict(inputs))
184
+ y_true.extend(labels.tolist())
185
+
186
+ ap, accuracy = evaluate(y_pred, y_true)
187
+ logger.info(f"Epoch {epoch} | Total AP {ap*100:.2f}% | Total Accuracy {accuracy*100:.2f}%")
188
+
189
+ # Schedule the training
190
+ status_dict = {'epoch': epoch, 'AP': ap, 'Accuracy': accuracy}
191
+ proceed = detector.scheduler(status_dict)
192
+ if not proceed:
193
+ logger.info("Early stopping")
194
+ break
195
+
196
+ # Save the model
197
+ if accuracy >= best_acc:
198
+ best_acc = accuracy
199
+ detector.model.save_weights(f"{model_dir}/best_model.pth")
200
+ logger.info(f"Best model saved with accuracy {best_acc.mean()*100:.2f}%")
201
+
202
+ if epoch % 5 == 0:
203
+ detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth")
204
+ logger.info(f"Model saved at epoch {epoch}")
205
+
206
+ # Save the final model
207
+ detector.model.save_weights(f"{model_dir}/final_model.pth")
208
+ logger.info("Final model saved")
209
+
210
+ # Remove the logger
211
+ logger.remove(logger_id)
212
+
213
+
214
+ def test(args):
215
+ # Initialize the detector
216
+ detector = Detector(args)
217
+
218
+ # Load the [best/final] model
219
+ weights_path = os.path.join(args.ckpt, "cospy_calibrate", "best_model.pth")
220
+
221
+ detector.model.load_weights(weights_path)
222
+ detector.model.to(args.device)
223
+ detector.model.eval()
224
+
225
+ # Set the pre-processing function
226
+ test_transform = detector.model.test_transform
227
+
228
+ # Set the saving directory
229
+ save_dir = os.path.join(args.ckpt, "cospy_calibrate")
230
+ save_result_path = os.path.join(save_dir, "result.json")
231
+ save_output_path = os.path.join(save_dir, "output.json")
232
+
233
+ # Begin the evaluation
234
+ result_all = {}
235
+ output_all = {}
236
+ for dataset_name in EVAL_DATASET_LIST:
237
+ result_all[dataset_name] = {}
238
+ output_all[dataset_name] = {}
239
+ for model_name in EVAL_MODEL_LIST:
240
+ test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform)
241
+ test_loader = torch.utils.data.DataLoader(test_dataset,
242
+ batch_size=args.batch_size,
243
+ shuffle=False,
244
+ num_workers=4,
245
+ pin_memory=True)
246
+
247
+ # Evaluate the model
248
+ y_pred, y_true = [], []
249
+ for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
250
+ y_pred.extend(detector.predict(images))
251
+ y_true.extend(labels.tolist())
252
+
253
+ ap, accuracy = evaluate(y_pred, y_true)
254
+ print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%")
255
+
256
+ result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy}
257
+ output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
258
+
259
+ # Save the results
260
+ with open(save_result_path, "w") as f:
261
+ json.dump(result_all, f, indent=4)
262
+
263
+ with open(save_output_path, "w") as f:
264
+ json.dump(output_all, f, indent=4)
265
+
266
+
267
+ if __name__ == "__main__":
268
+ import argparse
269
+
270
+ parser = argparse.ArgumentParser("Deep Fake Detection")
271
+ parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
272
+ parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment")
273
+ parser.add_argument("--semantic_weights_path", type=str, default="ckpt/semantic/best_model.pth", help="Semantic weights path")
274
+ parser.add_argument("--artifact_weights_path", type=str, default="ckpt/artifact/best_model.pth", help="Artifact weights path")
275
+ parser.add_argument("--calibration_dirpath", type=str, default="data/train", help="Calibration directory")
276
+ parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
277
+ parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory")
278
+ parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
279
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
280
+ parser.add_argument("--seed", type=int, default=1024, help="Random seed")
281
+ parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training")
282
+
283
+ args = parser.parse_args()
284
+
285
+ # Set the random seed
286
+ seed_torch(args.seed)
287
+
288
+ # Set the GPU ID
289
+ args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
290
+
291
+ # Begin the experiment
292
+ if args.phase == "train":
293
+ train(args)
294
+ elif args.phase == "test":
295
+ test(args)
296
+ else:
297
+ raise ValueError("Unknown phase")
data/in_the_wild/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Our in-the-wild evaluation dataset is constructed from five sources:
2
+
3
+ (1) Civitai [https://civitai.com/]
4
+
5
+ (2) DALL-E 3 [https://huggingface.co/datasets/ProGamerGov/synthetic-dataset-1m-dalle3-high-quality-captions]
6
+
7
+ (3) instavibe.ai [https://www.instavibe.ai/discover]
8
+
9
+ (4) Lexica [https://lexica.art/]
10
+
11
+ (5) Midjourney-v6 [https://huggingface.co/datasets/terminusresearch/midjourney-v6-520k-raw]
12
+
13
+ Data from sources (1), (2), (5) can be easily accessed and downloaded.
14
+ For sources (3) and (4), we provide the image URLs used in our dataset under the `./urls` directory for your convenience.
data/in_the_wild/urls/flux.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/in_the_wild/urls/lexica.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/test/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Please download the test samples from [Co-Spy-Bench](https://huggingface.co/datasets/ruojiruoli/Co-Spy-Bench) and place them in this directory.
2
+
3
+ For real images:
4
+
5
+ * **CC3M**, **MSCOCO**, **TextCaps**, **Flickr**, and **SBU** are used.
6
+ * For **MSCOCO** and **Flickr**, refer to `Datasets/mscoco.py` and `Datasets/flickr.py` for instructions on downloading via HuggingFace Datasets.
7
+ * For the remaining datasets, download from their original sources:
8
+
9
+ * [CC3M](https://ai.google.com/research/ConceptualCaptions/download)
10
+ * [TextCaps](https://textvqa.org/textcaps/dataset/)
11
+ * [SBU](https://huggingface.co/datasets/vicenteor/sbu_captions)
12
+
13
+ Example test samples are also available on [Google Drive](https://drive.google.com/file/d/1JaaIGItyDYprr4_k0C_90MGIIRVQpmIP/view?usp=sharing). Please ensure their use complies with the original licenses.
data/train/download.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Download and unzip the synthetic training dataset from DRCT
2
+ # Reference: https://icml.cc/virtual/2024/poster/33086
3
+ # Data source: https://github.com/beibuwandeluori/DRCT
4
+ wget --no-check-certificate https://modelscope.cn/datasets/BokingChen/DRCT-2M/resolve/master/images/stable-diffusion-v1-4.zip
5
+ unzip stable-diffusion-v1-4.zip
6
+
7
+ # Download the real training dataset from MSCOCO2017
8
+ # Reference: https://arxiv.org/pdf/1405.0312
9
+ # Data source: https://cocodataset.org/#download
10
+ mkdir mscoco2017
11
+ cd mscoco2017
12
+ wget http://images.cocodataset.org/zips/train2017.zip
13
+ wget http://images.cocodataset.org/zips/val2017.zip
14
+ unzip train2017.zip
15
+ unzip val2017.zip
16
+ cd ..
environment.yml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cospy
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - ca-certificates=2025.2.25=h06a4308_0
8
+ - ld_impl_linux-64=2.40=h12ee557_0
9
+ - libffi=3.4.4=h6a678d5_1
10
+ - libgcc-ng=11.2.0=h1234567_1
11
+ - libgomp=11.2.0=h1234567_1
12
+ - libstdcxx-ng=11.2.0=h1234567_1
13
+ - ncurses=6.4=h6a678d5_0
14
+ - openssl=3.0.16=h5eee18b_0
15
+ - pip=24.2=py38h06a4308_0
16
+ - python=3.8.18=h955ad1f_0
17
+ - readline=8.2=h5eee18b_0
18
+ - setuptools=75.1.0=py38h06a4308_0
19
+ - sqlite=3.45.3=h5eee18b_0
20
+ - tk=8.6.14=h39e8969_0
21
+ - wheel=0.44.0=py38h06a4308_0
22
+ - xz=5.6.4=h5eee18b_1
23
+ - zlib=1.2.13=h5eee18b_1
24
+ - pip:
25
+ - accelerate==1.0.1
26
+ - aiohappyeyeballs==2.4.4
27
+ - aiohttp==3.10.11
28
+ - aiosignal==1.3.1
29
+ - async-timeout==5.0.1
30
+ - attrs==25.3.0
31
+ - certifi==2025.1.31
32
+ - charset-normalizer==3.4.1
33
+ - contourpy==1.1.1
34
+ - cycler==0.12.1
35
+ - datasets==3.1.0
36
+ - diffusers==0.32.2
37
+ - dill==0.3.8
38
+ - filelock==3.16.1
39
+ - fonttools==4.56.0
40
+ - frozenlist==1.5.0
41
+ - fsspec==2024.9.0
42
+ - ftfy==6.2.3
43
+ - huggingface-hub==0.29.3
44
+ - idna==3.10
45
+ - importlib-metadata==8.5.0
46
+ - importlib-resources==6.4.5
47
+ - jinja2==3.1.6
48
+ - joblib==1.4.2
49
+ - kiwisolver==1.4.7
50
+ - loguru==0.7.3
51
+ - markupsafe==2.1.5
52
+ - matplotlib==3.7.5
53
+ - mpmath==1.3.0
54
+ - multidict==6.1.0
55
+ - multiprocess==0.70.16
56
+ - networkx==3.1
57
+ - numpy==1.24.4
58
+ - nvidia-cublas-cu12==12.1.3.1
59
+ - nvidia-cuda-cupti-cu12==12.1.105
60
+ - nvidia-cuda-nvrtc-cu12==12.1.105
61
+ - nvidia-cuda-runtime-cu12==12.1.105
62
+ - nvidia-cudnn-cu12==9.1.0.70
63
+ - nvidia-cufft-cu12==11.0.2.54
64
+ - nvidia-curand-cu12==10.3.2.106
65
+ - nvidia-cusolver-cu12==11.4.5.107
66
+ - nvidia-cusparse-cu12==12.1.0.106
67
+ - nvidia-nccl-cu12==2.20.5
68
+ - nvidia-nvjitlink-cu12==12.8.93
69
+ - nvidia-nvtx-cu12==12.1.105
70
+ - open-clip-torch==2.31.0
71
+ - opencv-python==4.11.0.86
72
+ - packaging==24.2
73
+ - pandas==2.0.3
74
+ - pillow==10.4.0
75
+ - propcache==0.2.0
76
+ - psutil==7.0.0
77
+ - pyarrow==17.0.0
78
+ - pycocotools==2.0.7
79
+ - pyparsing==3.1.4
80
+ - python-dateutil==2.9.0.post0
81
+ - pytz==2025.1
82
+ - pyyaml==6.0.2
83
+ - regex==2024.11.6
84
+ - requests==2.32.3
85
+ - safetensors==0.5.3
86
+ - scikit-learn==1.3.2
87
+ - scipy==1.10.1
88
+ - six==1.17.0
89
+ - sympy==1.13.3
90
+ - threadpoolctl==3.5.0
91
+ - timm==1.0.15
92
+ - tokenizers==0.20.3
93
+ - torch==2.4.1
94
+ - torchvision==0.19.1
95
+ - tqdm==4.67.1
96
+ - transformers==4.46.3
97
+ - triton==3.0.0
98
+ - typing-extensions==4.12.2
99
+ - tzdata==2025.1
100
+ - urllib3==2.2.3
101
+ - wcwidth==0.2.13
102
+ - xxhash==3.5.0
103
+ - yarl==1.15.2
104
+ - zipp==3.20.2
105
+ prefix: /connect4/cheng535-new/anaconda3/envs/cospy
evaluate.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+ from sklearn.metrics import average_precision_score
8
+ import csv
9
+
10
+ from Detectors import CospyCalibrateDetector
11
+ from Datasets import TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
12
+ from utils import seed_torch
13
+ from sklearn.metrics import (
14
+ accuracy_score, log_loss, average_precision_score, f1_score,
15
+ roc_auc_score, balanced_accuracy_score, confusion_matrix, recall_score
16
+ )
17
+
18
+ import numpy as np
19
+ import warnings
20
+ warnings.filterwarnings("ignore")
21
+
22
+
23
+ class Detector():
24
+ def __init__(self, args):
25
+ super(Detector, self).__init__()
26
+
27
+ # Device
28
+ self.device = args.device
29
+
30
+ # Initialize the detector
31
+ self.model = CospyCalibrateDetector(
32
+ semantic_weights_path=args.semantic_weights_path,
33
+ artifact_weights_path=args.artifact_weights_path)
34
+
35
+ # Load the pre-trained weights
36
+ self.model.load_weights(args.classifier_weights_path)
37
+ self.model.eval()
38
+
39
+ # Put the model on the device
40
+ self.model.to(self.device)
41
+
42
+ # Prediction function
43
+ def predict(self, inputs):
44
+ inputs = inputs.to(self.device)
45
+ outputs = self.model(inputs)
46
+ prediction = outputs.sigmoid().flatten().tolist()
47
+ return prediction
48
+
49
+
50
+ def expected_calibration_error(y_true, y_prob, n_bins=10):
51
+ """Tính ECE (Expected Calibration Error)"""
52
+ y_true = np.array(y_true)
53
+ y_prob = np.array(y_prob)
54
+ bins = np.linspace(0.0, 1.0, n_bins + 1)
55
+ ece = 0.0
56
+ for i in range(n_bins):
57
+ mask = (y_prob > bins[i]) & (y_prob <= bins[i+1])
58
+ if np.sum(mask) > 0:
59
+ prob_mean = y_prob[mask].mean()
60
+ acc = y_true[mask].mean()
61
+ ece += np.sum(mask) / len(y_true) * abs(acc - prob_mean)
62
+ return ece
63
+
64
+ def evaluate(y_pred, y_true):
65
+ y_pred = np.array(y_pred)
66
+ y_true = np.array(y_true)
67
+ pred_label = y_pred > 0.5
68
+
69
+ # Metrics
70
+ acc = accuracy_score(y_true, pred_label)
71
+ nll = log_loss(y_true, y_pred, eps=1e-7)
72
+ ap = average_precision_score(y_true, y_pred)
73
+ ece = expected_calibration_error(y_true, y_pred)
74
+ f1 = f1_score(y_true, pred_label)
75
+ try:
76
+ auc = roc_auc_score(y_true, y_pred)
77
+ except:
78
+ auc = float('nan')
79
+ bacc = balanced_accuracy_score(y_true, pred_label)
80
+ tn, fp, fn, tp = confusion_matrix(y_true, pred_label).ravel()
81
+ fnr = fn / (fn + tp) if (fn + tp) > 0 else float('nan')
82
+ recall_total = recall_score(y_true, pred_label) # recall tổng thể
83
+
84
+ return {
85
+ "ACC": acc,
86
+ "NLL": nll,
87
+ "AP": ap,
88
+ "ECE": ece,
89
+ "F1": f1,
90
+ "AUC": auc,
91
+ "bAcc": bacc,
92
+ "FNR": fnr,
93
+ "Recall": recall_total
94
+ }
95
+
96
+
97
+
98
+ def test(args):
99
+ # Initialize the detector
100
+ detector = Detector(args)
101
+
102
+ # Set the saving directory
103
+ if not os.path.exists(args.save_dir):
104
+ os.makedirs(args.save_dir)
105
+ save_result_path = os.path.join(args.save_dir, "result.json")
106
+ save_output_path = os.path.join(args.save_dir, "output.json")
107
+
108
+ # Begin the evaluation
109
+ result_all = {}
110
+ output_all = {}
111
+ for dataset_name in EVAL_DATASET_LIST:
112
+ result_all[dataset_name] = {}
113
+ output_all[dataset_name] = {}
114
+ for model_name in EVAL_MODEL_LIST:
115
+ test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=detector.model.test_transform)
116
+ test_loader = torch.utils.data.DataLoader(test_dataset,
117
+ batch_size=args.batch_size,
118
+ shuffle=False,
119
+ num_workers=4,
120
+ pin_memory=True)
121
+
122
+ # Evaluate the model
123
+ y_pred, y_true = [], []
124
+ for images, labels, _ in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
125
+ y_pred.extend(detector.predict(images))
126
+ y_true.extend(labels.tolist())
127
+
128
+
129
+ metrics = evaluate(y_pred, y_true)
130
+ print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | "
131
+ f"ACC {metrics['ACC']*100:.2f}% | Recall {metrics['Recall']*100:.2f}% | "
132
+ f"NLL {metrics['NLL']:.4f} | AP {metrics['AP']*100:.2f}% | "
133
+ f"ECE {metrics['ECE']:.4f} | F1 {metrics['F1']*100:.2f}% | "
134
+ f"AUC {metrics['AUC']*100:.2f}% | bAcc {metrics['bAcc']*100:.2f}% | "
135
+ f"FNR {metrics['FNR']*100:.2f}%")
136
+
137
+ result_all[dataset_name][model_name] = {"size": len(y_true), **metrics}
138
+ csv_dir = os.path.join(args.save_dir, "csv_outputs")
139
+ os.makedirs(csv_dir, exist_ok=True)
140
+
141
+ csv_path = os.path.join(csv_dir, f"{dataset_name}_{model_name}.csv")
142
+
143
+ with open(csv_path, mode="w", newline="", encoding="utf-8") as f:
144
+ writer = csv.writer(f)
145
+ writer.writerow(["path_to_image", "true_label", "pred_percentage", "pred_label"])
146
+
147
+ idx = 0
148
+ for img_path in test_dataset.image_paths:
149
+ pred_score = float(y_pred[idx])
150
+ pred_label = 1 if pred_score > 0.5 else 0
151
+ true_label = int(y_true[idx])
152
+
153
+ writer.writerow([
154
+ img_path,
155
+ true_label,
156
+ pred_score,
157
+ pred_label
158
+ ])
159
+ idx += 1
160
+
161
+ print(f"[CSV SAVED] {csv_path}")
162
+ output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
163
+
164
+ # Save the results
165
+ with open(save_result_path, "w") as f:
166
+ json.dump(result_all, f, indent=4)
167
+
168
+ with open(save_output_path, "w") as f:
169
+ json.dump(output_all, f, indent=4)
170
+
171
+
172
+ def scan(args):
173
+ # Initialize the detector
174
+ detector = Detector(args)
175
+
176
+ # Define the pre-processing function
177
+ test_transform = detector.model.test_transform
178
+
179
+ # Load the image
180
+ image_filepath = input("Please enter the image filepath for scanning: ")
181
+ if not os.path.exists(image_filepath):
182
+ print(f"Image file not found: {image_filepath}")
183
+ image_filepath = input("Please enter the image filepath for scanning: ")
184
+
185
+ image = Image.open(image_filepath).convert("RGB")
186
+ image = test_transform(image)
187
+ image = image.unsqueeze(0)
188
+ image = image.to(args.device)
189
+
190
+ # Make the prediction
191
+ prediction = detector.predict(image)[0]
192
+
193
+ if prediction > 0.5:
194
+ print(f"CO-SPY Prediction: {prediction:.3f} - AI-Generated")
195
+ else:
196
+ print(f"CO-SPY Prediction: {prediction:.3f} - Real")
197
+
198
+
199
+ if __name__ == "__main__":
200
+ import argparse
201
+
202
+ parser = argparse.ArgumentParser("Deep Fake Detection")
203
+ parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
204
+ parser.add_argument("--phase", type=str, default="scan", choices=["scan", "test"], help="Phase of the experiment")
205
+ parser.add_argument("--semantic_weights_path", type=str, default="pretrained/semantic_weights.pth", help="Semantic weights path")
206
+ parser.add_argument("--artifact_weights_path", type=str, default="pretrained/artifact_weights.pth", help="Artifact weights path")
207
+ parser.add_argument("--classifier_weights_path", type=str, default="pretrained/classifier_weights.pth", help="Classifier weights path")
208
+ parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
209
+ parser.add_argument("--save_dir", type=str, default="test_results", help="Save directory")
210
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
211
+ parser.add_argument("--seed", type=int, default=1024, help="Random seed")
212
+
213
+ args = parser.parse_args()
214
+
215
+ # Set the random seed
216
+ seed_torch(args.seed)
217
+
218
+ # Set the GPU ID
219
+ args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
220
+
221
+ # Begin the experiment
222
+ if args.phase == "scan":
223
+ scan(args)
224
+ elif args.phase == "test":
225
+ test(args)
226
+ else:
227
+ raise ValueError("Unknown phase")
pretrained/classifer_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b3cd62721dca4183bfd37790c12ccaf964f71fe7c6bbf4d97eda5f44c6bafab
3
+ size 1456
pretrained/classifier_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bf7e0efb68cf57718742ec3c944640856fd86ddaf1bb219e6cacdc280f781dc
3
+ size 1450
pretrained/semantic_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3e7c4cf6534e7fac0f2f898d3764d5aa892653dab96ed1316fd123fa4e0a17c
3
+ size 6064
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python-headless
4
+ numpy
5
+ Pillow
6
+ streamlit
7
+ tqdm
8
+ einops
train.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from loguru import logger
8
+ from sklearn.metrics import average_precision_score
9
+
10
+ from utils import seed_torch
11
+ from Detectors import CospyDetector, LabelSmoothingBCEWithLogits
12
+ from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+
18
+ class Detector():
19
+ def __init__(self, args):
20
+ super(Detector, self).__init__()
21
+
22
+ # Device
23
+ self.device = args.device
24
+
25
+ # Get the detector
26
+ self.model = CospyDetector()
27
+
28
+ # Put the model on the device
29
+ self.model.to(self.device)
30
+
31
+ # Initialize the fc layer
32
+ torch.nn.init.normal_(self.model.sem.fc.weight.data, 0.0, 0.02)
33
+ torch.nn.init.normal_(self.model.art.fc.weight.data, 0.0, 0.02)
34
+ torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
35
+
36
+ # Optimizer
37
+ _lr = 1e-4
38
+ _beta1 = 0.9
39
+ _weight_decay = 0.0
40
+ params = []
41
+ for name, param in self.model.named_parameters():
42
+ if param.requires_grad:
43
+ params.append(param)
44
+ print(f"Trainable parameters: {len(params)}")
45
+
46
+ self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay)
47
+
48
+ # Loss function
49
+ if args.no_label_smooth:
50
+ self.criterion = torch.nn.BCEWithLogitsLoss()
51
+ else:
52
+ self.criterion = LabelSmoothingBCEWithLogits(smoothing=0.1)
53
+
54
+ # Scheduler
55
+ self.delr_freq = 10
56
+
57
+ # Training function for the detector
58
+ def train_step(self, batch_data):
59
+ # Decompose the batch data
60
+ inputs, labels = batch_data
61
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
62
+
63
+ self.optimizer.zero_grad()
64
+
65
+ outputs = self.model(inputs)
66
+
67
+ loss = self.criterion(outputs, labels.unsqueeze(1).float())
68
+ loss.backward()
69
+ self.optimizer.step()
70
+
71
+ eval_loss = loss.item()
72
+ y_pred = outputs.sigmoid().flatten().tolist()
73
+ y_true = labels.tolist()
74
+ return eval_loss, y_pred, y_true
75
+
76
+ # Schedule the training
77
+ # Early stopping / learning rate adjustment
78
+ def scheduler(self, status_dict):
79
+ epoch = status_dict["epoch"]
80
+ if epoch % self.delr_freq == 0 and epoch != 0:
81
+ for param_group in self.optimizer.param_groups:
82
+ param_group["lr"] *= 0.9
83
+ self.lr = param_group["lr"]
84
+ return True
85
+
86
+ # Prediction function
87
+ def predict(self, inputs):
88
+ inputs = inputs.to(self.device)
89
+ outputs = self.model(inputs)
90
+ prediction = outputs.sigmoid().flatten().tolist()
91
+ return prediction
92
+
93
+
94
+ def evaluate(y_pred, y_true):
95
+ ap = average_precision_score(y_true, y_pred)
96
+ accuracy = ((np.array(y_pred) > 0.5) == y_true).mean()
97
+ return ap, accuracy
98
+
99
+
100
+ def train(args):
101
+ # Get the detector
102
+ detector = Detector(args)
103
+
104
+ # Load the dataset
105
+ train_dataset = TrainDataset(data_path=args.trainset_dirpath,
106
+ split="train",
107
+ transform=detector.model.train_transform)
108
+ train_loader = torch.utils.data.DataLoader(train_dataset,
109
+ batch_size=args.batch_size,
110
+ shuffle=True,
111
+ num_workers=4,
112
+ pin_memory=True)
113
+
114
+ test_dataset = TrainDataset(data_path=args.trainset_dirpath,
115
+ split="val",
116
+ transform=detector.model.test_transform)
117
+ test_loader = torch.utils.data.DataLoader(test_dataset,
118
+ batch_size=args.batch_size,
119
+ shuffle=False,
120
+ num_workers=4,
121
+ pin_memory=True)
122
+
123
+ logger.info(f"Train size {len(train_dataset)} | Test size {len(test_dataset)}")
124
+
125
+ # Set the saving directory
126
+ model_dir = os.path.join(args.ckpt, "cospy")
127
+ if not os.path.exists(model_dir):
128
+ os.makedirs(model_dir)
129
+ log_path = f"{model_dir}/training.log"
130
+ if os.path.exists(log_path):
131
+ os.remove(log_path)
132
+
133
+ logger_id = logger.add(
134
+ log_path,
135
+ format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
136
+ level="DEBUG",
137
+ )
138
+
139
+ # Train the detector
140
+ best_acc = 0
141
+ for epoch in range(args.epochs):
142
+ # Set the model to training mode
143
+ detector.model.train()
144
+ time_start = time.time()
145
+ for step_id, batch_data in enumerate(train_loader):
146
+ eval_loss, y_pred, y_true = detector.train_step(batch_data)
147
+ ap, accuracy = evaluate(y_pred, y_true)
148
+
149
+ # Log the training information
150
+ if (step_id + 1) % 100 == 0:
151
+ time_end = time.time()
152
+ logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s")
153
+ time_start = time.time()
154
+
155
+ # Evaluate the model
156
+ detector.model.eval()
157
+ y_pred, y_true = [], []
158
+ for (images, labels) in test_loader:
159
+ y_pred.extend(detector.predict(images))
160
+ y_true.extend(labels.tolist())
161
+
162
+ ap, accuracy = evaluate(y_pred, y_true)
163
+ logger.info(f"Epoch {epoch} | Test AP {ap*100:.2f}% | Test Accuracy {accuracy*100:.2f}%")
164
+
165
+ # Schedule the training
166
+ status_dict = {"epoch": epoch, "AP": ap, "Accuracy": accuracy}
167
+ proceed = detector.scheduler(status_dict)
168
+ if not proceed:
169
+ logger.info("Early stopping")
170
+ break
171
+
172
+ # Save the model
173
+ if accuracy >= best_acc:
174
+ best_acc = accuracy
175
+ detector.model.save_weights(f"{model_dir}/best_model.pth")
176
+ logger.info(f"Best model saved with accuracy {best_acc.mean()*100:.2f}%")
177
+
178
+ if epoch % 5 == 0:
179
+ detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth")
180
+ logger.info(f"Model saved at epoch {epoch}")
181
+
182
+ # Save the final model
183
+ detector.model.save_weights(f"{model_dir}/final_model.pth")
184
+ logger.info("Final model saved")
185
+
186
+ # Remove the logger
187
+ logger.remove(logger_id)
188
+
189
+
190
+ def test(args):
191
+ # Initialize the detector
192
+ detector = Detector(args)
193
+
194
+ # Load the [best/final] model
195
+ weights_path = os.path.join(args.ckpt, "cospy", "best_model.pth")
196
+
197
+ detector.model.load_weights(weights_path)
198
+ detector.model.to(args.device)
199
+ detector.model.eval()
200
+
201
+ # Set the pre-processing function
202
+ test_transform = detector.model.test_transform
203
+
204
+ # Set the saving directory
205
+ save_dir = os.path.join(args.ckpt, "cospy")
206
+ save_result_path = os.path.join(save_dir, "result.json")
207
+ save_output_path = os.path.join(save_dir, "output.json")
208
+
209
+ # Begin the evaluation
210
+ result_all = {}
211
+ output_all = {}
212
+ for dataset_name in EVAL_DATASET_LIST:
213
+ result_all[dataset_name] = {}
214
+ output_all[dataset_name] = {}
215
+ for model_name in EVAL_MODEL_LIST:
216
+ test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform)
217
+ test_loader = torch.utils.data.DataLoader(test_dataset,
218
+ batch_size=args.batch_size,
219
+ shuffle=False,
220
+ num_workers=4,
221
+ pin_memory=True)
222
+
223
+ # Evaluate the model
224
+ y_pred, y_true = [], []
225
+ for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
226
+ y_pred.extend(detector.predict(images))
227
+ y_true.extend(labels.tolist())
228
+
229
+ ap, accuracy = evaluate(y_pred, y_true)
230
+ print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%")
231
+
232
+ result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy}
233
+ output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
234
+
235
+ # Save the results
236
+ with open(save_result_path, "w") as f:
237
+ json.dump(result_all, f, indent=4)
238
+
239
+ with open(save_output_path, "w") as f:
240
+ json.dump(output_all, f, indent=4)
241
+
242
+
243
+ if __name__ == "__main__":
244
+ import argparse
245
+
246
+ parser = argparse.ArgumentParser("Deep Fake Detection")
247
+ parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
248
+ parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment")
249
+ parser.add_argument("--no_label_smooth", action="store_true", help="Whether to use label smoothing")
250
+ parser.add_argument("--trainset_dirpath", type=str, default="data/train", help="Trainset directory")
251
+ parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
252
+ parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory")
253
+ parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
254
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
255
+ parser.add_argument("--seed", type=int, default=1024, help="Random seed")
256
+
257
+ args = parser.parse_args()
258
+
259
+ # Set the random seed
260
+ seed_torch(args.seed)
261
+
262
+ # Set the GPU ID
263
+ args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
264
+
265
+ # Begin the experiment
266
+ if args.phase == "train":
267
+ train(args)
268
+ elif args.phase == "test":
269
+ test(args)
270
+ else:
271
+ raise ValueError("Unknown phase")
train_single.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from loguru import logger
8
+ from sklearn.metrics import average_precision_score
9
+
10
+ from utils import seed_torch
11
+ from Detectors import ArtifactDetector, SemanticDetector
12
+ from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ class Detector():
18
+ def __init__(self, args):
19
+ super(Detector, self).__init__()
20
+
21
+ # Device
22
+ self.device = args.device
23
+
24
+ # Get the detector
25
+ if args.detector == "artifact":
26
+ self.model = ArtifactDetector()
27
+ elif args.detector == "semantic":
28
+ self.model = SemanticDetector()
29
+ else:
30
+ raise ValueError("Unknown detector")
31
+
32
+ # Put the model on the device
33
+ self.model.to(self.device)
34
+
35
+ # Initialize the fc layer
36
+ torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
37
+
38
+ # Optimizer
39
+ _lr = 1e-4
40
+ _beta1 = 0.9
41
+ _weight_decay = 0.0
42
+ params = [p for p in self.model.parameters() if p.requires_grad]
43
+ print(f"Trainable parameters: {len(params)}")
44
+
45
+ self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay)
46
+
47
+ # Loss function
48
+ self.criterion = torch.nn.BCEWithLogitsLoss()
49
+
50
+ # Scheduler
51
+ self.delr_freq = 10
52
+
53
+ # Resume info
54
+ self.start_epoch = 0
55
+ self.best_acc = 0.0
56
+
57
+ def train_step(self, batch_data):
58
+ inputs, labels = batch_data
59
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
60
+
61
+ self.optimizer.zero_grad()
62
+ outputs = self.model(inputs)
63
+ loss = self.criterion(outputs, labels.unsqueeze(1).float())
64
+ loss.backward()
65
+ self.optimizer.step()
66
+
67
+ eval_loss = loss.item()
68
+ y_pred = outputs.sigmoid().flatten().tolist()
69
+ y_true = labels.tolist()
70
+ return eval_loss, y_pred, y_true
71
+
72
+ def scheduler(self, status_dict):
73
+ epoch = status_dict["epoch"]
74
+ if epoch % self.delr_freq == 0 and epoch != 0:
75
+ for param_group in self.optimizer.param_groups:
76
+ param_group["lr"] *= 0.9
77
+ self.lr = param_group["lr"]
78
+ return True
79
+
80
+ def predict(self, inputs):
81
+ inputs = inputs.to(self.device)
82
+ outputs = self.model(inputs)
83
+ return outputs.sigmoid().flatten().tolist()
84
+
85
+ # --- Checkpoint functions ---
86
+ def save_checkpoint(self, path, epoch, best_acc):
87
+ torch.save({
88
+ "epoch": epoch,
89
+ "best_acc": best_acc,
90
+ "model_state": self.model.state_dict(),
91
+ "optimizer_state": self.optimizer.state_dict()
92
+ }, path)
93
+
94
+ def load_checkpoint(self, path):
95
+ if os.path.exists(path):
96
+ ckpt = torch.load(path, map_location=self.device)
97
+ self.model.load_state_dict(ckpt["model_state"])
98
+ self.optimizer.load_state_dict(ckpt["optimizer_state"])
99
+ self.start_epoch = ckpt.get("epoch", 0) + 1
100
+ self.best_acc = ckpt.get("best_acc", 0.0)
101
+ print(f"[INFO] Loaded checkpoint '{path}' (start_epoch={self.start_epoch}, best_acc={self.best_acc})")
102
+ else:
103
+ print(f"[WARNING] Checkpoint not found: {path}")
104
+
105
+
106
+ def evaluate(y_pred, y_true):
107
+ ap = average_precision_score(y_true, y_pred)
108
+ accuracy = ((np.array(y_pred) > 0.5) == y_true).mean()
109
+ return ap, accuracy
110
+
111
+
112
+ def train(args):
113
+ # Get the detector
114
+ detector = Detector(args)
115
+
116
+ # --- Resume checkpoint ---
117
+ start_epoch = 0
118
+ best_acc = 0
119
+ if args.resume != "":
120
+ if os.path.exists(args.resume):
121
+ print(f"[INFO] Loading checkpoint from {args.resume}")
122
+ ckpt = torch.load(args.resume, map_location=args.device)
123
+ detector.model.load_weights(args.resume)
124
+ # Nếu lưu thêm optimizer & best_acc, load ở đây
125
+ if "best_acc" in ckpt:
126
+ best_acc = ckpt["best_acc"]
127
+ if "epoch" in ckpt:
128
+ start_epoch = ckpt["epoch"] + 1
129
+ else:
130
+ print(f"[WARNING] Resume checkpoint not found: {args.resume}")
131
+
132
+ # Load datasets
133
+ train_dataset = TrainDataset(data_path=args.trainset_dirpath,
134
+ split="train",
135
+ transform=detector.model.train_transform)
136
+ train_loader = torch.utils.data.DataLoader(train_dataset,
137
+ batch_size=args.batch_size,
138
+ shuffle=True,
139
+ num_workers=4,
140
+ pin_memory=True)
141
+
142
+ test_dataset = TrainDataset(data_path=args.trainset_dirpath,
143
+ split="val",
144
+ transform=detector.model.test_transform)
145
+ test_loader = torch.utils.data.DataLoader(test_dataset,
146
+ batch_size=args.batch_size,
147
+ shuffle=False,
148
+ num_workers=4,
149
+ pin_memory=True)
150
+
151
+ logger.info(f"Train size {len(train_dataset)} | Test size {len(test_dataset)}")
152
+
153
+ # Set saving directory
154
+ model_dir = os.path.join(args.ckpt, args.detector)
155
+ os.makedirs(model_dir, exist_ok=True)
156
+ log_path = f"{model_dir}/training.log"
157
+ if os.path.exists(log_path):
158
+ os.remove(log_path)
159
+ logger_id = logger.add(log_path, format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}", level="DEBUG")
160
+
161
+ # Train loop
162
+ for epoch in range(start_epoch, args.epochs):
163
+ detector.model.train()
164
+ time_start = time.time()
165
+ for step_id, batch_data in enumerate(train_loader):
166
+ eval_loss, y_pred, y_true = detector.train_step(batch_data)
167
+ ap, accuracy = evaluate(y_pred, y_true)
168
+
169
+ if (step_id + 1) % 100 == 0:
170
+ time_end = time.time()
171
+ logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s")
172
+ time_start = time.time()
173
+
174
+ # Evaluate
175
+ detector.model.eval()
176
+ y_pred, y_true = [], []
177
+ for (images, labels) in test_loader:
178
+ y_pred.extend(detector.predict(images))
179
+ y_true.extend(labels.tolist())
180
+ ap, accuracy = evaluate(y_pred, y_true)
181
+ logger.info(f"Epoch {epoch} | Test AP {ap*100:.2f}% | Test Accuracy {accuracy*100:.2f}%")
182
+
183
+ # Save best model
184
+ if accuracy >= best_acc:
185
+ best_acc = accuracy
186
+ detector.model.save_weights(f"{model_dir}/best_model.pth")
187
+ torch.save({"epoch": epoch, "best_acc": best_acc}, f"{model_dir}/best_model_meta.pth")
188
+ logger.info(f"Best model saved with accuracy {best_acc*100:.2f}%")
189
+
190
+ # Save periodic checkpoints
191
+ if epoch % 5 == 0:
192
+ detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth")
193
+ logger.info(f"Model saved at epoch {epoch}")
194
+
195
+ # Save final model
196
+ detector.model.save_weights(f"{model_dir}/final_model.pth")
197
+ logger.info("Final model saved")
198
+ logger.remove(logger_id)
199
+
200
+
201
+
202
+ def test(args):
203
+ # Initialize the detector
204
+ detector = Detector(args)
205
+ # --- Load checkpoint if resume is provided ---
206
+ if args.resume != "":
207
+ ckpt_path = args.resume
208
+ if os.path.exists(ckpt_path):
209
+ print(f"[INFO] Loading checkpoint from {ckpt_path}")
210
+ detector.model.load_weights(ckpt_path)
211
+ else:
212
+ print(f"[WARNING] Resume checkpoint not found: {ckpt_path}")
213
+
214
+
215
+ # Load the [best/final] model
216
+ weights_path = os.path.join(args.ckpt, args.detector, "best_model.pth")
217
+
218
+ detector.model.load_weights(weights_path)
219
+ detector.model.to(args.device)
220
+ detector.model.eval()
221
+
222
+ # Set the pre-processing function
223
+ test_transform = detector.model.test_transform
224
+
225
+ # Set the saving directory
226
+ save_dir = os.path.join(args.ckpt, args.detector)
227
+ save_result_path = os.path.join(save_dir, "result.json")
228
+ save_output_path = os.path.join(save_dir, "output.json")
229
+
230
+ # Begin the evaluation
231
+ result_all = {}
232
+ output_all = {}
233
+ for dataset_name in EVAL_DATASET_LIST:
234
+ result_all[dataset_name] = {}
235
+ output_all[dataset_name] = {}
236
+ for model_name in EVAL_MODEL_LIST:
237
+ test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform)
238
+ test_loader = torch.utils.data.DataLoader(test_dataset,
239
+ batch_size=args.batch_size,
240
+ shuffle=False,
241
+ num_workers=4,
242
+ pin_memory=True)
243
+
244
+ # Evaluate the model
245
+ y_pred, y_true = [], []
246
+ for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
247
+ y_pred.extend(detector.predict(images))
248
+ y_true.extend(labels.tolist())
249
+
250
+ ap, accuracy = evaluate(y_pred, y_true)
251
+ print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%")
252
+
253
+ result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy}
254
+ output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
255
+
256
+ # Save the results
257
+ with open(save_result_path, "w") as f:
258
+ json.dump(result_all, f, indent=4)
259
+
260
+ with open(save_output_path, "w") as f:
261
+ json.dump(output_all, f, indent=4)
262
+
263
+
264
+ if __name__ == "__main__":
265
+ import argparse
266
+
267
+ parser = argparse.ArgumentParser("Deep Fake Detection")
268
+ parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
269
+ parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment")
270
+ parser.add_argument("--detector", type=str, default="artifact", choices=["artifact", "semantic"], help="Detector to use")
271
+ parser.add_argument("--trainset_dirpath", type=str, default="data/train", help="Trainset directory")
272
+ parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
273
+ parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory")
274
+ parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
275
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
276
+ parser.add_argument("--seed", type=int, default=1024, help="Random seed")
277
+ parser.add_argument("--resume", type=str, default="", help="Path to checkpoint to resume training")
278
+
279
+ args = parser.parse_args()
280
+
281
+ # Set the random seed
282
+ seed_torch(args.seed)
283
+
284
+ # Set the GPU ID
285
+ args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
286
+
287
+ # Begin the experiment
288
+ if args.phase == "train":
289
+ train(args)
290
+ elif args.phase == "test":
291
+ test(args)
292
+ else:
293
+ raise ValueError("Unknown phase")
utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import pickle
5
+ import random
6
+ import numpy as np
7
+ from io import BytesIO
8
+ from PIL import Image, ImageFile
9
+ import torchvision.transforms.functional as TF
10
+ from scipy.ndimage.filters import gaussian_filter
11
+
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+
14
+
15
+ # Set random seed
16
+ def seed_torch(seed):
17
+ random.seed(seed)
18
+ os.environ['PYTHONHASHSEED'] = str(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ torch.backends.cudnn.benchmark = False
24
+ torch.backends.cudnn.deterministic = True
25
+
26
+
27
+ # Load dataset
28
+ def recursively_read(rootdir, must_contain, exts=["png", "PNG", "jpg", "JPG", "jpeg", "JPEG"]):
29
+ out = []
30
+ for r, d, f in os.walk(rootdir):
31
+ for file in f:
32
+ if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)):
33
+ out.append(os.path.join(r, file))
34
+ return out
35
+
36
+
37
+ def get_list(path, must_contain=''):
38
+ if ".pickle" in path:
39
+ with open(path, 'rb') as f:
40
+ image_list = pickle.load(f)
41
+ image_list = [item for item in image_list if must_contain in item]
42
+ else:
43
+ image_list = recursively_read(path, must_contain)
44
+ return image_list
45
+
46
+
47
+ # Data augmentation techniques
48
+ def data_augment(img, aug_config):
49
+ img = np.array(img)
50
+ if img.ndim == 2:
51
+ img = np.expand_dims(img, axis=2)
52
+ img = np.repeat(img, 3, axis=2)
53
+
54
+ if random.random() < aug_config["blur_prob"]:
55
+ sig = sample_continuous(aug_config["blur_sig"])
56
+ gaussian_blur(img, sig)
57
+
58
+ if random.random() < aug_config["jpg_prob"]:
59
+ method = sample_discrete(aug_config["jpg_method"])
60
+ qual = sample_discrete(aug_config["jpg_qual"])
61
+ img = jpeg_from_key(img, qual, method)
62
+
63
+ return Image.fromarray(img)
64
+
65
+
66
+ # Data augmentation techniques
67
+ def tensor_data_augment(images, aug_config):
68
+ device = images.device
69
+ images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
70
+ images = np.uint8(images * 255.)
71
+ outputs = []
72
+ for img in images:
73
+ if random.random() < aug_config["blur_prob"]:
74
+ sig = sample_continuous(aug_config["blur_sig"])
75
+ gaussian_blur(img, sig)
76
+
77
+ if random.random() < aug_config["jpg_prob"]:
78
+ method = sample_discrete(aug_config["jpg_method"])
79
+ qual = sample_discrete(aug_config["jpg_qual"])
80
+ img = jpeg_from_key(img, qual, method)
81
+ outputs.append(img)
82
+ outputs = np.stack(outputs)
83
+ outputs = torch.from_numpy(outputs).to(device).permute(0, 3, 1, 2).float() / 255.
84
+ return outputs
85
+
86
+
87
+ # Sample continuous or discrete values
88
+ def sample_continuous(s):
89
+ if len(s) == 1:
90
+ return s[0]
91
+ if len(s) == 2:
92
+ rg = s[1] - s[0]
93
+ return random.random() * rg + s[0]
94
+ raise ValueError("Length of iterable s should be 1 or 2.")
95
+
96
+
97
+ def sample_discrete(s):
98
+ if len(s) == 1:
99
+ return s[0]
100
+ return random.choice(s)
101
+
102
+
103
+ # Gaussian blur
104
+ def gaussian_blur(img, sigma):
105
+ gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma)
106
+ gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma)
107
+ gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma)
108
+
109
+
110
+ # JPEG compression
111
+ def cv2_jpg(img, compress_val):
112
+ img_cv2 = img[:,:,::-1]
113
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
114
+ result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
115
+ decimg = cv2.imdecode(encimg, 1)
116
+ return decimg[:,:,::-1]
117
+
118
+
119
+ def pil_jpg(img, compress_val):
120
+ out = BytesIO()
121
+ img = Image.fromarray(img)
122
+ img.save(out, format='jpeg', quality=compress_val)
123
+ img = Image.open(out)
124
+ # load from memory before ByteIO closes
125
+ img = np.array(img)
126
+ out.close()
127
+ return img
128
+
129
+
130
+ def png_to_jpeg(img, quality=95):
131
+ # Convert the PNG image to JPEG
132
+ # Input: PIL image
133
+ # Output: PIL image
134
+ out = BytesIO()
135
+ img.save(out, format='jpeg', quality=quality)
136
+ img = np.array(Image.open(out))
137
+ # Load from memory before ByteIO closes
138
+ out.close()
139
+ img = Image.fromarray(img)
140
+ return img
141
+
142
+
143
+ def jpeg_from_key(img, compress_val, key):
144
+ jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg}
145
+ method = jpeg_dict[key]
146
+ return method(img, compress_val)
147
+
148
+
149
+ # Custom resize function
150
+ def custom_resize(img, rz_interp, loadSize):
151
+ rz_dict = {'bilinear': Image.BILINEAR,
152
+ 'bicubic': Image.BICUBIC,
153
+ 'lanczos': Image.LANCZOS,
154
+ 'nearest': Image.NEAREST}
155
+ interp = sample_discrete(rz_interp)
156
+ return TF.resize(img, loadSize, interpolation=rz_dict[interp])
157
+
158
+
159
+ def weights2cpu(weights):
160
+ for key in weights:
161
+ weights[key] = weights[key].cpu()
162
+ return weights