From c2424087d69d8cf72658a2472ba775cebfea17d5 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 6 Aug 2024 11:53:27 -0600 Subject: [PATCH] 8 bit training working on flux --- extensions_built_in/sd_trainer/SDTrainer.py | 21 +++++------- jobs/process/BaseSDTrainProcess.py | 9 ++++- toolkit/config_modules.py | 2 +- toolkit/lora_special.py | 6 ++-- toolkit/network_mixins.py | 37 ++++++++++++++++----- toolkit/stable_diffusion_model.py | 36 ++++++++++++++++---- toolkit/train_tools.py | 2 ++ 7 files changed, 82 insertions(+), 31 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f198e126..dacb23c7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1538,22 +1538,19 @@ class SDTrainer(BaseSDTrainProcess): # flush() if not self.is_grad_accumulation_step: - # torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # fix this for multi params - if isinstance(self.params[0], dict): - for i in range(len(self.params)): - torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) - else: - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + if self.train_config.optimizer != 'adafactor': + self.scaler.unscale_(self.optimizer) + if isinstance(self.params[0], dict): + for i in range(len(self.params)): + torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # only step if we are not accumulating with self.timer('optimizer_step'): - if self.is_bfloat: - self.optimizer.step() - else: - # apply gradients - self.optimizer.step() - # self.scaler.update() # self.optimizer.step() + self.scaler.step(self.optimizer) + self.scaler.update() self.optimizer.zero_grad(set_to_none=True) if self.ema is not None: with self.timer('ema_update'): diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 853e8177..b6ca5a73 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1353,7 +1353,9 @@ class BaseSDTrainProcess(BaseTrainProcess): **network_kwargs ) - self.network.force_to(self.device_torch, dtype=dtype) + + # todo switch everything to proper mixed precision like this + self.network.force_to(self.device_torch, dtype=torch.float32) # give network to sd so it can use it self.sd.network = self.network self.network._update_torch_multiplier() @@ -1365,6 +1367,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.train_config.train_unet ) + # we cannot merge in if quantized + if self.model_config.quantize: + # todo find a way around this + self.network.can_merge_in = False + if is_lorm: self.network.is_lorm = True # make sure it is on the right device diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 673b8605..e08e8195 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -520,7 +520,7 @@ class DatasetConfig: self.random_crop: bool = kwargs.get('random_crop', False) self.resolution: int = kwargs.get('resolution', 512) self.scale: float = kwargs.get('scale', 1.0) - self.buckets: bool = kwargs.get('buckets', False) + self.buckets: bool = kwargs.get('buckets', True) self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) self.is_reg: bool = kwargs.get('is_reg', False) self.network_weight: float = float(kwargs.get('network_weight', 1.0)) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 92473dee..df1e807b 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -28,12 +28,14 @@ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers # diffusers specific stuff LINEAR_MODULES = [ 'Linear', - 'LoRACompatibleLinear' + 'LoRACompatibleLinear', + 'QLinear', # 'GroupNorm', ] CONV_MODULES = [ 'Conv2d', - 'LoRACompatibleConv' + 'LoRACompatibleConv', + 'QConv2d', ] class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 94781e0c..620c9852 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -4,6 +4,7 @@ from collections import OrderedDict from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal import torch +from optimum.quanto import QTensor from torch import nn import weakref @@ -258,7 +259,12 @@ class ToolkitModuleMixin: # return self.dora_forward(x, *args, **kwargs) org_forwarded = self.org_forward(x, *args, **kwargs) - lora_output = self._call_forward(x) + + if isinstance(x, QTensor): + x = x.dequantize() + # always cast to float32 + lora_input = x.float() + lora_output = self._call_forward(lora_input) multiplier = self.network_ref().torch_multiplier lora_output_batch_size = lora_output.size(0) @@ -269,6 +275,7 @@ class ToolkitModuleMixin: multiplier = multiplier.repeat_interleave(num_interleaves) scaled_lora_output = broadcast_and_multiply(lora_output, multiplier) + scaled_lora_output = scaled_lora_output.to(org_forwarded.dtype) if self.__class__.__name__ == "DoRAModule": # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417 @@ -320,8 +327,18 @@ class ToolkitModuleMixin: # extract weight from org_module org_sd = self.org_module[0].state_dict() - orig_dtype = org_sd["weight"].dtype - weight = org_sd["weight"].float() + # todo find a way to merge in weights when doing quantized model + if 'weight._data' in org_sd: + # quantized weight + return + + weight_key = "weight" + if 'weight._data' in org_sd: + # quantized weight + weight_key = "weight._data" + + orig_dtype = org_sd[weight_key].dtype + weight = org_sd[weight_key].float() multiplier = merge_weight scale = self.scale @@ -348,7 +365,7 @@ class ToolkitModuleMixin: weight = weight + multiplier * conved * scale # set weight to org_module - org_sd["weight"] = weight.to(orig_dtype) + org_sd[weight_key] = weight.to(orig_dtype) self.org_module[0].load_state_dict(org_sd) def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None): @@ -523,12 +540,16 @@ class ToolkitNetworkMixin: keymap = self.get_keymap(force_weight_mapping) keymap = {} if keymap is None else keymap - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file + if isinstance(file, str): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file - weights_sd = load_file(file) + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") else: - weights_sd = torch.load(file, map_location="cpu") + # probably a state dict + weights_sd = file load_sd = OrderedDict() for key, value in weights_sd.items(): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index db9f0010..4c0a12f4 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -14,6 +14,7 @@ from PIL import Image from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file, load_file +from torch import autocast from torch.nn import Parameter from torch.utils.checkpoint import checkpoint from tqdm import tqdm @@ -54,7 +55,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from toolkit.util.inverse_cfg import inverse_classifier_guidance -from optimum.quanto import freeze, qfloat8, quantize +from optimum.quanto import freeze, qfloat8, quantize, QTensor # tell it to shut up diffusers.logging.set_verbosity(diffusers.logging.ERROR) @@ -474,6 +475,23 @@ class StableDiffusion: transformer.to(self.device_torch, dtype=dtype) flush() + if self.model_config.lora_path is not None: + # need the pipe to do this unfortunately for now + # we have to fuse in the weights before quantizing + pipe: FluxPipeline = FluxPipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=None, + text_encoder_2=None, + tokenizer_2=None, + vae=vae, + transformer=transformer, + ) + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + if self.model_config.quantize: print("Quantizing transformer") quantize(transformer, weights=qfloat8) @@ -498,7 +516,7 @@ class StableDiffusion: text_encoder.to(self.device_torch, dtype=dtype) print("making pipe") - pipe = FluxPipeline( + pipe: FluxPipeline = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, @@ -613,7 +631,7 @@ class StableDiffusion: self.unet.eval() # load any loras we have - if self.model_config.lora_path is not None: + if self.model_config.lora_path is not None and not self.is_flux: pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") pipe.fuse_lora() # unfortunately, not an easier way with peft @@ -1631,14 +1649,15 @@ class StableDiffusion: width=width_latent, # 128 ) - + cast_dtype = self.unet.dtype + # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): noise_pred = self.unet( - hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64] + hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # todo make sure this doesnt change timestep=timestep / 1000, # timestep is 1000 scale - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096] - pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768] + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096] + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] txt_ids=text_ids, # [1, 512, 3] img_ids=latent_image_ids, # [1, 4096, 3] guidance=guidance, @@ -1646,6 +1665,9 @@ class StableDiffusion: **kwargs, )[0] + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + # unpack latents noise_pred = self.pipeline._unpack_latents( noise_pred, diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 6b906fff..b9059b7b 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -52,6 +52,8 @@ def get_torch_dtype(dtype_str): return torch.float16 if dtype_str == "bf16" or dtype_str == "bfloat16": return torch.bfloat16 + if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8": + return torch.float8_e4m3fn return dtype_str