CO-SPY / Detectors /cospy_calibrate_detector.py
NghiTran1009's picture
Upload clean CO-SPY project
cab012d
import torch
from torchvision import transforms
from utils import data_augment
from .semantic_detector import SemanticDetector
from .artifact_detector import ArtifactDetector
# CO-SPY Calibrate Detector (Calibrate the integration of semantic and artifact detectors)
class CospyCalibrateDetector(torch.nn.Module):
def __init__(self, semantic_weights_path, artifact_weights_path, num_classes=1):
super(CospyCalibrateDetector, self).__init__()
# Load the semantic detector
self.sem = SemanticDetector()
self.sem.load_weights(semantic_weights_path)
# Load the artifact detector
self.art = ArtifactDetector()
self.art.load_weights(artifact_weights_path)
# Freeze the two pre-trained models
for param in self.sem.parameters():
param.requires_grad = False
for param in self.art.parameters():
param.requires_grad = False
# Classifier
self.fc = torch.nn.Linear(2, num_classes)
# Transformations inside the forward function
# Including the normalization and resizing (only for the artifact detector)
self.sem_transform = transforms.Compose([
transforms.Normalize(self.sem.mean, self.sem.std)
])
self.art_transform = transforms.Compose([
transforms.Resize(self.art.cropSize, antialias=False),
transforms.Normalize(self.art.mean, self.art.std)
])
# Resolution
self.loadSize = 384
self.cropSize = 384
# Data augmentation
self.blur_prob = 0.0
self.blur_sig = [0.0, 3.0]
self.jpg_prob = 0.5
self.jpg_method = ['cv2', 'pil']
self.jpg_qual = list(range(70, 96))
# Define the augmentation configuration
self.aug_config = {
"blur_prob": self.blur_prob,
"blur_sig": self.blur_sig,
"jpg_prob": self.jpg_prob,
"jpg_method": self.jpg_method,
"jpg_qual": self.jpg_qual,
}
# Pre-processing
crop_func = transforms.RandomCrop(self.cropSize)
flip_func = transforms.RandomHorizontalFlip()
rz_func = transforms.Resize(self.loadSize)
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
self.train_transform = transforms.Compose([
flip_func,
aug_func,
rz_func,
crop_func,
transforms.ToTensor(),
])
self.test_transform = transforms.Compose([
rz_func,
crop_func,
transforms.ToTensor(),
])
def forward(self, x):
x_sem = self.sem_transform(x)
x_art = self.art_transform(x)
pred_sem = self.sem(x_sem)
pred_art = self.art(x_art)
x = torch.cat([pred_sem, pred_art], dim=1)
x = self.fc(x)
return x
def save_weights(self, weights_path):
save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
torch.save(save_params, weights_path)
def load_weights(self, weights_path):
weights = torch.load(weights_path)
self.fc.weight.data = weights["fc.weight"]
self.fc.bias.data = weights["fc.bias"]