From f965a1299f9a95ab6ed4bb72caa39b16583af9ea Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 24 Feb 2024 10:26:01 -0700 Subject: [PATCH] Fixed Dora implementation. Still highly experimental --- .gitignore | 3 +- jobs/process/GenerateProcess.py | 21 +++++++++-- toolkit/dataloader_mixins.py | 19 ++++++++-- toolkit/models/DoRA.py | 64 ++++++++++++++++++++++++++------- toolkit/network_mixins.py | 54 +++++++++++++++++++--------- 5 files changed, 128 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 5a3ba0ed..edb8d50d 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,5 @@ cython_debug/ /output/* !/output/.gitkeep /extensions/* -!/extensions/example \ No newline at end of file +!/extensions/example +/temp \ No newline at end of file diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index 13acca2d..f5fa22b0 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -32,16 +32,30 @@ class GenerateConfig: self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) self.ext = kwargs.get('ext', 'png') self.prompt_file = kwargs.get('prompt_file', False) + self.prompts_in_file = self.prompts if self.prompts is None: raise ValueError("Prompts must be set") if isinstance(self.prompts, str): if os.path.exists(self.prompts): with open(self.prompts, 'r', encoding='utf-8') as f: - self.prompts = f.read().splitlines() - self.prompts = [p.strip() for p in self.prompts if len(p.strip()) > 0] + self.prompts_in_file = f.read().splitlines() + self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0] else: raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts") + self.random_prompts = kwargs.get('random_prompts', False) + self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1) + self.max_images = kwargs.get('max_prompts', 10000) + + if self.random_prompts: + self.prompts = [] + for i in range(self.max_images): + num_prompts = random.randint(1, self.max_random_per_prompt) + prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)] + self.prompts.append(", ".join(prompt_list)) + else: + self.prompts = self.prompts_in_file + if kwargs.get('shuffle', False): # shuffle the prompts random.shuffle(self.prompts) @@ -78,6 +92,9 @@ class GenerateProcess(BaseProcess): print("Loading model...") self.sd.load_model() + print("Compiling model...") + self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) + print(f"Generating {len(self.generate_config.prompts)} images") # build prompt image configs prompt_image_configs = [] diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 29bf5147..3aef4d14 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -629,8 +629,23 @@ class ClipImageFileItemDTOMixin: # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() - # apply augmentations - augmented = self.clip_image_aug_transform(image=open_cv_image)["image"] + if self.clip_vision_is_quad: + # image is in a 2x2 gris. split, run augs, and recombine + # split + img1, img2 = np.hsplit(open_cv_image, 2) + img1_1, img1_2 = np.vsplit(img1, 2) + img2_1, img2_2 = np.vsplit(img2, 2) + # apply augmentations + img1_1 = self.clip_image_aug_transform(image=img1_1)["image"] + img1_2 = self.clip_image_aug_transform(image=img1_2)["image"] + img2_1 = self.clip_image_aug_transform(image=img2_1)["image"] + img2_2 = self.clip_image_aug_transform(image=img2_2)["image"] + # recombine + augmented = np.vstack((np.hstack((img1_1, img1_2)), np.hstack((img2_1, img2_2)))) + + else: + # apply augmentations + augmented = self.clip_image_aug_transform(image=open_cv_image)["image"] # convert back to RGB tensor augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) diff --git a/toolkit/models/DoRA.py b/toolkit/models/DoRA.py index 0f010352..fb8f4838 100644 --- a/toolkit/models/DoRA.py +++ b/toolkit/models/DoRA.py @@ -22,6 +22,13 @@ CONV_MODULES = [ 'LoRACompatibleConv' ] +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): # def __init__(self, d_in, d_out, rank=4, weight=None, bias=None): @@ -65,15 +72,26 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): self.module_dropout = module_dropout self.is_checkpointing = False - # m = Magnitude column-wise across output dimension - self.magnitude = nn.Parameter(self.get_orig_weight().norm(p=2, dim=0, keepdim=True)) - d_out = org_module.out_features d_in = org_module.in_features std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float()) - self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) - self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) + # self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) # lora_A + # self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) # lora_B + self.lora_up = nn.Linear(self.lora_dim, d_out, bias=False) # lora_B + # self.lora_up.weight.data = torch.randn_like(self.lora_up.weight.data) * std_dev + self.lora_up.weight.data = torch.zeros_like(self.lora_up.weight.data) + # self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) + # self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) + self.lora_down = nn.Linear(d_in, self.lora_dim, bias=False) # lora_A + # self.lora_down.weight.data = torch.zeros_like(self.lora_down.weight.data) + self.lora_down.weight.data = torch.randn_like(self.lora_down.weight.data) * std_dev + + # m = Magnitude column-wise across output dimension + weight = self.get_orig_weight() + lora_weight = self.lora_up.weight @ self.lora_down.weight + weight_norm = self._get_weight_norm(weight, lora_weight) + self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True) def apply_to(self): self.org_forward = self.org_module[0].forward @@ -88,11 +106,33 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): return self.org_module[0].bias.data.detach() return None - def dora_forward(self, x, *args, **kwargs): - lora = torch.matmul(self.lora_up, self.lora_down) - adapted = self.get_orig_weight() + lora - column_norm = adapted.norm(p=2, dim=0, keepdim=True) - norm_adapted = adapted / column_norm - calc_weights = self.magnitude * norm_adapted - return F.linear(x, calc_weights, self.get_orig_bias()) + # def dora_forward(self, x, *args, **kwargs): + # lora = torch.matmul(self.lora_A, self.lora_B) + # adapted = self.get_orig_weight() + lora + # column_norm = adapted.norm(p=2, dim=0, keepdim=True) + # norm_adapted = adapted / column_norm + # calc_weights = self.magnitude * norm_adapted + # return F.linear(x, calc_weights, self.get_orig_bias()) + def _get_weight_norm(self, weight, scaled_lora_weight) -> torch.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = weight + scaled_lora_weight.to(weight.device) + weight_norm = torch.linalg.norm(weight, dim=1) + return weight_norm + + def apply_dora(self, x, scaled_lora_weight): + # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L192 + # lora weight is already scaled + + # magnitude = self.lora_magnitude_vector[active_adapter] + weight = self.get_orig_weight() + weight_norm = self._get_weight_norm(weight, scaled_lora_weight) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + dora_weight = transpose(weight + scaled_lora_weight, False) + return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index cb3fc747..1fc86fb4 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -52,8 +52,14 @@ def broadcast_and_multiply(tensor, multiplier): for _ in range(num_extra_dims): multiplier = multiplier.unsqueeze(-1) - # Multiplying the broadcasted tensor with the output tensor - result = tensor * multiplier + try: + # Multiplying the broadcasted tensor with the output tensor + result = tensor * multiplier + except RuntimeError as e: + print(e) + print(tensor.size()) + print(multiplier.size()) + raise e return result @@ -248,9 +254,9 @@ class ToolkitModuleMixin: # network is not active, avoid doing anything return self.org_forward(x, *args, **kwargs) - if self.__class__.__name__ == "DoRAModule": - # return dora forward - return self.dora_forward(x, *args, **kwargs) + # if self.__class__.__name__ == "DoRAModule": + # # return dora forward + # return self.dora_forward(x, *args, **kwargs) org_forwarded = self.org_forward(x, *args, **kwargs) lora_output = self._call_forward(x) @@ -263,7 +269,27 @@ class ToolkitModuleMixin: # todo check if this is correct, do we just concat when doing cfg? multiplier = multiplier.repeat_interleave(num_interleaves) - x = org_forwarded + broadcast_and_multiply(lora_output, multiplier) + scaled_lora_output = broadcast_and_multiply(lora_output, multiplier) + + if self.__class__.__name__ == "DoRAModule": + # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417 + # x = dropout(x) + # todo this wont match the dropout applied to the lora + if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): + lx = self.dropout(x) + # normal dropout + elif self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(x, p=self.dropout) + else: + lx = x + lora_weight = self.lora_up.weight @ self.lora_down.weight + # scale it here + # todo handle our batch split scalers for slider training. For now take the mean of them + scale = multiplier.mean() + scaled_lora_weight = lora_weight * scale + scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight) + + x = org_forwarded + scaled_lora_output return x def enable_gradient_checkpointing(self: Module): @@ -413,12 +439,12 @@ class ToolkitNetworkMixin: new_keymap = {} for ldm_key, diffusers_key in keymap.items(): ldm_key = ldm_key.replace('.alpha', '.magnitude') - ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down') - ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up') + # ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down') + # ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up') diffusers_key = diffusers_key.replace('.alpha', '.magnitude') - diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down') - diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up') + # diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down') + # diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up') new_keymap[ldm_key] = diffusers_key @@ -513,12 +539,8 @@ class ToolkitNetworkMixin: multiplier = self._multiplier # get first module first_module = self.get_all_modules()[0] - if self.network_type.lower() == 'dora': - device = first_module.lora_down.device - dtype = first_module.lora_down.dtype - else: - device = first_module.lora_down.weight.device - dtype = first_module.lora_down.weight.dtype + device = first_module.lora_down.weight.device + dtype = first_module.lora_down.weight.dtype with torch.no_grad(): tensor_multiplier = None if isinstance(multiplier, int) or isinstance(multiplier, float):