mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Experimental features and bug fixes
This commit is contained in:
@@ -1839,7 +1839,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
did_first_flush = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
if self.train_config.do_paramiter_swapping:
|
||||
self.optimizer.swap_paramiters()
|
||||
self.optimizer.optimizer.swap_paramiters()
|
||||
self.timer.start('train_loop')
|
||||
if self.train_config.do_random_cfg:
|
||||
self.train_config.do_cfg = True
|
||||
|
||||
@@ -483,6 +483,7 @@ class ModelConfig:
|
||||
self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False)
|
||||
if self.split_model_over_gpus and not self.is_flux:
|
||||
raise ValueError("split_model_over_gpus is only supported with flux models currently")
|
||||
self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3)
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
|
||||
@@ -250,13 +250,14 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
bs = noise_pred.shape[0]
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
stepped_chunks = []
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx]
|
||||
timestep = timestep_chunks[idx]
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
sample = noisy_latents.to(torch.float32)
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
|
||||
@@ -117,16 +117,18 @@ def split_gpu_single_block_forward(
|
||||
return hidden_state_out
|
||||
|
||||
|
||||
def add_model_gpu_splitter_to_flux(transformer: FluxTransformer2DModel):
|
||||
def add_model_gpu_splitter_to_flux(
|
||||
transformer: FluxTransformer2DModel,
|
||||
# ~ 5 billion for all other params
|
||||
other_module_params: Optional[int] = 5e9,
|
||||
# since they are not trainable, multiply by smaller number
|
||||
other_module_param_count_scale: Optional[float] = 0.3
|
||||
):
|
||||
gpu_id_list = [i for i in range(torch.cuda.device_count())]
|
||||
|
||||
# if len(gpu_id_list) > 2:
|
||||
# raise ValueError("Cannot split to more than 2 GPUs currently.")
|
||||
|
||||
# ~ 5 billion for all other params
|
||||
other_module_params = 5e9
|
||||
# since they are not trainable, multiply by smaller number
|
||||
other_module_params *= 0.5
|
||||
other_module_params *= other_module_param_count_scale
|
||||
|
||||
# since we are not tuning the
|
||||
total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params
|
||||
|
||||
@@ -108,6 +108,7 @@ class Adafactor(torch.optim.Optimizer):
|
||||
warmup_init=False,
|
||||
do_paramiter_swapping=False,
|
||||
paramiter_swapping_factor=0.1,
|
||||
stochastic_accumulation=True,
|
||||
):
|
||||
if lr is not None and relative_step:
|
||||
raise ValueError(
|
||||
@@ -136,13 +137,14 @@ class Adafactor(torch.optim.Optimizer):
|
||||
self.is_stochastic_rounding_accumulation = False
|
||||
|
||||
# setup stochastic grad accum hooks
|
||||
for group in self.param_groups:
|
||||
for param in group['params']:
|
||||
if param.requires_grad and param.dtype != torch.float32:
|
||||
self.is_stochastic_rounding_accumulation = True
|
||||
param.register_post_accumulate_grad_hook(
|
||||
stochastic_grad_accummulation
|
||||
)
|
||||
if stochastic_accumulation:
|
||||
for group in self.param_groups:
|
||||
for param in group['params']:
|
||||
if param.requires_grad and param.dtype != torch.float32:
|
||||
self.is_stochastic_rounding_accumulation = True
|
||||
param.register_post_accumulate_grad_hook(
|
||||
stochastic_grad_accummulation
|
||||
)
|
||||
|
||||
self.do_paramiter_swapping = do_paramiter_swapping
|
||||
self.paramiter_swapping_factor = paramiter_swapping_factor
|
||||
|
||||
@@ -66,6 +66,7 @@ from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from diffusers import FluxFillPipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -560,7 +561,10 @@ class StableDiffusion:
|
||||
)
|
||||
# hack in model gpu splitter
|
||||
if self.model_config.split_model_over_gpus:
|
||||
add_model_gpu_splitter_to_flux(transformer)
|
||||
add_model_gpu_splitter_to_flux(
|
||||
transformer,
|
||||
other_module_param_count_scale=self.model_config.split_model_other_module_param_count_scale
|
||||
)
|
||||
|
||||
if not self.low_vram:
|
||||
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
||||
|
||||
Reference in New Issue
Block a user