Spaces:
Paused
Paused
| import gc | |
| import os | |
| from PIL import Image | |
| import json | |
| import random | |
| import cv2 | |
| import einops | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from pytorch_lightning import seed_everything | |
| from annotator.util import resize_image, HWC3 | |
| from torch.nn.functional import threshold, normalize, interpolate | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
| from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation | |
| from einops import rearrange, repeat | |
| import argparse | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # parse= argparse.ArgumentParser() | |
| # parseadd_argument('--pretrained_model', type=str, default='runwayml/stable-diffusion-v1-5') | |
| # parseadd_argument('--controlnet', type=str, default='controlnet') | |
| # parseadd_argument('--precision', type=str, default='fp32') | |
| # = parseparse_) | |
| # pretrained_model = pretrained_model | |
| pretrained_model = 'runwayml/stable-diffusion-v1-5' | |
| controlnet = 'models' | |
| # controlnet = 'checkpoint-34000/controlnet' | |
| precision = 'fp16' | |
| # Check for different hardware architectures | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| # # Check for xformers | |
| # try: | |
| # import xformers | |
| # enable_xformers = True | |
| # except ImportError: | |
| # enable_xformers = False | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| print(f"Using device: {device}") | |
| # Load models | |
| if precision == 'fp32': | |
| torch_dtype = torch.float32 | |
| elif precision == 'fp16': | |
| torch_dtype = torch.float16 | |
| elif precision == 'bf16': | |
| torch_dtype = torch.bfloat16 | |
| else: | |
| raise ValueError(f"Invalid precision: {precision}") | |
| controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch_dtype, use_safetensors=True) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| pretrained_model, controlnet=controlnet, torch_dtype=torch_dtype | |
| ) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to(device) | |
| # Apply optimizations based on hardware | |
| if device == "cuda": | |
| pipe = pipe.to(device) | |
| # if enable_xformers: | |
| # pipe.enable_xformers_memory_efficient_attention() | |
| # print("xformers optimization enabled") | |
| pipe.enable_attention_slicing() | |
| elif device == "mps": | |
| pipe = pipe.to(device) | |
| pipe.enable_attention_slicing() | |
| print("Attention slicing enabled for Apple Silicon") | |
| else: | |
| # CPU-specific optimizations | |
| pipe = pipe.to(device) | |
| # pipe.enable_sequential_cpu_offload() | |
| pipe.enable_attention_slicing() | |
| pipe.safety_checker = None | |
| pipe.requires_safety_checker = False | |
| feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing") | |
| segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing") | |
| def LGB_TO_RGB(gray_image, rgb_image): | |
| # gray_image [H, W, 3] | |
| # rgb_image [H, W, 3] | |
| # print("gray_image shape: ", gray_image.shape) | |
| # print("rgb_image shape: ", rgb_image.shape) | |
| gray_image = cv2.cvtColor(gray_image, cv2.COLOR_RGB2GRAY) | |
| lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB) | |
| lab_image[:, :, 0] = gray_image[:, :] | |
| return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB) | |
| def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, | |
| guidance_scale, seed, eta, threshold, save_memory=False): | |
| with torch.no_grad(): | |
| img = resize_image(input_image, image_resolution) | |
| H, W, C = img.shape | |
| gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| gray_img = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB) | |
| control = Image.fromarray(img) | |
| control = control.convert('L') | |
| if a_prompt: | |
| prompt = prompt + ', ' + a_prompt | |
| if seed == -1: | |
| seed = random.randint(0, 65535) | |
| seed_everything(seed) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| # Generate images | |
| output = pipe( | |
| num_images_per_prompt=num_samples, | |
| prompt=prompt, | |
| image=control, | |
| negative_prompt=n_prompt, | |
| num_inference_steps=ddim_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| eta=eta, | |
| strength=strength, | |
| output_type='np', | |
| ).images | |
| # output = einops.rearrange(output, 'b c h w -> b h w c') | |
| output = (output * 127.5 + 127.5).clip(0, 255).astype(np.uint8) | |
| results = [output[i] for i in range(num_samples)] | |
| results = [LGB_TO_RGB(gray_img, result) for result in results] | |
| # results의 각 이미지를 mask로 변환 | |
| masks = [] | |
| for result in results: | |
| inputs = feature_extractor(images=result, return_tensors="pt") | |
| outputs = segmodel(**inputs) | |
| logits = outputs.logits | |
| logits = logits.squeeze(0) | |
| thresholded = torch.zeros_like(logits) | |
| thresholded[logits > threshold] = 1 | |
| mask = thresholded[1:, :, :].sum(dim=0) | |
| mask = mask.unsqueeze(0).unsqueeze(0) | |
| mask = interpolate(mask, size=(H, W), mode='bilinear') | |
| mask = mask.detach().numpy() | |
| mask = np.squeeze(mask) | |
| mask = np.where(mask > threshold, 1, 0) | |
| masks.append(mask) | |
| # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환. | |
| # img를 channel이 3인 rgb 이미지로 변환 | |
| final = [gray_img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)] | |
| # mask to 255 img | |
| mask_img = [mask * 255 for mask in masks] | |
| gc.collect() | |
| return [gray_img] + results + mask_img + final | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown("## Control Stable Diffusion with Gray Image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(sources=['upload'], type="numpy") | |
| prompt = gr.Textbox(label="Prompt") | |
| run_button = gr.Button(value="Run") | |
| with gr.Accordion("Advanced options", open=False): | |
| num_samples = gr.Slider(label="Images", minimum=1, maximum=1, value=1, step=1, visible=False) | |
| # num_samples = 1 | |
| image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=512, value=512, step=64) | |
| strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
| # guess_mode = gr.Checkbox(label='Guess Mode', value=False) | |
| ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=20, value=20, step=1) | |
| scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.0, step=0.1) | |
| threshold = gr.Slider(label="Segmentation Threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1) | |
| eta = gr.Number(label="eta (DDIM)", value=0.0) | |
| a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, vivid colors') | |
| n_prompt = gr.Textbox(label="Negative Prompt", | |
| value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
| with gr.Column(): | |
| # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
| result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery") | |
| ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed, | |
| eta, threshold] | |
| run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=2) | |
| block.queue(max_size=100) | |
| block.launch(share=True) | |