mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-04 01:59:48 +00:00
8 bit training working on flux
This commit is contained in:
@@ -1538,22 +1538,19 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
# torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# fix this for multi params
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
if self.is_bfloat:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
# self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.ema is not None:
|
||||
with self.timer('ema_update'):
|
||||
|
||||
@@ -1353,7 +1353,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
**network_kwargs
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
# todo switch everything to proper mixed precision like this
|
||||
self.network.force_to(self.device_torch, dtype=torch.float32)
|
||||
# give network to sd so it can use it
|
||||
self.sd.network = self.network
|
||||
self.network._update_torch_multiplier()
|
||||
@@ -1365,6 +1367,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.train_config.train_unet
|
||||
)
|
||||
|
||||
# we cannot merge in if quantized
|
||||
if self.model_config.quantize:
|
||||
# todo find a way around this
|
||||
self.network.can_merge_in = False
|
||||
|
||||
if is_lorm:
|
||||
self.network.is_lorm = True
|
||||
# make sure it is on the right device
|
||||
|
||||
@@ -520,7 +520,7 @@ class DatasetConfig:
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
self.scale: float = kwargs.get('scale', 1.0)
|
||||
self.buckets: bool = kwargs.get('buckets', False)
|
||||
self.buckets: bool = kwargs.get('buckets', True)
|
||||
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
||||
self.is_reg: bool = kwargs.get('is_reg', False)
|
||||
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
||||
|
||||
@@ -28,12 +28,14 @@ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers
|
||||
# diffusers specific stuff
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
'LoRACompatibleLinear',
|
||||
'QLinear',
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
'LoRACompatibleConv',
|
||||
'QConv2d',
|
||||
]
|
||||
|
||||
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import OrderedDict
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
|
||||
|
||||
import torch
|
||||
from optimum.quanto import QTensor
|
||||
from torch import nn
|
||||
import weakref
|
||||
|
||||
@@ -258,7 +259,12 @@ class ToolkitModuleMixin:
|
||||
# return self.dora_forward(x, *args, **kwargs)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
|
||||
if isinstance(x, QTensor):
|
||||
x = x.dequantize()
|
||||
# always cast to float32
|
||||
lora_input = x.float()
|
||||
lora_output = self._call_forward(lora_input)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
lora_output_batch_size = lora_output.size(0)
|
||||
@@ -269,6 +275,7 @@ class ToolkitModuleMixin:
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
|
||||
scaled_lora_output = broadcast_and_multiply(lora_output, multiplier)
|
||||
scaled_lora_output = scaled_lora_output.to(org_forwarded.dtype)
|
||||
|
||||
if self.__class__.__name__ == "DoRAModule":
|
||||
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417
|
||||
@@ -320,8 +327,18 @@ class ToolkitModuleMixin:
|
||||
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
orig_dtype = org_sd["weight"].dtype
|
||||
weight = org_sd["weight"].float()
|
||||
# todo find a way to merge in weights when doing quantized model
|
||||
if 'weight._data' in org_sd:
|
||||
# quantized weight
|
||||
return
|
||||
|
||||
weight_key = "weight"
|
||||
if 'weight._data' in org_sd:
|
||||
# quantized weight
|
||||
weight_key = "weight._data"
|
||||
|
||||
orig_dtype = org_sd[weight_key].dtype
|
||||
weight = org_sd[weight_key].float()
|
||||
|
||||
multiplier = merge_weight
|
||||
scale = self.scale
|
||||
@@ -348,7 +365,7 @@ class ToolkitModuleMixin:
|
||||
weight = weight + multiplier * conved * scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(orig_dtype)
|
||||
org_sd[weight_key] = weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
@@ -523,12 +540,16 @@ class ToolkitNetworkMixin:
|
||||
keymap = self.get_keymap(force_weight_mapping)
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
if isinstance(file, str):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
# probably a state dict
|
||||
weights_sd = file
|
||||
|
||||
load_sd = OrderedDict()
|
||||
for key, value in weights_sd.items():
|
||||
|
||||
@@ -14,6 +14,7 @@ from PIL import Image
|
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch import autocast
|
||||
from torch.nn import Parameter
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from tqdm import tqdm
|
||||
@@ -54,7 +55,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
@@ -474,6 +475,23 @@ class StableDiffusion:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
# need the pipe to do this unfortunately for now
|
||||
# we have to fuse in the weights before quantizing
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=None,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=qfloat8)
|
||||
@@ -498,7 +516,7 @@ class StableDiffusion:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
print("making pipe")
|
||||
pipe = FluxPipeline(
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
@@ -613,7 +631,7 @@ class StableDiffusion:
|
||||
self.unet.eval()
|
||||
|
||||
# load any loras we have
|
||||
if self.model_config.lora_path is not None:
|
||||
if self.model_config.lora_path is not None and not self.is_flux:
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
@@ -1631,14 +1649,15 @@ class StableDiffusion:
|
||||
width=width_latent, # 128
|
||||
)
|
||||
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# todo make sure this doesnt change
|
||||
timestep=timestep / 1000, # timestep is 1000 scale
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096]
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768]
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
|
||||
txt_ids=text_ids, # [1, 512, 3]
|
||||
img_ids=latent_image_ids, # [1, 4096, 3]
|
||||
guidance=guidance,
|
||||
@@ -1646,6 +1665,9 @@ class StableDiffusion:
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
if isinstance(noise_pred, QTensor):
|
||||
noise_pred = noise_pred.dequantize()
|
||||
|
||||
# unpack latents
|
||||
noise_pred = self.pipeline._unpack_latents(
|
||||
noise_pred,
|
||||
|
||||
@@ -52,6 +52,8 @@ def get_torch_dtype(dtype_str):
|
||||
return torch.float16
|
||||
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8":
|
||||
return torch.float8_e4m3fn
|
||||
return dtype_str
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user