""" MuseTalk HTTP API Server v2 Optimized for repeated use of the same avatar. """ import os import cv2 import copy import torch import glob import shutil import pickle import numpy as np import subprocess import tempfile import hashlib import time from pathlib import Path from typing import Optional from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from tqdm import tqdm from omegaconf import OmegaConf from transformers import WhisperModel import uvicorn # MuseTalk imports from musetalk.utils.blending import get_image from musetalk.utils.face_parsing import FaceParsing from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder class MuseTalkServerV2: """Server optimized for pre-processed avatars.""" def __init__(self): self.device = None self.vae = None self.unet = None self.pe = None self.whisper = None self.audio_processor = None self.fp = None self.timesteps = None self.weight_dtype = None self.is_loaded = False # Avatar cache (in-memory) self.loaded_avatars = {} self.avatar_dir = Path("./avatars") # Config self.fps = 25 self.batch_size = 8 self.use_float16 = True self.version = "v15" self.extra_margin = 10 self.parsing_mode = "jaw" self.left_cheek_width = 90 self.right_cheek_width = 90 self.audio_padding_left = 2 self.audio_padding_right = 2 def load_models( self, gpu_id: int = 0, unet_model_path: str = "./models/musetalkV15/unet.pth", unet_config: str = "./models/musetalk/config.json", vae_type: str = "sd-vae", whisper_dir: str = "./models/whisper", use_float16: bool = True, version: str = "v15" ): if self.is_loaded: print("Models already loaded!") return print("=" * 50) print("Loading MuseTalk models into GPU memory...") print("=" * 50) start_time = time.time() self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") print("Loading VAE, UNet, PE...") self.vae, self.unet, self.pe = load_all_model( unet_model_path=unet_model_path, vae_type=vae_type, unet_config=unet_config, device=self.device ) self.timesteps = torch.tensor([0], device=self.device) self.use_float16 = use_float16 if use_float16: print("Converting to float16...") self.pe = self.pe.half() self.vae.vae = self.vae.vae.half() self.unet.model = self.unet.model.half() self.pe = self.pe.to(self.device) self.vae.vae = self.vae.vae.to(self.device) self.unet.model = self.unet.model.to(self.device) print("Loading Whisper model...") self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) self.weight_dtype = self.unet.model.dtype self.whisper = WhisperModel.from_pretrained(whisper_dir) self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() self.whisper.requires_grad_(False) self.version = version if version == "v15": self.fp = FaceParsing( left_cheek_width=self.left_cheek_width, right_cheek_width=self.right_cheek_width ) else: self.fp = FaceParsing() self.is_loaded = True print(f"Models loaded in {time.time() - start_time:.2f}s") print("=" * 50) def load_avatar(self, avatar_name: str) -> dict: """Load a preprocessed avatar into memory.""" if avatar_name in self.loaded_avatars: return self.loaded_avatars[avatar_name] avatar_path = self.avatar_dir / avatar_name if not avatar_path.exists(): raise FileNotFoundError(f"Avatar not found: {avatar_name}") print(f"Loading avatar '{avatar_name}' into memory...") t0 = time.time() avatar_data = {} # Load metadata with open(avatar_path / "metadata.pkl", 'rb') as f: avatar_data['metadata'] = pickle.load(f) # Load coords with open(avatar_path / "coords.pkl", 'rb') as f: avatar_data['coord_list'] = pickle.load(f) # Load frames with open(avatar_path / "frames.pkl", 'rb') as f: avatar_data['frame_list'] = pickle.load(f) # Load latents and convert to GPU tensors with open(avatar_path / "latents.pkl", 'rb') as f: latents_np = pickle.load(f) avatar_data['latent_list'] = [ torch.from_numpy(l).to(self.device) for l in latents_np ] # Load crop info with open(avatar_path / "crop_info.pkl", 'rb') as f: avatar_data['crop_info'] = pickle.load(f) # Load parsing data (optional) parsing_path = avatar_path / "parsing.pkl" if parsing_path.exists(): with open(parsing_path, 'rb') as f: avatar_data['parsing_data'] = pickle.load(f) self.loaded_avatars[avatar_name] = avatar_data print(f"Avatar loaded in {time.time() - t0:.2f}s") return avatar_data def unload_avatar(self, avatar_name: str): """Unload avatar from memory.""" if avatar_name in self.loaded_avatars: del self.loaded_avatars[avatar_name] torch.cuda.empty_cache() @torch.no_grad() def generate_with_avatar( self, avatar_name: str, audio_path: str, output_path: str, fps: Optional[int] = None ) -> dict: """Generate video using pre-processed avatar. Much faster!""" if not self.is_loaded: raise RuntimeError("Models not loaded!") fps = fps or self.fps timings = {} total_start = time.time() # Load avatar (cached in memory) t0 = time.time() avatar = self.load_avatar(avatar_name) timings["avatar_load"] = time.time() - t0 coord_list = avatar['coord_list'] frame_list = avatar['frame_list'] input_latent_list = avatar['latent_list'] temp_dir = tempfile.mkdtemp() try: # 1. Extract audio features (only audio-dependent step that's heavy) t0 = time.time() whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) whisper_chunks = self.audio_processor.get_whisper_chunk( whisper_input_features, self.device, self.weight_dtype, self.whisper, librosa_length, fps=fps, audio_padding_length_left=self.audio_padding_left, audio_padding_length_right=self.audio_padding_right, ) timings["whisper_features"] = time.time() - t0 # 2. Prepare cycled lists frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] # 3. UNet inference t0 = time.time() gen = datagen( whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=self.batch_size, delay_frame=0, device=self.device, ) res_frame_list = [] for whisper_batch, latent_batch in gen: audio_feature_batch = self.pe(whisper_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) pred_latents = self.unet.model( latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch ).sample recon = self.vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) timings["unet_inference"] = time.time() - t0 # 4. Face blending t0 = time.time() result_img_path = os.path.join(temp_dir, "results") os.makedirs(result_img_path, exist_ok=True) for i, res_frame in enumerate(res_frame_list): bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) x1, y1, x2, y2 = bbox if self.version == "v15": y2 = y2 + self.extra_margin y2 = min(y2, ori_frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) except: continue if self.version == "v15": combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=self.parsing_mode, fp=self.fp) else: combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) timings["face_blending"] = time.time() - t0 # 5. Encode video t0 = time.time() temp_vid = os.path.join(temp_dir, "temp.mp4") cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" os.system(cmd_img2video) cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}" os.system(cmd_combine) timings["video_encoding"] = time.time() - t0 finally: shutil.rmtree(temp_dir, ignore_errors=True) timings["total"] = time.time() - total_start timings["frames_generated"] = len(res_frame_list) return timings # Global server instance server = MuseTalkServerV2() # FastAPI app app = FastAPI( title="MuseTalk API v2", description="Optimized API for repeated avatar usage", version="2.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): server.load_models() @app.get("/health") async def health_check(): return { "status": "ok" if server.is_loaded else "loading", "models_loaded": server.is_loaded, "device": str(server.device) if server.device else None, "loaded_avatars": list(server.loaded_avatars.keys()) } @app.get("/avatars") async def list_avatars(): """List all available preprocessed avatars.""" avatars = [] for p in server.avatar_dir.iterdir(): if p.is_dir() and (p / "metadata.pkl").exists(): with open(p / "metadata.pkl", 'rb') as f: metadata = pickle.load(f) metadata['loaded'] = p.name in server.loaded_avatars avatars.append(metadata) return {"avatars": avatars} @app.post("/avatars/{avatar_name}/load") async def load_avatar(avatar_name: str): """Pre-load an avatar into GPU memory.""" try: server.load_avatar(avatar_name) return {"status": "loaded", "avatar_name": avatar_name} except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) @app.post("/avatars/{avatar_name}/unload") async def unload_avatar(avatar_name: str): """Unload an avatar from memory.""" server.unload_avatar(avatar_name) return {"status": "unloaded", "avatar_name": avatar_name} class GenerateWithAvatarRequest(BaseModel): avatar_name: str audio_path: str output_path: str fps: Optional[int] = 25 @app.post("/generate/avatar") async def generate_with_avatar(request: GenerateWithAvatarRequest): """Generate video using pre-processed avatar. FAST!""" if not server.is_loaded: raise HTTPException(status_code=503, detail="Models not loaded") if not os.path.exists(request.audio_path): raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") try: timings = server.generate_with_avatar( avatar_name=request.avatar_name, audio_path=request.audio_path, output_path=request.output_path, fps=request.fps ) return { "status": "success", "output_path": request.output_path, "timings": timings } except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate/avatar/upload") async def generate_with_avatar_upload( avatar_name: str = Form(...), audio: UploadFile = File(...), fps: int = Form(25) ): """Generate video from uploaded audio using pre-processed avatar.""" if not server.is_loaded: raise HTTPException(status_code=503, detail="Models not loaded") temp_dir = tempfile.mkdtemp() try: audio_path = os.path.join(temp_dir, audio.filename) output_path = os.path.join(temp_dir, "output.mp4") with open(audio_path, "wb") as f: f.write(await audio.read()) timings = server.generate_with_avatar( avatar_name=avatar_name, audio_path=audio_path, output_path=output_path, fps=fps ) return FileResponse( output_path, media_type="video/mp4", filename="result.mp4", headers={"X-Timings": str(timings)} ) except Exception as e: shutil.rmtree(temp_dir, ignore_errors=True) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port)