diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 87984717..7815dc5a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -970,10 +970,15 @@ class BaseSDTrainProcess(BaseTrainProcess): self.train_config.linear_timesteps2, self.train_config.timestep_type == 'linear', ]) + + timestep_type = 'linear' if linear_timesteps else None + if timestep_type is None: + timestep_type = self.train_config.timestep_type + self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, - linear=linear_timesteps + timestep_type=timestep_type ) else: self.sd.noise_scheduler.set_timesteps( diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 1e7215bf..ae6da3de 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -386,7 +386,7 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) - self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend self.linear_timesteps = kwargs.get('linear_timesteps', False) self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) self.disable_sampling = kwargs.get('disable_sampling', False) @@ -470,6 +470,11 @@ class ModelConfig: if self.ignore_if_contains is not None or self.only_if_contains is not None: if not self.is_flux: raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently") + + # splits the model over the available gpus WIP + self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False) + if self.split_model_over_gpus and not self.is_flux: + raise ValueError("split_model_over_gpus is only supported with flux models currently") class EMAConfig: diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py index 48ce8786..829283e4 100644 --- a/toolkit/models/flux.py +++ b/toolkit/models/flux.py @@ -1,6 +1,9 @@ # forward that bypasses the guidance embedding so it can be avoided during training. from functools import partial +from typing import Optional +import torch +from diffusers import FluxTransformer2DModel def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): @@ -33,3 +36,139 @@ def restore_flux_guidance(transformer): return transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward del transformer.time_text_embed._bfg_orig_forward + +def new_device_to(self: FluxTransformer2DModel, *args, **kwargs): + # Store original device if provided in args or kwargs + device_in_kwargs = 'device' in kwargs + device_in_args = any(isinstance(arg, (str, torch.device)) for arg in args) + + device = None + # Remove device from kwargs if present + if device_in_kwargs: + device = kwargs['device'] + del kwargs['device'] + + # Only filter args if we detected a device argument + if device_in_args: + args = list(args) + for idx, arg in enumerate(args): + if isinstance(arg, (str, torch.device)): + device = arg + del args[idx] + + self.pos_embed = self.pos_embed.to(device, *args, **kwargs) + self.time_text_embed = self.time_text_embed.to(device, *args, **kwargs) + self.context_embedder = self.context_embedder.to(device, *args, **kwargs) + self.x_embedder = self.x_embedder.to(device, *args, **kwargs) + for block in self.transformer_blocks: + block.to(block._split_device, *args, **kwargs) + for block in self.single_transformer_blocks: + block.to(block._split_device, *args, **kwargs) + + self.norm_out = self.norm_out.to(device, *args, **kwargs) + self.proj_out = self.proj_out.to(device, *args, **kwargs) + + + + return self + + + + +def split_gpu_double_block_forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, +): + if hidden_states.device != self._split_device: + hidden_states = hidden_states.to(self._split_device) + if encoder_hidden_states.device != self._split_device: + encoder_hidden_states = encoder_hidden_states.to(self._split_device) + if temb.device != self._split_device: + temb = temb.to(self._split_device) + if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device: + # is a tuple of tensors + image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb]) + return self._pre_gpu_split_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) + + +def split_gpu_single_block_forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + **kwargs +): + if hidden_states.device != self._split_device: + hidden_states = hidden_states.to(device=self._split_device) + if temb.device != self._split_device: + temb = temb.to(device=self._split_device) + if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device: + # is a tuple of tensors + image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb]) + + hidden_state_out = self._pre_gpu_split_forward(hidden_states, temb, image_rotary_emb, joint_attention_kwargs, **kwargs) + if hasattr(self, "_split_output_device"): + return hidden_state_out.to(self._split_output_device) + return hidden_state_out + + +def add_model_gpu_splitter_to_flux(transformer: FluxTransformer2DModel): + gpu_id_list = [i for i in range(torch.cuda.device_count())] + + # if len(gpu_id_list) > 2: + # raise ValueError("Cannot split to more than 2 GPUs currently.") + + # ~ 5 billion for all other params + other_module_params = 5e9 + # since they are not trainable, multiply by smaller number + other_module_params *= 0.5 + + # since we are not tuning the + total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params + + params_per_gpu = total_params / len(gpu_id_list) + + current_gpu_idx = 0 + # text encoders, vae, and some non block layers will all be on gpu 0 + current_gpu_params = other_module_params + + for double_block in transformer.transformer_blocks: + device = torch.device(f"cuda:{current_gpu_idx}") + double_block._pre_gpu_split_forward = double_block.forward + double_block.forward = partial( + split_gpu_double_block_forward, double_block) + double_block._split_device = device + # add the params to the current gpu + current_gpu_params += sum(p.numel() for p in double_block.parameters()) + # if the current gpu params are greater than the params per gpu, move to next gpu + if current_gpu_params > params_per_gpu: + current_gpu_idx += 1 + current_gpu_params = 0 + if current_gpu_idx >= len(gpu_id_list): + current_gpu_idx = gpu_id_list[-1] + + for single_block in transformer.single_transformer_blocks: + device = torch.device(f"cuda:{current_gpu_idx}") + single_block._pre_gpu_split_forward = single_block.forward + single_block.forward = partial( + split_gpu_single_block_forward, single_block) + single_block._split_device = device + # add the params to the current gpu + current_gpu_params += sum(p.numel() for p in single_block.parameters()) + # if the current gpu params are greater than the params per gpu, move to next gpu + if current_gpu_params > params_per_gpu: + current_gpu_idx += 1 + current_gpu_params = 0 + if current_gpu_idx >= len(gpu_id_list): + current_gpu_idx = gpu_id_list[-1] + + # add output device to last layer + transformer.single_transformer_blocks[-1]._split_output_device = torch.device("cuda:0") + + transformer._pre_gpu_split_to = transformer.to + transformer.to = partial(new_device_to, transformer) diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 440eb4fa..c8c190a5 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -1,6 +1,6 @@ import math from typing import Union - +from torch.distributions import LogNormal from diffusers import FlowMatchEulerDiscreteScheduler import torch @@ -89,12 +89,12 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample - def set_train_timesteps(self, num_timesteps, device, linear=False): - if linear: + def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'): + if timestep_type == 'linear': timesteps = torch.linspace(1000, 0, num_timesteps, device=device) self.timesteps = timesteps return timesteps - else: + elif timestep_type == 'sigmoid': # distribute them closer to center. Inference distributes them as a bias toward first # Generate values from 0 to 1 t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) @@ -108,3 +108,25 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): self.timesteps = timesteps.to(device=device) return timesteps + elif timestep_type == 'lognorm_blend': + # disgtribute timestepd to the center/early and blend in linear + alpha = 0.8 + + lognormal = LogNormal(loc=0, scale=0.333) + + # Sample from the distribution + t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device) + + # Scale and reverse the values to go from 1000 to 0 + t1 = ((1 - t1/t1.max()) * 1000) + + # add half of linear + t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device) + timesteps = torch.cat((t1, t2)) + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + timesteps = timesteps.to(torch.int) + else: + raise ValueError(f"Invalid timestep type: {timestep_type}") diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 23439fbc..aaa30898 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -60,7 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from huggingface_hub import hf_hub_download -from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 from typing import TYPE_CHECKING @@ -553,6 +553,10 @@ class StableDiffusion: # low_cpu_mem_usage=False, # device_map=None ) + # hack in model gpu splitter + if self.model_config.split_model_over_gpus: + add_model_gpu_splitter_to_flux(transformer) + if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu transformer.to(torch.device(self.quantize_device), dtype=dtype)