import torch import torch.nn as nn from torch.nn import functional as F from torchvision import transforms from utils import data_augment # Artifact Detector (Extract artifact features using VAE) class ArtifactDetectorProGAN(torch.nn.Module): def __init__(self, dim_artifact=512, num_classes=1): super(ArtifactDetectorProGAN, self).__init__() # Load the artifact encoder based on NPR self.artifact_encoder = ResNet(Bottleneck, [3, 4, 6, 3]) # Classifier self.fc = torch.nn.Linear(dim_artifact, num_classes) # Normalization self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] # Resolution self.loadSize = 256 self.cropSize = 224 # Data augmentation self.blur_prob = 0.0 self.blur_sig = [0.0, 3.0] self.jpg_prob = 0.0 self.jpg_method = ['cv2', 'pil'] self.jpg_qual = list(range(70, 96)) # Define the augmentation configuration self.aug_config = { "blur_prob": self.blur_prob, "blur_sig": self.blur_sig, "jpg_prob": self.jpg_prob, "jpg_method": self.jpg_method, "jpg_qual": self.jpg_qual, } # Pre-processing crop_func = transforms.RandomCrop(self.cropSize) flip_func = transforms.RandomHorizontalFlip() rz_func = transforms.Resize(self.loadSize) aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config)) self.train_transform = transforms.Compose([ aug_func, rz_func, crop_func, flip_func, transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) self.test_transform = transforms.Compose([ rz_func, crop_func, transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) def forward(self, x, return_feat=False): feat = self.artifact_encoder(x) out = self.fc(feat) if return_feat: return feat, out return out def save_weights(self, weights_path): save_params = {k: v.cpu() for k, v in self.state_dict().items()} torch.save(save_params, weights_path) def load_weights(self, weights_path): weights = torch.load(weights_path) self.load_state_dict(weights) # Define the artifact encoder (based on NPR) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = conv1x1(inplanes, planes) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = conv3x3(planes, planes, stride) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = conv1x1(planes, planes * self.expansion) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1): super(ResNet, self).__init__() self.unfoldSize = 2 self.unfoldIndex = 0 assert self.unfoldSize > 1 assert -1 < self.unfoldIndex and self.unfoldIndex < self.unfoldSize*self.unfoldSize self.inplanes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64 , layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc1 = nn.Linear(512, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def interpolate(self, img, factor): return F.interpolate( F.interpolate(img, scale_factor=factor, mode='nearest', recompute_scale_factor=True), scale_factor=1 / factor, mode='nearest', recompute_scale_factor=True) def forward(self, x): artifact = x - self.interpolate(x, 0.5) x = self.conv1(artifact * 2.0 / 3.0) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.avgpool(x) x = x.view(x.size(0), -1) return x