diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index 79886ea4..a1b75d49 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -1,6 +1,6 @@ from functools import partial import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, List from typing_extensions import Self import torch import yaml @@ -134,13 +134,15 @@ class DualWanTransformer3DModel(torch.nn.Module): getattr(self, t_name).to(self.device_torch) torch.cuda.empty_cache() self._active_transformer_name = t_name - + if self.transformer.device != hidden_states.device: if self.low_vram: # move other transformer to cpu - other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2' + other_tname = ( + "transformer_1" if t_name == "transformer_2" else "transformer_2" + ) getattr(self, other_tname).to("cpu") - + self.transformer.to(hidden_states.device) return self.transformer( @@ -184,11 +186,33 @@ class Wan2214bModel(Wan225bModel): self.target_lora_modules = ["DualWanTransformer3DModel"] self._wan_cache = None + self.is_multistage = True + # multistage boundaries split the models up when sampling timesteps + # for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2 + self.multistage_boundaries: List[float] = [0.875, 0.0] + + self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True) + self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True) + + self.trainable_multistage_boundaries: List[int] = [] + if self.train_high_noise: + self.trainable_multistage_boundaries.append(0) + if self.train_low_noise: + self.trainable_multistage_boundaries.append(1) + + if len(self.trainable_multistage_boundaries) == 0: + raise ValueError( + "At least one of train_high_noise or train_low_noise must be True in model.model_kwargs" + ) + @property def max_step_saves_to_keep_multiplier(self): # the cleanup mechanism checks this to see how many saves to keep # if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep - if self.network is not None: + if ( + self.network is not None + and self.network.network_config.split_multistage_loras + ): return 2 return 1 @@ -264,7 +288,7 @@ class Wan2214bModel(Wan225bModel): transformer_1.to(self.quantize_device, dtype=dtype) flush() - if self.model_config.quantize: + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: # todo handle two ARAs self.print_and_status_update("Quantizing Transformer 1") quantize_model(self, transformer_1) @@ -289,7 +313,7 @@ class Wan2214bModel(Wan225bModel): transformer_2.to(self.quantize_device, dtype=dtype) flush() - if self.model_config.quantize: + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: # todo handle two ARAs self.print_and_status_update("Quantizing Transformer 2") quantize_model(self, transformer_2) @@ -309,7 +333,13 @@ class Wan2214bModel(Wan225bModel): boundary_ratio=boundary_ratio_t2v, low_vram=self.model_config.low_vram, ) - + + if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None: + # apply the accuracy recovery adapter to both transformers + self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers") + quantize_model(self, transformer) + flush() + return transformer def get_generation_pipeline(self): @@ -407,17 +437,20 @@ class Wan2214bModel(Wan225bModel): # just save as a combo lora save_file(state_dict, output_path, metadata=metadata) return - + # we need to build out both dictionaries for high and low noise LoRAs high_noise_lora = {} low_noise_lora = {} + + only_train_high_noise = self.train_high_noise and not self.train_low_noise + only_train_low_noise = self.train_low_noise and not self.train_high_noise for key in state_dict: - if ".transformer_1." in key: + if ".transformer_1." in key or only_train_high_noise: # this is a high noise LoRA new_key = key.replace(".transformer_1.", ".") high_noise_lora[new_key] = state_dict[key] - elif ".transformer_2." in key: + elif ".transformer_2." in key or only_train_low_noise: # this is a low noise LoRA new_key = key.replace(".transformer_2.", ".") low_noise_lora[new_key] = state_dict[key] @@ -439,11 +472,14 @@ class Wan2214bModel(Wan225bModel): def load_lora(self, file: str): # if it doesnt have high_noise or low_noise, it is a combo LoRA - if "_high_noise.safetensors" not in file and "_low_noise.safetensors" not in file: - # this is a combined LoRA, we need to split it up + if ( + "_high_noise.safetensors" not in file + and "_low_noise.safetensors" not in file + ): + # this is a combined LoRA, we dont need to split it up sd = load_file(file) return sd - + # we may have been passed the high_noise or the low_noise LoRA path, but we need to load both high_noise_lora_path = file.replace( "_low_noise.safetensors", "_high_noise.safetensors" @@ -454,7 +490,7 @@ class Wan2214bModel(Wan225bModel): combined_dict = {} - if os.path.exists(high_noise_lora_path): + if os.path.exists(high_noise_lora_path) and self.train_high_noise: # load the high noise LoRA high_noise_lora = load_file(high_noise_lora_path) for key in high_noise_lora: @@ -462,7 +498,7 @@ class Wan2214bModel(Wan225bModel): "diffusion_model.", "diffusion_model.transformer_1." ) combined_dict[new_key] = high_noise_lora[key] - if os.path.exists(low_noise_lora_path): + if os.path.exists(low_noise_lora_path) and self.train_low_noise: # load the low noise LoRA low_noise_lora = load_file(low_noise_lora_path) for key in low_noise_lora: @@ -470,5 +506,35 @@ class Wan2214bModel(Wan225bModel): "diffusion_model.", "diffusion_model.transformer_2." ) combined_dict[new_key] = low_noise_lora[key] + + # if we are not training both stages, we wont have transformer designations in the keys + if not self.train_high_noise and not self.train_low_noise: + new_dict = {} + for key in combined_dict: + if ".transformer_1." in key: + new_key = key.replace(".transformer_1.", ".") + elif ".transformer_2." in key: + new_key = key.replace(".transformer_2.", ".") + else: + new_key = key + new_dict[new_key] = combined_dict[key] + combined_dict = new_dict return combined_dict + + def get_model_to_train(self): + # todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key. + # called when setting up the LoRA. We only need to get the model for the stages we want to train. + if self.train_high_noise and self.train_low_noise: + # we are training both stages, return the unified model + return self.model + elif self.train_high_noise: + # we are only training the high noise stage, return transformer_1 + return self.model.transformer_1 + elif self.train_low_noise: + # we are only training the low noise stage, return transformer_2 + return self.model.transformer_2 + else: + raise ValueError( + "At least one of train_high_noise or train_low_noise must be True in model.model_kwargs" + ) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 22fd465a..8ca654af 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1862,7 +1862,20 @@ class SDTrainer(BaseSDTrainProcess): total_loss = None self.optimizer.zero_grad() for batch in batch_list: + if self.sd.is_multistage: + # handle multistage switching + if self.steps_this_boundary >= self.train_config.switch_boundary_every: + # iterate to make sure we only train trainable_multistage_boundaries + while True: + self.steps_this_boundary = 0 + self.current_boundary_index += 1 + if self.current_boundary_index >= len(self.sd.multistage_boundaries): + self.current_boundary_index = 0 + if self.current_boundary_index in self.sd.trainable_multistage_boundaries: + # if this boundary is trainable, we can stop looking + break loss = self.train_single_accumulation(batch) + self.steps_this_boundary += 1 if total_loss is None: total_loss = loss else: @@ -1907,7 +1920,7 @@ class SDTrainer(BaseSDTrainProcess): self.adapter.restore_embeddings() loss_dict = OrderedDict( - {'loss': loss.item()} + {'loss': (total_loss / len(batch_list)).item()} ) self.end_of_training_loop() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index daf3d33c..80e2e226 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -260,6 +260,9 @@ class BaseSDTrainProcess(BaseTrainProcess): torch.profiler.ProfilerActivity.CUDA, ], ) + + self.current_boundary_index = 0 + self.steps_this_boundary = 0 def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -1171,6 +1174,24 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch ) + if self.sd.is_multistage: + with self.timer('adjust_multistage_timesteps'): + # get our current sample range + boundaries = [1000] + self.sd.multistage_boundaries + boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1] + lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False) + hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True) + first_idx = lo.item() if hi > lo else 0 + last_idx = (hi - 1).item() if hi > lo else 999 + + min_noise_steps = first_idx + max_noise_steps = last_idx + + # clip min max indicies + min_noise_steps = max(min_noise_steps, 0) + max_noise_steps = min(max_noise_steps, num_train_timesteps - 1) + + with self.timer('prepare_timesteps_indices'): content_or_style = self.train_config.content_or_style @@ -1209,11 +1230,11 @@ class BaseSDTrainProcess(BaseTrainProcess): 0, self.train_config.num_train_timesteps - 1, min_noise_steps, - max_noise_steps - 1 + max_noise_steps ) timestep_indices = timestep_indices.long().clamp( - min_noise_steps + 1, - max_noise_steps - 1 + min_noise_steps, + max_noise_steps ) elif content_or_style == 'balanced': @@ -1226,7 +1247,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.train_config.noise_scheduler == 'flowmatch': # flowmatch uses indices, so we need to use indices min_idx = 0 - max_idx = max_noise_steps - 1 + max_idx = max_noise_steps timestep_indices = torch.randint( min_idx, max_idx, @@ -1676,7 +1697,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network = NetworkClass( text_encoder=text_encoder, - unet=unet, + unet=self.sd.get_model_to_train(), lora_dim=self.network_config.linear, multiplier=1.0, alpha=self.network_config.linear_alpha, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 81415a10..403abc54 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -335,7 +335,7 @@ class TrainConfig: self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) - self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) + self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999) self.batch_size: int = kwargs.get('batch_size', 1) self.orig_batch_size: int = self.batch_size self.dtype: str = kwargs.get('dtype', 'fp32') @@ -515,6 +515,9 @@ class TrainConfig: self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '') if isinstance(self.guidance_loss_target, tuple): self.guidance_loss_target = list(self.guidance_loss_target) + + # for multi stage models, how often to switch the boundary + self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1) ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 412e04d4..d446c25e 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -172,6 +172,11 @@ class BaseModel: self.sample_prompts_cache = None self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None + self.is_multistage = False + # a list of multistage boundaries starting with train step 1000 to first idx + self.multistage_boundaries: List[float] = [0.0] + # a list of trainable multistage boundaries + self.trainable_multistage_boundaries: List[int] = [0] # properties for old arch for backwards compatibility @property @@ -1502,3 +1507,7 @@ class BaseModel: def get_base_model_version(self) -> str: # override in child classes to get the base model version return "unknown" + + def get_model_to_train(self): + # called to get model to attach LoRAs to. Can be overridden in child classes + return self.unet diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 183cbb8d..86908884 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -211,6 +211,12 @@ class StableDiffusion: self.sample_prompts_cache = None + self.is_multistage = False + # a list of multistage boundaries starting with train step 1000 to first idx + self.multistage_boundaries: List[float] = [0.0] + # a list of trainable multistage boundaries + self.trainable_multistage_boundaries: List[int] = [0] + # properties for old arch for backwards compatibility @property def is_xl(self): @@ -3123,3 +3129,6 @@ class StableDiffusion: if self.is_v2: return 'sd_2.1' return 'sd_1.5' + + def get_model_to_train(self): + return self.unet diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index b2e31294..c9c4d0f8 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -40,10 +40,25 @@ export default function SimpleJob({ const isVideoModel = !!(modelArch?.group === 'video'); - let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; + const numTopCards = useMemo(() => { + let count = 4; // job settings, model config, target config, save config + if (modelArch?.additionalSections?.includes('model.multistage')) { + count += 1; // add multistage card + } + if (!modelArch?.disableSections?.includes('model.quantize')) { + count += 1; // add quantization card + } + return count; + + }, [modelArch]); - if (modelArch?.disableSections?.includes('model.quantize')) { - topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; + let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; + + if (numTopCards == 5) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; + } + if (numTopCards == 6) { + topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; } const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { @@ -91,7 +106,7 @@ export default function SimpleJob({ <>
- + {/* Model Configuration Section */} - + )} - + {modelArch?.additionalSections?.includes('model.multistage') && ( + + + setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')} + /> + setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')} + /> + + setJobConfig(value, 'config.process[0].train.switch_boundary_every')} + placeholder="eg. 1" + docKey={'train.switch_boundary_every'} + min={1} + required + /> + + )} + )} - +
- +
- +
= props => { return (
{label && ( -