From 2292f9a1008c4e052595404554db5172fc84fee5 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sun, 18 Aug 2024 01:47:27 -0700 Subject: [PATCH] Technically Correct PhotoMaker V2 that can actually reproduce author results, and be used in serious research and writing academic papers --- .../forge_space_photo_maker_v2/forge_app.py | 446 ++++++++++++++++++ .../space_meta.json | 6 + 2 files changed, 452 insertions(+) create mode 100644 extensions-builtin/forge_space_photo_maker_v2/forge_app.py create mode 100644 extensions-builtin/forge_space_photo_maker_v2/space_meta.json diff --git a/extensions-builtin/forge_space_photo_maker_v2/forge_app.py b/extensions-builtin/forge_space_photo_maker_v2/forge_app.py new file mode 100644 index 00000000..306a150e --- /dev/null +++ b/extensions-builtin/forge_space_photo_maker_v2/forge_app.py @@ -0,0 +1,446 @@ +import spaces + +import torch +import torchvision.transforms.functional as TF +import numpy as np +import random +import os +import sys + +from diffusers.utils import load_image +from diffusers import EulerDiscreteScheduler, T2IAdapter + +from huggingface_hub import hf_hub_download + +import gradio as gr + +from pipeline_t2i_adapter import PhotoMakerStableDiffusionXLAdapterPipeline +from face_utils import FaceAnalysis2, analyze_faces + +from style_template import styles +from aspect_ratio_template import aspect_ratios + +# global variable +base_model_path = 'SG161222/RealVisXL_V4.0' +face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition']) +face_detector.prepare(ctx_id=0, det_size=(640, 640)) + +device = "cpu" + +MAX_SEED = np.iinfo(np.int32).max +STYLE_NAMES = list(styles.keys()) +DEFAULT_STYLE_NAME = "Photographic (Default)" +ASPECT_RATIO_LABELS = list(aspect_ratios) +DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0] + +enable_doodle_arg = False +photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model") + +if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + torch_dtype = torch.bfloat16 +else: + torch_dtype = torch.float16 + +if device == "mps": + torch_dtype = torch.float16 + + +with spaces.GPUObject() as gpu_object: + # load adapter + adapter = T2IAdapter.from_pretrained( + "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch_dtype, variant="fp16" + ).to(device) + + pipe = PhotoMakerStableDiffusionXLAdapterPipeline.from_pretrained( + base_model_path, + adapter=adapter, + torch_dtype=torch_dtype, + use_safetensors=True, + variant="fp16", + ).to(device) + + pipe.load_photomaker_adapter( + os.path.dirname(photomaker_ckpt), + subfolder="", + weight_name=os.path.basename(photomaker_ckpt), + trigger_word="img", + pm_version="v2", + ) + pipe.id_encoder.to(device) + + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + # pipe.set_adapters(["photomaker"], adapter_weights=[1.0]) + pipe.unet.to(spaces.gpu) + pipe.fuse_lora() + pipe.to(spaces.cpu) + + +spaces.automatically_move_pipeline_components(pipe) +spaces.change_attention_from_diffusers_to_forge(pipe.unet) +spaces.change_attention_from_diffusers_to_forge(pipe.vae) + + +@spaces.GPU(gpu_objects=[gpu_object], manual_load=True) +def generate_image( + upload_images, + prompt, + negative_prompt, + aspect_ratio_name, + style_name, + num_steps, + style_strength_ratio, + num_outputs, + guidance_scale, + seed, + use_doodle, + sketch_image, + adapter_conditioning_scale, + adapter_conditioning_factor, + progress=gr.Progress(track_tqdm=True) +): + if use_doodle: + sketch_image = sketch_image["composite"] + r, g, b, a = sketch_image.split() + sketch_image = a.convert("RGB") + sketch_image = TF.to_tensor(sketch_image) > 0.5 # Inversion + sketch_image = TF.to_pil_image(sketch_image.to(torch.float32)) + adapter_conditioning_scale = adapter_conditioning_scale + adapter_conditioning_factor = adapter_conditioning_factor + else: + adapter_conditioning_scale = 0. + adapter_conditioning_factor = 0. + sketch_image = None + + # check the trigger word + image_token_id = pipe.tokenizer.convert_tokens_to_ids(pipe.trigger_word) + input_ids = pipe.tokenizer.encode(prompt) + if image_token_id not in input_ids: + raise gr.Error(f"Cannot find the trigger word '{pipe.trigger_word}' in text prompt! Please refer to step 2️⃣") + + if input_ids.count(image_token_id) > 1: + raise gr.Error(f"Cannot use multiple trigger words '{pipe.trigger_word}' in text prompt!") + + # determine output dimensions by the aspect ratio + output_w, output_h = aspect_ratios[aspect_ratio_name] + print(f"[Debug] Generate image using aspect ratio [{aspect_ratio_name}] => {output_w} x {output_h}") + + # apply the style template + prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) + + if upload_images is None: + raise gr.Error(f"Cannot find any input face image! Please refer to step 1️⃣") + + input_id_images = [] + for img in upload_images: + input_id_images.append(load_image(img)) + + id_embed_list = [] + + for img in input_id_images: + img = np.array(img) + img = img[:, :, ::-1] + faces = analyze_faces(face_detector, img) + if len(faces) > 0: + id_embed_list.append(torch.from_numpy((faces[0]['embedding']))) + + if len(id_embed_list) == 0: + raise gr.Error(f"No face detected, please update the input face image(s)") + + id_embeds = torch.stack(id_embed_list) + + generator = torch.Generator(device=device).manual_seed(seed) + + print("Start inference...") + print(f"[Debug] Seed: {seed}") + print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") + start_merge_step = int(float(style_strength_ratio) / 100 * num_steps) + if start_merge_step > 30: + start_merge_step = 30 + print(start_merge_step) + images = pipe( + prompt=prompt, + width=output_w, + height=output_h, + input_id_images=input_id_images, + negative_prompt=negative_prompt, + num_images_per_prompt=num_outputs, + num_inference_steps=num_steps, + start_merge_step=start_merge_step, + generator=generator, + guidance_scale=guidance_scale, + id_embeds=id_embeds, + image=sketch_image, + adapter_conditioning_scale=adapter_conditioning_scale, + adapter_conditioning_factor=adapter_conditioning_factor, + ).images + return images, gr.update(visible=True) + + +def swap_to_gallery(images): + return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) + + +def upload_example_to_gallery(images, prompt, style, negative_prompt): + return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) + + +def remove_back_to_files(): + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) + + +def change_doodle_space(use_doodle): + if use_doodle: + return gr.update(visible=True) + else: + return gr.update(visible=False) + + +def remove_tips(): + return gr.update(visible=False) + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + + +def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]: + p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) + return p.replace("{prompt}", positive), n + ' ' + negative + + +def get_image_path_list(folder_name): + image_basename_list = os.listdir(folder_name) + image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) + return image_path_list + + +def get_example(): + case = [ + [ + get_image_path_list(spaces.convert_root_path() + 'examples/scarletthead_woman'), + "instagram photo, portrait photo of a woman img, colorful, perfect face, natural skin, hard shadows, film grain", + "(No style)", + "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", + ], + [ + get_image_path_list(spaces.convert_root_path() + 'examples/newton_man'), + "sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain", + "(No style)", + "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", + ], + ] + return case + + +### Description and style +logo = r""" +
