Experimental features and bug fixes

This commit is contained in:
Jaret Burkett
2025-02-04 13:36:34 -07:00
parent e6180d1e1d
commit 216ab164ce
6 changed files with 26 additions and 16 deletions

View File

@@ -1839,7 +1839,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
did_first_flush = False
for step in range(start_step_num, self.train_config.steps):
if self.train_config.do_paramiter_swapping:
self.optimizer.swap_paramiters()
self.optimizer.optimizer.swap_paramiters()
self.timer.start('train_loop')
if self.train_config.do_random_cfg:
self.train_config.do_cfg = True

View File

@@ -483,6 +483,7 @@ class ModelConfig:
self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False)
if self.split_model_over_gpus and not self.is_flux:
raise ValueError("split_model_over_gpus is only supported with flux models currently")
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
class EMAConfig:

View File

@@ -250,13 +250,14 @@ class DiffusionFeatureExtractor3(nn.Module):
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
scheduler._step_index = None
scheduler._init_step_index(timestep)
sample = noisy_latents.to(torch.float32)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
sigma_next = scheduler.sigmas[-1] # use last sigma for final step

View File

@@ -117,16 +117,18 @@ def split_gpu_single_block_forward(
return hidden_state_out
def add_model_gpu_splitter_to_flux(transformer: FluxTransformer2DModel):
def add_model_gpu_splitter_to_flux(
transformer: FluxTransformer2DModel,
# ~ 5 billion for all other params
other_module_params: Optional[int] = 5e9,
# since they are not trainable, multiply by smaller number
other_module_param_count_scale: Optional[float] = 0.3
):
gpu_id_list = [i for i in range(torch.cuda.device_count())]
# if len(gpu_id_list) > 2:
# raise ValueError("Cannot split to more than 2 GPUs currently.")
# ~ 5 billion for all other params
other_module_params = 5e9
# since they are not trainable, multiply by smaller number
other_module_params *= 0.5
other_module_params *= other_module_param_count_scale
# since we are not tuning the
total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params

View File

@@ -108,6 +108,7 @@ class Adafactor(torch.optim.Optimizer):
warmup_init=False,
do_paramiter_swapping=False,
paramiter_swapping_factor=0.1,
stochastic_accumulation=True,
):
if lr is not None and relative_step:
raise ValueError(
@@ -136,13 +137,14 @@ class Adafactor(torch.optim.Optimizer):
self.is_stochastic_rounding_accumulation = False
# setup stochastic grad accum hooks
for group in self.param_groups:
for param in group['params']:
if param.requires_grad and param.dtype != torch.float32:
self.is_stochastic_rounding_accumulation = True
param.register_post_accumulate_grad_hook(
stochastic_grad_accummulation
)
if stochastic_accumulation:
for group in self.param_groups:
for param in group['params']:
if param.requires_grad and param.dtype != torch.float32:
self.is_stochastic_rounding_accumulation = True
param.register_post_accumulate_grad_hook(
stochastic_grad_accummulation
)
self.do_paramiter_swapping = do_paramiter_swapping
self.paramiter_swapping_factor = paramiter_swapping_factor

View File

@@ -66,6 +66,7 @@ from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING
from toolkit.print import print_acc
from diffusers import FluxFillPipeline
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
@@ -560,7 +561,10 @@ class StableDiffusion:
)
# hack in model gpu splitter
if self.model_config.split_model_over_gpus:
add_model_gpu_splitter_to_flux(transformer)
add_model_gpu_splitter_to_flux(
transformer,
other_module_param_count_scale=self.model_config.split_model_other_module_param_count_scale
)
if not self.low_vram:
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu