mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to split up flux across gpus (experimental). Changed the way timestep scheduling works to prep for more specific schedules.
This commit is contained in:
@@ -970,10 +970,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.train_config.linear_timesteps2,
|
||||
self.train_config.timestep_type == 'linear',
|
||||
])
|
||||
|
||||
timestep_type = 'linear' if linear_timesteps else None
|
||||
if timestep_type is None:
|
||||
timestep_type = self.train_config.timestep_type
|
||||
|
||||
self.sd.noise_scheduler.set_train_timesteps(
|
||||
num_train_timesteps,
|
||||
device=self.device_torch,
|
||||
linear=linear_timesteps
|
||||
timestep_type=timestep_type
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
|
||||
@@ -386,7 +386,7 @@ class TrainConfig:
|
||||
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear
|
||||
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend
|
||||
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
||||
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
|
||||
self.disable_sampling = kwargs.get('disable_sampling', False)
|
||||
@@ -470,6 +470,11 @@ class ModelConfig:
|
||||
if self.ignore_if_contains is not None or self.only_if_contains is not None:
|
||||
if not self.is_flux:
|
||||
raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently")
|
||||
|
||||
# splits the model over the available gpus WIP
|
||||
self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False)
|
||||
if self.split_model_over_gpus and not self.is_flux:
|
||||
raise ValueError("split_model_over_gpus is only supported with flux models currently")
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
|
||||
# forward that bypasses the guidance embedding so it can be avoided during training.
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
|
||||
|
||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||
@@ -33,3 +36,139 @@ def restore_flux_guidance(transformer):
|
||||
return
|
||||
transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
|
||||
del transformer.time_text_embed._bfg_orig_forward
|
||||
|
||||
def new_device_to(self: FluxTransformer2DModel, *args, **kwargs):
|
||||
# Store original device if provided in args or kwargs
|
||||
device_in_kwargs = 'device' in kwargs
|
||||
device_in_args = any(isinstance(arg, (str, torch.device)) for arg in args)
|
||||
|
||||
device = None
|
||||
# Remove device from kwargs if present
|
||||
if device_in_kwargs:
|
||||
device = kwargs['device']
|
||||
del kwargs['device']
|
||||
|
||||
# Only filter args if we detected a device argument
|
||||
if device_in_args:
|
||||
args = list(args)
|
||||
for idx, arg in enumerate(args):
|
||||
if isinstance(arg, (str, torch.device)):
|
||||
device = arg
|
||||
del args[idx]
|
||||
|
||||
self.pos_embed = self.pos_embed.to(device, *args, **kwargs)
|
||||
self.time_text_embed = self.time_text_embed.to(device, *args, **kwargs)
|
||||
self.context_embedder = self.context_embedder.to(device, *args, **kwargs)
|
||||
self.x_embedder = self.x_embedder.to(device, *args, **kwargs)
|
||||
for block in self.transformer_blocks:
|
||||
block.to(block._split_device, *args, **kwargs)
|
||||
for block in self.single_transformer_blocks:
|
||||
block.to(block._split_device, *args, **kwargs)
|
||||
|
||||
self.norm_out = self.norm_out.to(device, *args, **kwargs)
|
||||
self.proj_out = self.proj_out.to(device, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
return self
|
||||
|
||||
|
||||
|
||||
|
||||
def split_gpu_double_block_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
image_rotary_emb=None,
|
||||
joint_attention_kwargs=None,
|
||||
):
|
||||
if hidden_states.device != self._split_device:
|
||||
hidden_states = hidden_states.to(self._split_device)
|
||||
if encoder_hidden_states.device != self._split_device:
|
||||
encoder_hidden_states = encoder_hidden_states.to(self._split_device)
|
||||
if temb.device != self._split_device:
|
||||
temb = temb.to(self._split_device)
|
||||
if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device:
|
||||
# is a tuple of tensors
|
||||
image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb])
|
||||
return self._pre_gpu_split_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs)
|
||||
|
||||
|
||||
def split_gpu_single_block_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
image_rotary_emb=None,
|
||||
joint_attention_kwargs=None,
|
||||
**kwargs
|
||||
):
|
||||
if hidden_states.device != self._split_device:
|
||||
hidden_states = hidden_states.to(device=self._split_device)
|
||||
if temb.device != self._split_device:
|
||||
temb = temb.to(device=self._split_device)
|
||||
if image_rotary_emb is not None and image_rotary_emb[0].device != self._split_device:
|
||||
# is a tuple of tensors
|
||||
image_rotary_emb = tuple([t.to(self._split_device) for t in image_rotary_emb])
|
||||
|
||||
hidden_state_out = self._pre_gpu_split_forward(hidden_states, temb, image_rotary_emb, joint_attention_kwargs, **kwargs)
|
||||
if hasattr(self, "_split_output_device"):
|
||||
return hidden_state_out.to(self._split_output_device)
|
||||
return hidden_state_out
|
||||
|
||||
|
||||
def add_model_gpu_splitter_to_flux(transformer: FluxTransformer2DModel):
|
||||
gpu_id_list = [i for i in range(torch.cuda.device_count())]
|
||||
|
||||
# if len(gpu_id_list) > 2:
|
||||
# raise ValueError("Cannot split to more than 2 GPUs currently.")
|
||||
|
||||
# ~ 5 billion for all other params
|
||||
other_module_params = 5e9
|
||||
# since they are not trainable, multiply by smaller number
|
||||
other_module_params *= 0.5
|
||||
|
||||
# since we are not tuning the
|
||||
total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params
|
||||
|
||||
params_per_gpu = total_params / len(gpu_id_list)
|
||||
|
||||
current_gpu_idx = 0
|
||||
# text encoders, vae, and some non block layers will all be on gpu 0
|
||||
current_gpu_params = other_module_params
|
||||
|
||||
for double_block in transformer.transformer_blocks:
|
||||
device = torch.device(f"cuda:{current_gpu_idx}")
|
||||
double_block._pre_gpu_split_forward = double_block.forward
|
||||
double_block.forward = partial(
|
||||
split_gpu_double_block_forward, double_block)
|
||||
double_block._split_device = device
|
||||
# add the params to the current gpu
|
||||
current_gpu_params += sum(p.numel() for p in double_block.parameters())
|
||||
# if the current gpu params are greater than the params per gpu, move to next gpu
|
||||
if current_gpu_params > params_per_gpu:
|
||||
current_gpu_idx += 1
|
||||
current_gpu_params = 0
|
||||
if current_gpu_idx >= len(gpu_id_list):
|
||||
current_gpu_idx = gpu_id_list[-1]
|
||||
|
||||
for single_block in transformer.single_transformer_blocks:
|
||||
device = torch.device(f"cuda:{current_gpu_idx}")
|
||||
single_block._pre_gpu_split_forward = single_block.forward
|
||||
single_block.forward = partial(
|
||||
split_gpu_single_block_forward, single_block)
|
||||
single_block._split_device = device
|
||||
# add the params to the current gpu
|
||||
current_gpu_params += sum(p.numel() for p in single_block.parameters())
|
||||
# if the current gpu params are greater than the params per gpu, move to next gpu
|
||||
if current_gpu_params > params_per_gpu:
|
||||
current_gpu_idx += 1
|
||||
current_gpu_params = 0
|
||||
if current_gpu_idx >= len(gpu_id_list):
|
||||
current_gpu_idx = gpu_id_list[-1]
|
||||
|
||||
# add output device to last layer
|
||||
transformer.single_transformer_blocks[-1]._split_output_device = torch.device("cuda:0")
|
||||
|
||||
transformer._pre_gpu_split_to = transformer.to
|
||||
transformer.to = partial(new_device_to, transformer)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
from torch.distributions import LogNormal
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
|
||||
@@ -89,12 +89,12 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, linear=False):
|
||||
if linear:
|
||||
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'):
|
||||
if timestep_type == 'linear':
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
else:
|
||||
elif timestep_type == 'sigmoid':
|
||||
# distribute them closer to center. Inference distributes them as a bias toward first
|
||||
# Generate values from 0 to 1
|
||||
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
||||
@@ -108,3 +108,25 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
elif timestep_type == 'lognorm_blend':
|
||||
# disgtribute timestepd to the center/early and blend in linear
|
||||
alpha = 0.8
|
||||
|
||||
lognormal = LogNormal(loc=0, scale=0.333)
|
||||
|
||||
# Sample from the distribution
|
||||
t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device)
|
||||
|
||||
# Scale and reverse the values to go from 1000 to 0
|
||||
t1 = ((1 - t1/t1.max()) * 1000)
|
||||
|
||||
# add half of linear
|
||||
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
|
||||
timesteps = torch.cat((t1, t2))
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||
|
||||
timesteps = timesteps.to(torch.int)
|
||||
else:
|
||||
raise ValueError(f"Invalid timestep type: {timestep_type}")
|
||||
|
||||
@@ -60,7 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -553,6 +553,10 @@ class StableDiffusion:
|
||||
# low_cpu_mem_usage=False,
|
||||
# device_map=None
|
||||
)
|
||||
# hack in model gpu splitter
|
||||
if self.model_config.split_model_over_gpus:
|
||||
add_model_gpu_splitter_to_flux(transformer)
|
||||
|
||||
if not self.low_vram:
|
||||
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user