CO-SPY / calibrate_combine.py
NghiTran1009's picture
Upload clean CO-SPY project
cab012d
import os
import time
import json
import torch
import numpy as np
from tqdm import tqdm
from loguru import logger
from sklearn.metrics import average_precision_score
from Detectors import CospyCalibrateDetector
from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST
from utils import seed_torch
import warnings
warnings.filterwarnings("ignore")
class Detector():
def __init__(self, args):
super(Detector, self).__init__()
# Device
self.device = args.device
# ===== Khởi tạo model =====
self.model = CospyCalibrateDetector(
semantic_weights_path=args.semantic_weights_path,
artifact_weights_path=args.artifact_weights_path
)
self.model.to(self.device)
# Khởi tạo fc layer nếu muốn
torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
# ===== Optimizer =====
_lr = 1e-1
_beta1 = 0.9
_weight_decay = 0.0
params = [p for p in self.model.parameters() if p.requires_grad]
print(f'Trainable parameters: {len(params)}')
self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay)
# ===== Loss =====
self.criterion = torch.nn.BCEWithLogitsLoss()
# Scheduler
self.delr_freq = 10
# ===== Load checkpoint nếu có =====
if args.resume is not None:
print(f"Loading checkpoint from {args.resume}")
state = torch.load(args.resume, map_location=self.device)
# hỗ trợ cả 2 dạng: {'model': state_dict} hoặc state_dict trực tiếp
if isinstance(state, dict) and "model" in state:
state = state["model"]
self.model.load_state_dict(state, strict=False)
print("Checkpoint loaded. Continue training...")
self.model.to(self.device)
self.model.train()
# Training function for the detector
def train_step(self, batch_data):
# Decompose the batch data
inputs, labels = batch_data
inputs, labels = inputs.to(self.device), labels.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, labels.unsqueeze(1).float())
loss.backward()
self.optimizer.step()
eval_loss = loss.item()
y_pred = outputs.sigmoid().flatten().tolist()
y_true = labels.tolist()
return eval_loss, y_pred, y_true
# Schedule the training
# Early stopping / learning rate adjustment
def scheduler(self, status_dict):
epoch = status_dict['epoch']
if epoch % self.delr_freq == 0 and epoch != 0:
for param_group in self.optimizer.param_groups:
param_group['lr'] *= 0.9
self.lr = param_group['lr']
return True
# Prediction function
def predict(self, inputs):
inputs = inputs.to(self.device)
outputs = self.model(inputs)
prediction = outputs.sigmoid().flatten().tolist()
return prediction
def evaluate(y_pred, y_true):
ap = average_precision_score(y_true, y_pred)
accuracy = ((np.array(y_pred) > 0.5) == y_true).mean()
return ap, accuracy
def train(args):
# Set the saving directory **trước**
model_dir = os.path.join(args.ckpt, "cospy_calibrate")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
log_path = f"{model_dir}/training.log"
if os.path.exists(log_path):
os.remove(log_path)
logger_id = logger.add(
log_path,
format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
level="DEBUG",
)
# Get the detector
detector = Detector(args)
# --- Resume checkpoint ---
start_epoch = 0
best_acc = 0
if args.resume:
resume_path = os.path.join(model_dir, "best_model.pth")
if os.path.exists(resume_path):
print(f"Resuming from {resume_path} ...")
detector.model.load_weights(resume_path)
detector.model.to(args.device)
# Load the calibration dataset using the "val" split
train_dataset = TrainDataset(data_path=args.calibration_dirpath,
split="val",
transform=detector.model.test_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True)
logger.info(f"Train size {len(train_dataset)}")
# Set the saving directory
model_dir = os.path.join(args.ckpt, "cospy_calibrate")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
log_path = f"{model_dir}/training.log"
if os.path.exists(log_path):
os.remove(log_path)
logger_id = logger.add(
log_path,
format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}",
level="DEBUG",
)
# Train the detector
best_acc = 0
for epoch in range(start_epoch, args.epochs):
# Set the model to training mode
detector.model.train()
time_start = time.time()
for step_id, batch_data in enumerate(train_loader):
eval_loss, y_pred, y_true = detector.train_step(batch_data)
ap, accuracy = evaluate(y_pred, y_true)
# Log the training information
if (step_id + 1) % 100 == 0:
time_end = time.time()
logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s")
time_start = time.time()
# Evaluate the model
detector.model.eval()
y_pred, y_true = [], []
for inputs in train_loader:
inputs, labels = inputs
y_pred.extend(detector.predict(inputs))
y_true.extend(labels.tolist())
ap, accuracy = evaluate(y_pred, y_true)
logger.info(f"Epoch {epoch} | Total AP {ap*100:.2f}% | Total Accuracy {accuracy*100:.2f}%")
# Schedule the training
status_dict = {'epoch': epoch, 'AP': ap, 'Accuracy': accuracy}
proceed = detector.scheduler(status_dict)
if not proceed:
logger.info("Early stopping")
break
# Save the model
if accuracy >= best_acc:
best_acc = accuracy
detector.model.save_weights(f"{model_dir}/best_model.pth")
logger.info(f"Best model saved with accuracy {best_acc.mean()*100:.2f}%")
if epoch % 5 == 0:
detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth")
logger.info(f"Model saved at epoch {epoch}")
# Save the final model
detector.model.save_weights(f"{model_dir}/final_model.pth")
logger.info("Final model saved")
# Remove the logger
logger.remove(logger_id)
def test(args):
# Initialize the detector
detector = Detector(args)
# Load the [best/final] model
weights_path = os.path.join(args.ckpt, "cospy_calibrate", "best_model.pth")
detector.model.load_weights(weights_path)
detector.model.to(args.device)
detector.model.eval()
# Set the pre-processing function
test_transform = detector.model.test_transform
# Set the saving directory
save_dir = os.path.join(args.ckpt, "cospy_calibrate")
save_result_path = os.path.join(save_dir, "result.json")
save_output_path = os.path.join(save_dir, "output.json")
# Begin the evaluation
result_all = {}
output_all = {}
for dataset_name in EVAL_DATASET_LIST:
result_all[dataset_name] = {}
output_all[dataset_name] = {}
for model_name in EVAL_MODEL_LIST:
test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True)
# Evaluate the model
y_pred, y_true = [], []
for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"):
y_pred.extend(detector.predict(images))
y_true.extend(labels.tolist())
ap, accuracy = evaluate(y_pred, y_true)
print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%")
result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy}
output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true}
# Save the results
with open(save_result_path, "w") as f:
json.dump(result_all, f, indent=4)
with open(save_output_path, "w") as f:
json.dump(output_all, f, indent=4)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Deep Fake Detection")
parser.add_argument("--gpu", type=int, default=0, help="GPU ID")
parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment")
parser.add_argument("--semantic_weights_path", type=str, default="ckpt/semantic/best_model.pth", help="Semantic weights path")
parser.add_argument("--artifact_weights_path", type=str, default="ckpt/artifact/best_model.pth", help="Artifact weights path")
parser.add_argument("--calibration_dirpath", type=str, default="data/train", help="Calibration directory")
parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory")
parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory")
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--seed", type=int, default=1024, help="Random seed")
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training")
args = parser.parse_args()
# Set the random seed
seed_torch(args.seed)
# Set the GPU ID
args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
# Begin the experiment
if args.phase == "train":
train(args)
elif args.phase == "test":
test(args)
else:
raise ValueError("Unknown phase")