Spaces:
Sleeping
Sleeping
File size: 2,957 Bytes
cab012d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
import open_clip
from torchvision import transforms
from utils import data_augment
# Semantic Detector (Extract semantic features using CLIP)
class SemanticDetector(torch.nn.Module):
def __init__(self, dim_clip=1152, num_classes=1):
super(SemanticDetector, self).__init__()
# Get the pre-trained CLIP
model_name = "ViT-SO400M-14-SigLIP-384"
version = "webli"
self.clip, _, _ = open_clip.create_model_and_transforms(model_name, pretrained=version)
# Freeze the CLIP visual encoder
self.clip.requires_grad_(False)
# Classifier
self.fc = torch.nn.Linear(dim_clip, num_classes)
# Normalization
self.mean = [0.5, 0.5, 0.5]
self.std = [0.5, 0.5, 0.5]
# Resolution
self.loadSize = 384
self.cropSize = 384
# Data augmentation
self.blur_prob = 0.5
self.blur_sig = [0.0, 3.0]
self.jpg_prob = 0.5
self.jpg_method = ['cv2', 'pil']
self.jpg_qual = list(range(30, 101))
# Define the augmentation configuration
self.aug_config = {
"blur_prob": self.blur_prob,
"blur_sig": self.blur_sig,
"jpg_prob": self.jpg_prob,
"jpg_method": self.jpg_method,
"jpg_qual": self.jpg_qual,
}
# Pre-processing
crop_func = transforms.RandomCrop(self.cropSize)
flip_func = transforms.RandomHorizontalFlip()
rz_func = transforms.Resize(self.loadSize)
aug_func = transforms.Lambda(lambda x: data_augment(x, self.aug_config))
self.train_transform = transforms.Compose([
rz_func,
aug_func,
crop_func,
flip_func,
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
self.test_transform = transforms.Compose([
rz_func,
crop_func,
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])
def forward(self, x, return_feat=False):
device = next(self.fc.parameters()).device # lấy device của fc
x = x.to(device) # đảm bảo input cùng device
feat = self.clip.encode_image(x)
feat = feat.to(device) # đảm bảo feat cùng device với fc
out = self.fc(feat)
if return_feat:
return feat, out
return out
def save_weights(self, weights_path):
save_params = {"fc.weight": self.fc.weight.cpu(), "fc.bias": self.fc.bias.cpu()}
torch.save(save_params, weights_path)
def load_weights(self, weights_path):
device = next(self.fc.parameters()).device # lấy device hiện tại của model
weights = torch.load(weights_path, map_location=device)
self.fc.weight.data = weights["fc.weight"].to(device)
self.fc.bias.data = weights["fc.bias"].to(device)
|