From 8ee804a1599ac60c8dd1bfb07c29fbcbd3e7c262 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 29 Jan 2024 23:17:57 -0800 Subject: [PATCH] lllllllite --- .../lib_controllllite/lib_controllllite.py | 271 ++++++++++++++++++ .../scripts/forge_controllllite.py | 38 +++ .../scripts/forge_ipadapter.py | 8 +- modules_forge/supported_controlnet.py | 2 +- 4 files changed, 314 insertions(+), 5 deletions(-) create mode 100644 extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py create mode 100644 extensions-builtin/sd_forge_controlllite/scripts/forge_controllllite.py diff --git a/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py new file mode 100644 index 00000000..73c45921 --- /dev/null +++ b/extensions-builtin/sd_forge_controlllite/lib_controllllite/lib_controllllite.py @@ -0,0 +1,271 @@ +import math +import torch +import os + +import ldm_patched.modules + + +def extra_options_to_module_prefix(extra_options): + # extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64} + + # block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0), + # ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)] + # transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block + # block_index is: 0-1 or 0-9, depends on the block + # input 7 and 8, middle has 10 blocks + + # make module name from extra_options + block = extra_options["block"] + block_index = extra_options["block_index"] + if block[0] == "input": + module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}" + elif block[0] == "middle": + module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}" + elif block[0] == "output": + module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}" + else: + raise Exception("invalid block name") + return module_pfx + + +def load_control_net_lllite_patch(ctrl_sd, cond_image, multiplier, num_steps, start_percent, end_percent): + # calculate start and end step + start_step = math.floor(num_steps * start_percent) if start_percent > 0 else 0 + end_step = math.floor(num_steps * end_percent) if end_percent > 0 else num_steps + + # split each weights for each module + module_weights = {} + for key, value in ctrl_sd.items(): + fragments = key.split(".") + module_name = fragments[0] + weight_name = ".".join(fragments[1:]) + + if module_name not in module_weights: + module_weights[module_name] = {} + module_weights[module_name][weight_name] = value + + # load each module + modules = {} + for module_name, weights in module_weights.items(): + # ここの自動判定を何とかしたい + if "conditioning1.4.weight" in weights: + depth = 3 + elif weights["conditioning1.2.weight"].shape[-1] == 4: + depth = 2 + else: + depth = 1 + + module = LLLiteModule( + name=module_name, + is_conv2d=weights["down.0.weight"].ndim == 4, + in_dim=weights["down.0.weight"].shape[1], + depth=depth, + cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2, + mlp_dim=weights["down.0.weight"].shape[0], + multiplier=multiplier, + num_steps=num_steps, + start_step=start_step, + end_step=end_step, + ) + info = module.load_state_dict(weights) + modules[module_name] = module + if len(modules) == 1: + module.is_first = True + + print(f"{len(modules)} modules") + + # cond imageをセットする + cond_image = cond_image.permute(0, 3, 1, 2) # b,h,w,3 -> b,3,h,w + cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-+1 + + for module in modules.values(): + module.set_cond_image(cond_image) + + class control_net_lllite_patch: + def __init__(self, modules): + self.modules = modules + + def __call__(self, q, k, v, extra_options): + module_pfx = extra_options_to_module_prefix(extra_options) + + is_attn1 = q.shape[-1] == k.shape[-1] # self attention + if is_attn1: + module_pfx = module_pfx + "_attn1" + else: + module_pfx = module_pfx + "_attn2" + + module_pfx_to_q = module_pfx + "_to_q" + module_pfx_to_k = module_pfx + "_to_k" + module_pfx_to_v = module_pfx + "_to_v" + + if module_pfx_to_q in self.modules: + q = q + self.modules[module_pfx_to_q](q) + if module_pfx_to_k in self.modules: + k = k + self.modules[module_pfx_to_k](k) + if module_pfx_to_v in self.modules: + v = v + self.modules[module_pfx_to_v](v) + + return q, k, v + + def to(self, device): + for d in self.modules.keys(): + self.modules[d] = self.modules[d].to(device) + return self + + return control_net_lllite_patch(modules) + + +class LLLiteModule(torch.nn.Module): + def __init__( + self, + name: str, + is_conv2d: bool, + in_dim: int, + depth: int, + cond_emb_dim: int, + mlp_dim: int, + multiplier: int, + num_steps: int, + start_step: int, + end_step: int, + ): + super().__init__() + self.name = name + self.is_conv2d = is_conv2d + self.multiplier = multiplier + self.num_steps = num_steps + self.start_step = start_step + self.end_step = end_step + self.is_first = False + + modules = [] + modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2 + if depth == 1: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + elif depth == 2: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) + elif depth == 3: + # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + + self.conditioning1 = torch.nn.Sequential(*modules) + + if self.is_conv2d: + self.down = torch.nn.Sequential( + torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), + ) + else: + self.down = torch.nn.Sequential( + torch.nn.Linear(in_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Linear(mlp_dim, in_dim), + ) + + self.depth = depth + self.cond_image = None + self.cond_emb = None + self.current_step = 0 + + # @torch.inference_mode() + def set_cond_image(self, cond_image): + # print("set_cond_image", self.name) + self.cond_image = cond_image + self.cond_emb = None + self.current_step = 0 + + def forward(self, x): + if self.num_steps > 0: + if self.current_step < self.start_step: + self.current_step += 1 + return torch.zeros_like(x) + elif self.current_step >= self.end_step: + if self.is_first and self.current_step == self.end_step: + print(f"end LLLite: step {self.current_step}") + self.current_step += 1 + if self.current_step >= self.num_steps: + self.current_step = 0 # reset + return torch.zeros_like(x) + else: + if self.is_first and self.current_step == self.start_step: + print(f"start LLLite: step {self.current_step}") + self.current_step += 1 + if self.current_step >= self.num_steps: + self.current_step = 0 # reset + + if self.cond_emb is None: + # print(f"cond_emb is None, {self.name}") + cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype)) + if not self.is_conv2d: + # reshape / b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + self.cond_emb = cx + + cx = self.cond_emb + # print(f"forward {self.name}, {cx.shape}, {x.shape}") + + # uncond/condでxはバッチサイズが2倍 + if x.shape[0] != cx.shape[0]: + if self.is_conv2d: + cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1) + else: + # print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0]) + cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1) + + cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) + cx = self.mid(cx) + cx = self.up(cx) + return cx * self.multiplier + + +class LLLiteLoader: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "model_name": None, + "cond_image": ("IMAGE",), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "steps": ("INT", {"default": 0, "min": 0, "max": 200, "step": 1}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), + "end_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lllite" + CATEGORY = "loaders" + + def load_lllite(self, model, state_dict, cond_image, strength, steps, start_percent, end_percent): + # cond_image is b,h,w,3, 0-1 + + model_lllite = model.clone() + patch = load_control_net_lllite_patch(state_dict, cond_image, strength, steps, start_percent, end_percent) + if patch is not None: + model_lllite.set_model_attn1_patch(patch) + model_lllite.set_model_attn2_patch(patch) + + return (model_lllite,) diff --git a/extensions-builtin/sd_forge_controlllite/scripts/forge_controllllite.py b/extensions-builtin/sd_forge_controlllite/scripts/forge_controllllite.py new file mode 100644 index 00000000..3f347264 --- /dev/null +++ b/extensions-builtin/sd_forge_controlllite/scripts/forge_controllllite.py @@ -0,0 +1,38 @@ +from modules_forge.shared import add_supported_control_model +from modules_forge.supported_controlnet import ControlModelPatcher +from lib_controllllite.lib_controllllite import LLLiteLoader + + +opLLLiteLoader = LLLiteLoader().load_lllite + + +class ControlLLLitePatcher(ControlModelPatcher): + @staticmethod + def try_build_from_state_dict(state_dict, ckpt_path): + if not any('lllite' in k for k in state_dict.keys()): + return None + return ControlLLLitePatcher(state_dict) + + def __init__(self, state_dict): + super().__init__() + self.state_dict = state_dict + return + + def process_before_every_sampling(self, process, cond, *args, **kwargs): + unet = process.sd_model.forge_objects.unet + + unet = opLLLiteLoader( + model=unet, + state_dict=self.state_dict, + cond_image=cond.movedim(1, -1), + strength=self.strength, + steps=process.steps, + start_percent=self.start_percent, + end_percent=self.end_percent + )[0] + + process.sd_model.forge_objects.unet = unet + return + + +add_supported_control_model(ControlLLLitePatcher) diff --git a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py index c3398f41..30488a5e 100644 --- a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py +++ b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py @@ -25,9 +25,9 @@ class IPAdapterPatcher(ControlModelPatcher): return IPAdapterPatcher(model) - def __init__(self, model_patcher): - super().__init__(model_patcher) - self.ipadapter = model_patcher + def __init__(self, state_dict): + super().__init__() + self.ip_adapter = state_dict return def process_before_every_sampling(self, process, cond, *args, **kwargs): @@ -35,7 +35,7 @@ class IPAdapterPatcher(ControlModelPatcher): unet = process.sd_model.forge_objects.unet unet = opIPAdapterApply( - ipadapter=self.ipadapter, + ipadapter=self.ip_adapter, model=unet, weight=self.strength, clip_vision=clip_vision, diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 1d210be4..67dc0488 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -13,7 +13,7 @@ class ControlModelPatcher: def try_build_from_state_dict(state_dict, ckpt_path): return None - def __init__(self, model_patcher): + def __init__(self, model_patcher=None): self.model_patcher = model_patcher self.strength = 1.0 self.start_percent = 0.0