Added ability to split up flux across gpus (experimental). Changed the way timestep scheduling works to prep for more specific schedules.

This commit is contained in:
Jaret Burkett
2024-12-31 07:06:55 -07:00
parent 8ef07a9c36
commit 4723f23c0d
5 changed files with 182 additions and 7 deletions

View File

@@ -970,10 +970,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.linear_timesteps2,
self.train_config.timestep_type == 'linear',
])
timestep_type = 'linear' if linear_timesteps else None
if timestep_type is None:
timestep_type = self.train_config.timestep_type
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps,
device=self.device_torch,
linear=linear_timesteps
timestep_type=timestep_type
)
else:
self.sd.noise_scheduler.set_timesteps(

View File

@@ -386,7 +386,7 @@ class TrainConfig:
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend
self.linear_timesteps = kwargs.get('linear_timesteps', False)
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
self.disable_sampling = kwargs.get('disable_sampling', False)
@@ -470,6 +470,11 @@ class ModelConfig:
if self.ignore_if_contains is not None or self.only_if_contains is not None:
if not self.is_flux:
raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently")
# splits the model over the available gpus WIP
self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False)
if self.split_model_over_gpus and not self.is_flux:
raise ValueError("split_model_over_gpus is only supported with flux models currently")
class EMAConfig:

View File

@@ -1,6 +1,9 @@
# forward that bypasses the guidance embedding so it can be avoided during training.
from functools import partial
from typing import Optional
import torch
from diffusers import FluxTransformer2DModel
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
@@ -33,3 +36,139 @@ def restore_flux_guidance(transformer):
return
transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
del transformer.time_text_embed._bfg_orig_forward
def new_device_to(self: FluxTransformer2DModel, *args, **kwargs):
# Store original device if provided in args or kwargs
device_in_kwargs = 'device' in kwargs
device_in_args = any(isinstance(arg, (str, torch.device)) for arg in args)
device = None
# Remove device from kwargs if present
if device_in_kwargs:
device = kwargs['device']
del kwargs['device']
# Only filter args if we detected a device argument
if device_in_args:
args = list(args)
for idx, arg in enumerate(args):
if isinstance(arg, (str, torch.device)):
device = arg
del args[idx]
self.pos_embed = self.pos_embed.to(device, *args, **kwargs)
self.time_text_embed = self.time_text_embed.to(device, *args, **kwargs)
self.context_embedder = self.context_embedder.to(device, *args, **kwargs)
self.x_embedder = self.x_embedder.to(device, *args, **kwargs)
for block in self.transformer_blocks:
block.to(block._split_device, *args, **kwargs)
for block in self.single_transformer_blocks:
block.to(block._split_device, *args, **kwargs)
self.norm_out = self.norm_out.to(device, *args, **kwargs)
self.proj_out = self.proj_out.to(device, *args, **kwargs)
return self
def split_gpu_double_block_forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
if hidden_states.device != self._split_device:
hidden_states = hidden_states.to(self._split_device)
if encoder_hidden_states.device != self._split_device:
encoder_hidden_states = encoder_hidden_states.to(self._split_device)
if temb.device != self._split_device:
temb = temb.to(self._split_device)
if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device:
# is a tuple of tensors
image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb])
return self._pre_gpu_split_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs)
def split_gpu_single_block_forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
**kwargs
):
if hidden_states.device != self._split_device:
hidden_states = hidden_states.to(device=self._split_device)
if temb.device != self._split_device:
temb = temb.to(device=self._split_device)
if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device:
# is a tuple of tensors
image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb])
hidden_state_out = self._pre_gpu_split_forward(hidden_states, temb, image_rotary_emb, joint_attention_kwargs, **kwargs)
if hasattr(self, "_split_output_device"):
return hidden_state_out.to(self._split_output_device)
return hidden_state_out
def add_model_gpu_splitter_to_flux(transformer: FluxTransformer2DModel):
gpu_id_list = [i for i in range(torch.cuda.device_count())]
# if len(gpu_id_list) > 2:
# raise ValueError("Cannot split to more than 2 GPUs currently.")
# ~ 5 billion for all other params
other_module_params = 5e9
# since they are not trainable, multiply by smaller number
other_module_params *= 0.5
# since we are not tuning the
total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params
params_per_gpu = total_params / len(gpu_id_list)
current_gpu_idx = 0
# text encoders, vae, and some non block layers will all be on gpu 0
current_gpu_params = other_module_params
for double_block in transformer.transformer_blocks:
device = torch.device(f"cuda:{current_gpu_idx}")
double_block._pre_gpu_split_forward = double_block.forward
double_block.forward = partial(
split_gpu_double_block_forward, double_block)
double_block._split_device = device
# add the params to the current gpu
current_gpu_params += sum(p.numel() for p in double_block.parameters())
# if the current gpu params are greater than the params per gpu, move to next gpu
if current_gpu_params > params_per_gpu:
current_gpu_idx += 1
current_gpu_params = 0
if current_gpu_idx >= len(gpu_id_list):
current_gpu_idx = gpu_id_list[-1]
for single_block in transformer.single_transformer_blocks:
device = torch.device(f"cuda:{current_gpu_idx}")
single_block._pre_gpu_split_forward = single_block.forward
single_block.forward = partial(
split_gpu_single_block_forward, single_block)
single_block._split_device = device
# add the params to the current gpu
current_gpu_params += sum(p.numel() for p in single_block.parameters())
# if the current gpu params are greater than the params per gpu, move to next gpu
if current_gpu_params > params_per_gpu:
current_gpu_idx += 1
current_gpu_params = 0
if current_gpu_idx >= len(gpu_id_list):
current_gpu_idx = gpu_id_list[-1]
# add output device to last layer
transformer.single_transformer_blocks[-1]._split_output_device = torch.device("cuda:0")
transformer._pre_gpu_split_to = transformer.to
transformer.to = partial(new_device_to, transformer)

View File

@@ -1,6 +1,6 @@
import math
from typing import Union
from torch.distributions import LogNormal
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
@@ -89,12 +89,12 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'):
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
elif timestep_type == 'sigmoid':
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
@@ -108,3 +108,25 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.8
lognormal = LogNormal(loc=0, scale=0.333)
# Sample from the distribution
t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device)
# Scale and reverse the values to go from 1000 to 0
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
timesteps = timesteps.to(torch.int)
else:
raise ValueError(f"Invalid timestep type: {timestep_type}")

View File

@@ -60,7 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from huggingface_hub import hf_hub_download
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from typing import TYPE_CHECKING
@@ -553,6 +553,10 @@ class StableDiffusion:
# low_cpu_mem_usage=False,
# device_map=None
)
# hack in model gpu splitter
if self.model_config.split_model_over_gpus:
add_model_gpu_splitter_to_flux(transformer)
if not self.low_vram:
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
transformer.to(torch.device(self.quantize_device), dtype=dtype)