Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import json | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from loguru import logger | |
| from sklearn.metrics import average_precision_score | |
| from Detectors import CospyCalibrateDetector | |
| from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST | |
| from utils import seed_torch | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class Detector(): | |
| def __init__(self, args): | |
| super(Detector, self).__init__() | |
| # Device | |
| self.device = args.device | |
| # ===== Khởi tạo model ===== | |
| self.model = CospyCalibrateDetector( | |
| semantic_weights_path=args.semantic_weights_path, | |
| artifact_weights_path=args.artifact_weights_path | |
| ) | |
| self.model.to(self.device) | |
| # Khởi tạo fc layer nếu muốn | |
| torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02) | |
| # ===== Optimizer ===== | |
| _lr = 1e-1 | |
| _beta1 = 0.9 | |
| _weight_decay = 0.0 | |
| params = [p for p in self.model.parameters() if p.requires_grad] | |
| print(f'Trainable parameters: {len(params)}') | |
| self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay) | |
| # ===== Loss ===== | |
| self.criterion = torch.nn.BCEWithLogitsLoss() | |
| # Scheduler | |
| self.delr_freq = 10 | |
| # ===== Load checkpoint nếu có ===== | |
| if args.resume is not None: | |
| print(f"Loading checkpoint from {args.resume}") | |
| state = torch.load(args.resume, map_location=self.device) | |
| # hỗ trợ cả 2 dạng: {'model': state_dict} hoặc state_dict trực tiếp | |
| if isinstance(state, dict) and "model" in state: | |
| state = state["model"] | |
| self.model.load_state_dict(state, strict=False) | |
| print("Checkpoint loaded. Continue training...") | |
| self.model.to(self.device) | |
| self.model.train() | |
| # Training function for the detector | |
| def train_step(self, batch_data): | |
| # Decompose the batch data | |
| inputs, labels = batch_data | |
| inputs, labels = inputs.to(self.device), labels.to(self.device) | |
| self.optimizer.zero_grad() | |
| outputs = self.model(inputs) | |
| loss = self.criterion(outputs, labels.unsqueeze(1).float()) | |
| loss.backward() | |
| self.optimizer.step() | |
| eval_loss = loss.item() | |
| y_pred = outputs.sigmoid().flatten().tolist() | |
| y_true = labels.tolist() | |
| return eval_loss, y_pred, y_true | |
| # Schedule the training | |
| # Early stopping / learning rate adjustment | |
| def scheduler(self, status_dict): | |
| epoch = status_dict['epoch'] | |
| if epoch % self.delr_freq == 0 and epoch != 0: | |
| for param_group in self.optimizer.param_groups: | |
| param_group['lr'] *= 0.9 | |
| self.lr = param_group['lr'] | |
| return True | |
| # Prediction function | |
| def predict(self, inputs): | |
| inputs = inputs.to(self.device) | |
| outputs = self.model(inputs) | |
| prediction = outputs.sigmoid().flatten().tolist() | |
| return prediction | |
| def evaluate(y_pred, y_true): | |
| ap = average_precision_score(y_true, y_pred) | |
| accuracy = ((np.array(y_pred) > 0.5) == y_true).mean() | |
| return ap, accuracy | |
| def train(args): | |
| # Set the saving directory **trước** | |
| model_dir = os.path.join(args.ckpt, "cospy_calibrate") | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| log_path = f"{model_dir}/training.log" | |
| if os.path.exists(log_path): | |
| os.remove(log_path) | |
| logger_id = logger.add( | |
| log_path, | |
| format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}", | |
| level="DEBUG", | |
| ) | |
| # Get the detector | |
| detector = Detector(args) | |
| # --- Resume checkpoint --- | |
| start_epoch = 0 | |
| best_acc = 0 | |
| if args.resume: | |
| resume_path = os.path.join(model_dir, "best_model.pth") | |
| if os.path.exists(resume_path): | |
| print(f"Resuming from {resume_path} ...") | |
| detector.model.load_weights(resume_path) | |
| detector.model.to(args.device) | |
| # Load the calibration dataset using the "val" split | |
| train_dataset = TrainDataset(data_path=args.calibration_dirpath, | |
| split="val", | |
| transform=detector.model.test_transform) | |
| train_loader = torch.utils.data.DataLoader(train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True) | |
| logger.info(f"Train size {len(train_dataset)}") | |
| # Set the saving directory | |
| model_dir = os.path.join(args.ckpt, "cospy_calibrate") | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| log_path = f"{model_dir}/training.log" | |
| if os.path.exists(log_path): | |
| os.remove(log_path) | |
| logger_id = logger.add( | |
| log_path, | |
| format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}", | |
| level="DEBUG", | |
| ) | |
| # Train the detector | |
| best_acc = 0 | |
| for epoch in range(start_epoch, args.epochs): | |
| # Set the model to training mode | |
| detector.model.train() | |
| time_start = time.time() | |
| for step_id, batch_data in enumerate(train_loader): | |
| eval_loss, y_pred, y_true = detector.train_step(batch_data) | |
| ap, accuracy = evaluate(y_pred, y_true) | |
| # Log the training information | |
| if (step_id + 1) % 100 == 0: | |
| time_end = time.time() | |
| logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s") | |
| time_start = time.time() | |
| # Evaluate the model | |
| detector.model.eval() | |
| y_pred, y_true = [], [] | |
| for inputs in train_loader: | |
| inputs, labels = inputs | |
| y_pred.extend(detector.predict(inputs)) | |
| y_true.extend(labels.tolist()) | |
| ap, accuracy = evaluate(y_pred, y_true) | |
| logger.info(f"Epoch {epoch} | Total AP {ap*100:.2f}% | Total Accuracy {accuracy*100:.2f}%") | |
| # Schedule the training | |
| status_dict = {'epoch': epoch, 'AP': ap, 'Accuracy': accuracy} | |
| proceed = detector.scheduler(status_dict) | |
| if not proceed: | |
| logger.info("Early stopping") | |
| break | |
| # Save the model | |
| if accuracy >= best_acc: | |
| best_acc = accuracy | |
| detector.model.save_weights(f"{model_dir}/best_model.pth") | |
| logger.info(f"Best model saved with accuracy {best_acc.mean()*100:.2f}%") | |
| if epoch % 5 == 0: | |
| detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth") | |
| logger.info(f"Model saved at epoch {epoch}") | |
| # Save the final model | |
| detector.model.save_weights(f"{model_dir}/final_model.pth") | |
| logger.info("Final model saved") | |
| # Remove the logger | |
| logger.remove(logger_id) | |
| def test(args): | |
| # Initialize the detector | |
| detector = Detector(args) | |
| # Load the [best/final] model | |
| weights_path = os.path.join(args.ckpt, "cospy_calibrate", "best_model.pth") | |
| detector.model.load_weights(weights_path) | |
| detector.model.to(args.device) | |
| detector.model.eval() | |
| # Set the pre-processing function | |
| test_transform = detector.model.test_transform | |
| # Set the saving directory | |
| save_dir = os.path.join(args.ckpt, "cospy_calibrate") | |
| save_result_path = os.path.join(save_dir, "result.json") | |
| save_output_path = os.path.join(save_dir, "output.json") | |
| # Begin the evaluation | |
| result_all = {} | |
| output_all = {} | |
| for dataset_name in EVAL_DATASET_LIST: | |
| result_all[dataset_name] = {} | |
| output_all[dataset_name] = {} | |
| for model_name in EVAL_MODEL_LIST: | |
| test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform) | |
| test_loader = torch.utils.data.DataLoader(test_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True) | |
| # Evaluate the model | |
| y_pred, y_true = [], [] | |
| for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"): | |
| y_pred.extend(detector.predict(images)) | |
| y_true.extend(labels.tolist()) | |
| ap, accuracy = evaluate(y_pred, y_true) | |
| print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%") | |
| result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy} | |
| output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true} | |
| # Save the results | |
| with open(save_result_path, "w") as f: | |
| json.dump(result_all, f, indent=4) | |
| with open(save_output_path, "w") as f: | |
| json.dump(output_all, f, indent=4) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser("Deep Fake Detection") | |
| parser.add_argument("--gpu", type=int, default=0, help="GPU ID") | |
| parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment") | |
| parser.add_argument("--semantic_weights_path", type=str, default="ckpt/semantic/best_model.pth", help="Semantic weights path") | |
| parser.add_argument("--artifact_weights_path", type=str, default="ckpt/artifact/best_model.pth", help="Artifact weights path") | |
| parser.add_argument("--calibration_dirpath", type=str, default="data/train", help="Calibration directory") | |
| parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory") | |
| parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory") | |
| parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") | |
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size") | |
| parser.add_argument("--seed", type=int, default=1024, help="Random seed") | |
| parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training") | |
| args = parser.parse_args() | |
| # Set the random seed | |
| seed_torch(args.seed) | |
| # Set the GPU ID | |
| args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" | |
| # Begin the experiment | |
| if args.phase == "train": | |
| train(args) | |
| elif args.phase == "test": | |
| test(args) | |
| else: | |
| raise ValueError("Unknown phase") | |