diff --git a/extensions-builtin/forge_space_iclight/forge_app.py b/extensions-builtin/forge_space_iclight/forge_app.py new file mode 100644 index 00000000..bdfeac6a --- /dev/null +++ b/extensions-builtin/forge_space_iclight/forge_app.py @@ -0,0 +1,452 @@ +import spaces +import math +import gradio as gr +import numpy as np +import torch +import safetensors.torch as sf +import db_examples + +from PIL import Image +from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline +from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler +from diffusers.models.attention_processor import AttnProcessor2_0 +from transformers import CLIPTextModel, CLIPTokenizer +from briarmbg import BriaRMBG +from enum import Enum +# from torch.hub import download_url_to_file + + +with spaces.GPUObject() as gpu_object: + # 'stablediffusionapi/realistic-vision-v51' + # 'runwayml/stable-diffusion-v1-5' + sd15_name = 'stablediffusionapi/realistic-vision-v51' + tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet") + rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4") + +spaces.automatically_move_to_gpu_when_forward(rmbg) + +# Change UNet + +with torch.no_grad(): + new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + new_conv_in.bias = unet.conv_in.bias + unet.conv_in = new_conv_in + +unet_original_forward = unet.forward + + +def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): + c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) + c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) + new_sample = torch.cat([sample, c_concat], dim=1) + kwargs['cross_attention_kwargs'] = {} + return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) + + +unet.forward = hooked_unet_forward + +# Load + +model_path = spaces.convert_root_path() + 'models/iclight_sd15_fc.safetensors' +# download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path) +sd_offset = sf.load_file(model_path) +sd_origin = unet.state_dict() +keys = sd_origin.keys() +sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} +unet.load_state_dict(sd_merged, strict=True) +del sd_offset, sd_origin, sd_merged, keys + +# Device + +device = spaces.cpu +text_encoder = text_encoder.to(device=device, dtype=torch.float16) +vae = vae.to(device=device, dtype=torch.bfloat16) +unet = unet.to(device=device, dtype=torch.float16) +rmbg = rmbg.to(device=device, dtype=torch.float32) +device = spaces.gpu + +# SDP + +unet.set_attn_processor(AttnProcessor2_0()) +vae.set_attn_processor(AttnProcessor2_0()) + +# Samplers + +ddim_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, +) + +euler_a_scheduler = EulerAncestralDiscreteScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + steps_offset=1 +) + +dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + algorithm_type="sde-dpmsolver++", + use_karras_sigmas=True, + steps_offset=1 +) + +# Pipelines + +t2i_pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=dpmpp_2m_sde_karras_scheduler, + safety_checker=None, + requires_safety_checker=False, + feature_extractor=None, + image_encoder=None +) + +i2i_pipe = StableDiffusionImg2ImgPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=dpmpp_2m_sde_karras_scheduler, + safety_checker=None, + requires_safety_checker=False, + feature_extractor=None, + image_encoder=None +) + +spaces.automatically_move_pipeline_components(t2i_pipe) + + +@torch.inference_mode() +def encode_prompt_inner(txt: str): + max_length = tokenizer.model_max_length + chunk_length = tokenizer.model_max_length - 2 + id_start = tokenizer.bos_token_id + id_end = tokenizer.eos_token_id + id_pad = id_end + + def pad(x, p, i): + return x[:i] if len(x) >= i else x + [p] * (i - len(x)) + + tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"] + chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)] + chunks = [pad(ck, id_pad, max_length) for ck in chunks] + + token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64) + conds = text_encoder(token_ids).last_hidden_state + + return conds + + +@torch.inference_mode() +def encode_prompt_pair(positive_prompt, negative_prompt): + c = encode_prompt_inner(positive_prompt) + uc = encode_prompt_inner(negative_prompt) + + c_len = float(len(c)) + uc_len = float(len(uc)) + max_count = max(c_len, uc_len) + c_repeat = int(math.ceil(max_count / c_len)) + uc_repeat = int(math.ceil(max_count / uc_len)) + max_chunk = max(len(c), len(uc)) + + c = torch.cat([c] * c_repeat, dim=0)[:max_chunk] + uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk] + + c = torch.cat([p[None, ...] for p in c], dim=1) + uc = torch.cat([p[None, ...] for p in uc], dim=1) + + return c, uc + + +@torch.inference_mode() +def pytorch2numpy(imgs, quant=True): + results = [] + for x in imgs: + y = x.movedim(0, -1) + + if quant: + y = y * 127.5 + 127.5 + y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) + else: + y = y * 0.5 + 0.5 + y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32) + + results.append(y) + return results + + +@torch.inference_mode() +def numpy2pytorch(imgs): + h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0 + h = h.movedim(-1, 1) + return h + + +def resize_and_center_crop(image, target_width, target_height): + pil_image = Image.fromarray(image) + original_width, original_height = pil_image.size + scale_factor = max(target_width / original_width, target_height / original_height) + resized_width = int(round(original_width * scale_factor)) + resized_height = int(round(original_height * scale_factor)) + resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) + left = (resized_width - target_width) / 2 + top = (resized_height - target_height) / 2 + right = (resized_width + target_width) / 2 + bottom = (resized_height + target_height) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + return np.array(cropped_image) + + +def resize_without_crop(image, target_width, target_height): + pil_image = Image.fromarray(image) + resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) + return np.array(resized_image) + + +@torch.inference_mode() +def run_rmbg(img, sigma=0.0): + H, W, C = img.shape + assert C == 3 + k = (256.0 / float(H * W)) ** 0.5 + feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k))) + feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32) + alpha = rmbg(feed)[0][0] + alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear") + alpha = alpha.movedim(1, -1)[0] + alpha = alpha.detach().float().cpu().numpy().clip(0, 1) + result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha + return result.clip(0, 255).astype(np.uint8), alpha + + +@torch.inference_mode() +def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source): + bg_source = BGSource(bg_source) + input_bg = None + + if bg_source == BGSource.NONE: + pass + elif bg_source == BGSource.LEFT: + gradient = np.linspace(255, 0, image_width) + image = np.tile(gradient, (image_height, 1)) + input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) + elif bg_source == BGSource.RIGHT: + gradient = np.linspace(0, 255, image_width) + image = np.tile(gradient, (image_height, 1)) + input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) + elif bg_source == BGSource.TOP: + gradient = np.linspace(255, 0, image_height)[:, None] + image = np.tile(gradient, (1, image_width)) + input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) + elif bg_source == BGSource.BOTTOM: + gradient = np.linspace(0, 255, image_height)[:, None] + image = np.tile(gradient, (1, image_width)) + input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) + else: + raise 'Wrong initial latent!' + + rng = torch.Generator(device=device).manual_seed(int(seed)) + + fg = resize_and_center_crop(input_fg, image_width, image_height) + + concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype) + concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor + + conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt) + + if input_bg is None: + latents = t2i_pipe( + prompt_embeds=conds, + negative_prompt_embeds=unconds, + width=image_width, + height=image_height, + num_inference_steps=steps, + num_images_per_prompt=num_samples, + generator=rng, + output_type='latent', + guidance_scale=cfg, + cross_attention_kwargs={'concat_conds': concat_conds}, + ).images.to(vae.dtype) / vae.config.scaling_factor + else: + bg = resize_and_center_crop(input_bg, image_width, image_height) + bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype) + bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor + latents = i2i_pipe( + image=bg_latent, + strength=lowres_denoise, + prompt_embeds=conds, + negative_prompt_embeds=unconds, + width=image_width, + height=image_height, + num_inference_steps=int(round(steps / lowres_denoise)), + num_images_per_prompt=num_samples, + generator=rng, + output_type='latent', + guidance_scale=cfg, + cross_attention_kwargs={'concat_conds': concat_conds}, + ).images.to(vae.dtype) / vae.config.scaling_factor + + pixels = vae.decode(latents).sample + pixels = pytorch2numpy(pixels) + pixels = [resize_without_crop( + image=p, + target_width=int(round(image_width * highres_scale / 64.0) * 64), + target_height=int(round(image_height * highres_scale / 64.0) * 64)) + for p in pixels] + + pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype) + latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor + latents = latents.to(device=unet.device, dtype=unet.dtype) + + image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8 + + fg = resize_and_center_crop(input_fg, image_width, image_height) + concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype) + concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor + + latents = i2i_pipe( + image=latents, + strength=highres_denoise, + prompt_embeds=conds, + negative_prompt_embeds=unconds, + width=image_width, + height=image_height, + num_inference_steps=int(round(steps / highres_denoise)), + num_images_per_prompt=num_samples, + generator=rng, + output_type='latent', + guidance_scale=cfg, + cross_attention_kwargs={'concat_conds': concat_conds}, + ).images.to(vae.dtype) / vae.config.scaling_factor + + pixels = vae.decode(latents).sample + + return pytorch2numpy(pixels) + + +@spaces.GPU(gpu_objects=[gpu_object], manual_load=True) +@torch.inference_mode() +def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source): + input_fg, matting = run_rmbg(input_fg) + results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source) + return input_fg, results + + +quick_prompts = [ + 'sunshine from window', + 'neon light, city', + 'sunset over sea', + 'golden time', + 'sci-fi RGB glowing, cyberpunk', + 'natural lighting', + 'warm atmosphere, at home, bedroom', + 'magic lit', + 'evil, gothic, Yharnam', + 'light and shadow', + 'shadow from window', + 'soft studio lighting', + 'home atmosphere, cozy bedroom illumination', + 'neon, Wong Kar-wai, warm' +] +quick_prompts = [[x] for x in quick_prompts] + + +quick_subjects = [ + 'beautiful woman, detailed face', + 'handsome man, detailed face', +] +quick_subjects = [[x] for x in quick_subjects] + + +class BGSource(Enum): + NONE = "None" + LEFT = "Left Light" + RIGHT = "Right Light" + TOP = "Top Light" + BOTTOM = "Bottom Light" + + +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## IC-Light (Relighting with Foreground Condition)") + with gr.Row(): + gr.Markdown("See also https://github.com/lllyasviel/IC-Light for background-conditioned model and normal estimation") + with gr.Row(): + with gr.Column(): + with gr.Row(): + input_fg = gr.Image(sources='upload', type="numpy", label="Image", height=480) + output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480) + prompt = gr.Textbox(label="Prompt") + bg_source = gr.Radio(choices=[e.value for e in BGSource], + value=BGSource.NONE.value, + label="Lighting Preference (Initial Latent)", type='value') + example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt]) + example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt]) + relight_button = gr.Button(value="Relight") + + with gr.Group(): + with gr.Row(): + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + seed = gr.Number(label="Seed", value=12345, precision=0) + + with gr.Row(): + image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64) + image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64) + + with gr.Accordion("Advanced options", open=False): + steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1) + cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01) + lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01) + highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01) + highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality') + n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality') + with gr.Column(): + result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs') + with gr.Row(): + dummy_image_for_outputs = gr.Image(visible=False, label='Result') + + examples = [] + + for ex in db_examples.foreground_conditioned_examples: + ex[0] = spaces.convert_root_path() + ex[0] + ex[-1] = spaces.convert_root_path() + ex[-1] + examples.append(ex) + + + gr.Examples( + fn=lambda *args: [[args[-1]], "imgs/dummy.png"], + examples=examples, + inputs=[ + input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs + ], + outputs=[result_gallery, output_bg], + run_on_click=True, examples_per_page=1024 + ) + ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source] + relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery]) + example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False) + example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False) + + +demo = block + + +if __name__ == "__main__": + demo.launch() diff --git a/extensions-builtin/forge_space_iclight/space_meta.json b/extensions-builtin/forge_space_iclight/space_meta.json new file mode 100644 index 00000000..6bdf9dbe --- /dev/null +++ b/extensions-builtin/forge_space_iclight/space_meta.json @@ -0,0 +1,6 @@ +{ + "tag": "Image Processing: Illumination, Shading, and Relighting", + "title": "IC-Light: Imposing Consistent Light (Foreground Model)", + "repo_id": "lllyasviel/IC-Light", + "revision": "3072485a2b69261ab332fdf2e255d0eca35b323b" +}