8 bit training working on flux

This commit is contained in:
Jaret Burkett
2024-08-06 11:53:27 -06:00
parent 272c8608c2
commit c2424087d6
7 changed files with 82 additions and 31 deletions

View File

@@ -1538,22 +1538,19 @@ class SDTrainer(BaseSDTrainProcess):
# flush()
if not self.is_grad_accumulation_step:
# torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# fix this for multi params
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
if self.train_config.optimizer != 'adafactor':
self.scaler.unscale_(self.optimizer)
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
if self.is_bfloat:
self.optimizer.step()
else:
# apply gradients
self.optimizer.step()
# self.scaler.update()
# self.optimizer.step()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
if self.ema is not None:
with self.timer('ema_update'):

View File

@@ -1353,7 +1353,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
**network_kwargs
)
self.network.force_to(self.device_torch, dtype=dtype)
# todo switch everything to proper mixed precision like this
self.network.force_to(self.device_torch, dtype=torch.float32)
# give network to sd so it can use it
self.sd.network = self.network
self.network._update_torch_multiplier()
@@ -1365,6 +1367,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.train_unet
)
# we cannot merge in if quantized
if self.model_config.quantize:
# todo find a way around this
self.network.can_merge_in = False
if is_lorm:
self.network.is_lorm = True
# make sure it is on the right device

View File

@@ -520,7 +520,7 @@ class DatasetConfig:
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)
self.scale: float = kwargs.get('scale', 1.0)
self.buckets: bool = kwargs.get('buckets', False)
self.buckets: bool = kwargs.get('buckets', True)
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
self.is_reg: bool = kwargs.get('is_reg', False)
self.network_weight: float = float(kwargs.get('network_weight', 1.0))

View File

@@ -28,12 +28,14 @@ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers
# diffusers specific stuff
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
'LoRACompatibleLinear',
'QLinear',
# 'GroupNorm',
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
'LoRACompatibleConv',
'QConv2d',
]
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):

View File

@@ -4,6 +4,7 @@ from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
import torch
from optimum.quanto import QTensor
from torch import nn
import weakref
@@ -258,7 +259,12 @@ class ToolkitModuleMixin:
# return self.dora_forward(x, *args, **kwargs)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
if isinstance(x, QTensor):
x = x.dequantize()
# always cast to float32
lora_input = x.float()
lora_output = self._call_forward(lora_input)
multiplier = self.network_ref().torch_multiplier
lora_output_batch_size = lora_output.size(0)
@@ -269,6 +275,7 @@ class ToolkitModuleMixin:
multiplier = multiplier.repeat_interleave(num_interleaves)
scaled_lora_output = broadcast_and_multiply(lora_output, multiplier)
scaled_lora_output = scaled_lora_output.to(org_forwarded.dtype)
if self.__class__.__name__ == "DoRAModule":
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417
@@ -320,8 +327,18 @@ class ToolkitModuleMixin:
# extract weight from org_module
org_sd = self.org_module[0].state_dict()
orig_dtype = org_sd["weight"].dtype
weight = org_sd["weight"].float()
# todo find a way to merge in weights when doing quantized model
if 'weight._data' in org_sd:
# quantized weight
return
weight_key = "weight"
if 'weight._data' in org_sd:
# quantized weight
weight_key = "weight._data"
orig_dtype = org_sd[weight_key].dtype
weight = org_sd[weight_key].float()
multiplier = merge_weight
scale = self.scale
@@ -348,7 +365,7 @@ class ToolkitModuleMixin:
weight = weight + multiplier * conved * scale
# set weight to org_module
org_sd["weight"] = weight.to(orig_dtype)
org_sd[weight_key] = weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
@@ -523,12 +540,16 @@ class ToolkitNetworkMixin:
keymap = self.get_keymap(force_weight_mapping)
keymap = {} if keymap is None else keymap
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
if isinstance(file, str):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
else:
weights_sd = torch.load(file, map_location="cpu")
# probably a state dict
weights_sd = file
load_sd = OrderedDict()
for key, value in weights_sd.items():

View File

@@ -14,6 +14,7 @@ from PIL import Image
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file
from torch import autocast
from torch.nn import Parameter
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
@@ -54,7 +55,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
from optimum.quanto import freeze, qfloat8, quantize
from optimum.quanto import freeze, qfloat8, quantize, QTensor
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -474,6 +475,23 @@ class StableDiffusion:
transformer.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.lora_path is not None:
# need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing
pipe: FluxPipeline = FluxPipeline(
scheduler=scheduler,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
vae=vae,
transformer=transformer,
)
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
if self.model_config.quantize:
print("Quantizing transformer")
quantize(transformer, weights=qfloat8)
@@ -498,7 +516,7 @@ class StableDiffusion:
text_encoder.to(self.device_torch, dtype=dtype)
print("making pipe")
pipe = FluxPipeline(
pipe: FluxPipeline = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
@@ -613,7 +631,7 @@ class StableDiffusion:
self.unet.eval()
# load any loras we have
if self.model_config.lora_path is not None:
if self.model_config.lora_path is not None and not self.is_flux:
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
# unfortunately, not an easier way with peft
@@ -1631,14 +1649,15 @@ class StableDiffusion:
width=width_latent, # 128
)
cast_dtype = self.unet.dtype
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
noise_pred = self.unet(
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768]
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
txt_ids=text_ids, # [1, 512, 3]
img_ids=latent_image_ids, # [1, 4096, 3]
guidance=guidance,
@@ -1646,6 +1665,9 @@ class StableDiffusion:
**kwargs,
)[0]
if isinstance(noise_pred, QTensor):
noise_pred = noise_pred.dequantize()
# unpack latents
noise_pred = self.pipeline._unpack_latents(
noise_pred,

View File

@@ -52,6 +52,8 @@ def get_torch_dtype(dtype_str):
return torch.float16
if dtype_str == "bf16" or dtype_str == "bfloat16":
return torch.bfloat16
if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8":
return torch.float8_e4m3fn
return dtype_str