Added training to the ui. Still testing, but everything seems to be working.

This commit is contained in:
Jaret Burkett
2025-08-16 05:51:37 -06:00
parent ca7bfa414b
commit 8ea2cf00f6
12 changed files with 268 additions and 39 deletions

View File

@@ -1,6 +1,6 @@
from functools import partial
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, List
from typing_extensions import Self
import torch
import yaml
@@ -134,13 +134,15 @@ class DualWanTransformer3DModel(torch.nn.Module):
getattr(self, t_name).to(self.device_torch)
torch.cuda.empty_cache()
self._active_transformer_name = t_name
if self.transformer.device != hidden_states.device:
if self.low_vram:
# move other transformer to cpu
other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2'
other_tname = (
"transformer_1" if t_name == "transformer_2" else "transformer_2"
)
getattr(self, other_tname).to("cpu")
self.transformer.to(hidden_states.device)
return self.transformer(
@@ -184,11 +186,33 @@ class Wan2214bModel(Wan225bModel):
self.target_lora_modules = ["DualWanTransformer3DModel"]
self._wan_cache = None
self.is_multistage = True
# multistage boundaries split the models up when sampling timesteps
# for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2
self.multistage_boundaries: List[float] = [0.875, 0.0]
self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True)
self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True)
self.trainable_multistage_boundaries: List[int] = []
if self.train_high_noise:
self.trainable_multistage_boundaries.append(0)
if self.train_low_noise:
self.trainable_multistage_boundaries.append(1)
if len(self.trainable_multistage_boundaries) == 0:
raise ValueError(
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
)
@property
def max_step_saves_to_keep_multiplier(self):
# the cleanup mechanism checks this to see how many saves to keep
# if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep
if self.network is not None:
if (
self.network is not None
and self.network.network_config.split_multistage_loras
):
return 2
return 1
@@ -264,7 +288,7 @@ class Wan2214bModel(Wan225bModel):
transformer_1.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.quantize:
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
# todo handle two ARAs
self.print_and_status_update("Quantizing Transformer 1")
quantize_model(self, transformer_1)
@@ -289,7 +313,7 @@ class Wan2214bModel(Wan225bModel):
transformer_2.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.quantize:
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
# todo handle two ARAs
self.print_and_status_update("Quantizing Transformer 2")
quantize_model(self, transformer_2)
@@ -309,7 +333,13 @@ class Wan2214bModel(Wan225bModel):
boundary_ratio=boundary_ratio_t2v,
low_vram=self.model_config.low_vram,
)
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None:
# apply the accuracy recovery adapter to both transformers
self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers")
quantize_model(self, transformer)
flush()
return transformer
def get_generation_pipeline(self):
@@ -407,17 +437,20 @@ class Wan2214bModel(Wan225bModel):
# just save as a combo lora
save_file(state_dict, output_path, metadata=metadata)
return
# we need to build out both dictionaries for high and low noise LoRAs
high_noise_lora = {}
low_noise_lora = {}
only_train_high_noise = self.train_high_noise and not self.train_low_noise
only_train_low_noise = self.train_low_noise and not self.train_high_noise
for key in state_dict:
if ".transformer_1." in key:
if ".transformer_1." in key or only_train_high_noise:
# this is a high noise LoRA
new_key = key.replace(".transformer_1.", ".")
high_noise_lora[new_key] = state_dict[key]
elif ".transformer_2." in key:
elif ".transformer_2." in key or only_train_low_noise:
# this is a low noise LoRA
new_key = key.replace(".transformer_2.", ".")
low_noise_lora[new_key] = state_dict[key]
@@ -439,11 +472,14 @@ class Wan2214bModel(Wan225bModel):
def load_lora(self, file: str):
# if it doesnt have high_noise or low_noise, it is a combo LoRA
if "_high_noise.safetensors" not in file and "_low_noise.safetensors" not in file:
# this is a combined LoRA, we need to split it up
if (
"_high_noise.safetensors" not in file
and "_low_noise.safetensors" not in file
):
# this is a combined LoRA, we dont need to split it up
sd = load_file(file)
return sd
# we may have been passed the high_noise or the low_noise LoRA path, but we need to load both
high_noise_lora_path = file.replace(
"_low_noise.safetensors", "_high_noise.safetensors"
@@ -454,7 +490,7 @@ class Wan2214bModel(Wan225bModel):
combined_dict = {}
if os.path.exists(high_noise_lora_path):
if os.path.exists(high_noise_lora_path) and self.train_high_noise:
# load the high noise LoRA
high_noise_lora = load_file(high_noise_lora_path)
for key in high_noise_lora:
@@ -462,7 +498,7 @@ class Wan2214bModel(Wan225bModel):
"diffusion_model.", "diffusion_model.transformer_1."
)
combined_dict[new_key] = high_noise_lora[key]
if os.path.exists(low_noise_lora_path):
if os.path.exists(low_noise_lora_path) and self.train_low_noise:
# load the low noise LoRA
low_noise_lora = load_file(low_noise_lora_path)
for key in low_noise_lora:
@@ -470,5 +506,35 @@ class Wan2214bModel(Wan225bModel):
"diffusion_model.", "diffusion_model.transformer_2."
)
combined_dict[new_key] = low_noise_lora[key]
# if we are not training both stages, we wont have transformer designations in the keys
if not self.train_high_noise and not self.train_low_noise:
new_dict = {}
for key in combined_dict:
if ".transformer_1." in key:
new_key = key.replace(".transformer_1.", ".")
elif ".transformer_2." in key:
new_key = key.replace(".transformer_2.", ".")
else:
new_key = key
new_dict[new_key] = combined_dict[key]
combined_dict = new_dict
return combined_dict
def get_model_to_train(self):
# todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key.
# called when setting up the LoRA. We only need to get the model for the stages we want to train.
if self.train_high_noise and self.train_low_noise:
# we are training both stages, return the unified model
return self.model
elif self.train_high_noise:
# we are only training the high noise stage, return transformer_1
return self.model.transformer_1
elif self.train_low_noise:
# we are only training the low noise stage, return transformer_2
return self.model.transformer_2
else:
raise ValueError(
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
)

View File

@@ -1862,7 +1862,20 @@ class SDTrainer(BaseSDTrainProcess):
total_loss = None
self.optimizer.zero_grad()
for batch in batch_list:
if self.sd.is_multistage:
# handle multistage switching
if self.steps_this_boundary >= self.train_config.switch_boundary_every:
# iterate to make sure we only train trainable_multistage_boundaries
while True:
self.steps_this_boundary = 0
self.current_boundary_index += 1
if self.current_boundary_index >= len(self.sd.multistage_boundaries):
self.current_boundary_index = 0
if self.current_boundary_index in self.sd.trainable_multistage_boundaries:
# if this boundary is trainable, we can stop looking
break
loss = self.train_single_accumulation(batch)
self.steps_this_boundary += 1
if total_loss is None:
total_loss = loss
else:
@@ -1907,7 +1920,7 @@ class SDTrainer(BaseSDTrainProcess):
self.adapter.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}
{'loss': (total_loss / len(batch_list)).item()}
)
self.end_of_training_loop()

View File

@@ -260,6 +260,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.profiler.ProfilerActivity.CUDA,
],
)
self.current_boundary_index = 0
self.steps_this_boundary = 0
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -1171,6 +1174,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch
)
if self.sd.is_multistage:
with self.timer('adjust_multistage_timesteps'):
# get our current sample range
boundaries = [1000] + self.sd.multistage_boundaries
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False)
hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True)
first_idx = lo.item() if hi > lo else 0
last_idx = (hi - 1).item() if hi > lo else 999
min_noise_steps = first_idx
max_noise_steps = last_idx
# clip min max indicies
min_noise_steps = max(min_noise_steps, 0)
max_noise_steps = min(max_noise_steps, num_train_timesteps - 1)
with self.timer('prepare_timesteps_indices'):
content_or_style = self.train_config.content_or_style
@@ -1209,11 +1230,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
0,
self.train_config.num_train_timesteps - 1,
min_noise_steps,
max_noise_steps - 1
max_noise_steps
)
timestep_indices = timestep_indices.long().clamp(
min_noise_steps + 1,
max_noise_steps - 1
min_noise_steps,
max_noise_steps
)
elif content_or_style == 'balanced':
@@ -1226,7 +1247,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.noise_scheduler == 'flowmatch':
# flowmatch uses indices, so we need to use indices
min_idx = 0
max_idx = max_noise_steps - 1
max_idx = max_noise_steps
timestep_indices = torch.randint(
min_idx,
max_idx,
@@ -1676,7 +1697,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network = NetworkClass(
text_encoder=text_encoder,
unet=unet,
unet=self.sd.get_model_to_train(),
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,

View File

@@ -335,7 +335,7 @@ class TrainConfig:
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999)
self.batch_size: int = kwargs.get('batch_size', 1)
self.orig_batch_size: int = self.batch_size
self.dtype: str = kwargs.get('dtype', 'fp32')
@@ -515,6 +515,9 @@ class TrainConfig:
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
if isinstance(self.guidance_loss_target, tuple):
self.guidance_loss_target = list(self.guidance_loss_target)
# for multi stage models, how often to switch the boundary
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']

View File

@@ -172,6 +172,11 @@ class BaseModel:
self.sample_prompts_cache = None
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
self.is_multistage = False
# a list of multistage boundaries starting with train step 1000 to first idx
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# properties for old arch for backwards compatibility
@property
@@ -1502,3 +1507,7 @@ class BaseModel:
def get_base_model_version(self) -> str:
# override in child classes to get the base model version
return "unknown"
def get_model_to_train(self):
# called to get model to attach LoRAs to. Can be overridden in child classes
return self.unet

View File

@@ -211,6 +211,12 @@ class StableDiffusion:
self.sample_prompts_cache = None
self.is_multistage = False
# a list of multistage boundaries starting with train step 1000 to first idx
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# properties for old arch for backwards compatibility
@property
def is_xl(self):
@@ -3123,3 +3129,6 @@ class StableDiffusion:
if self.is_v2:
return 'sd_2.1'
return 'sd_1.5'
def get_model_to_train(self):
return self.unet

View File

@@ -40,10 +40,25 @@ export default function SimpleJob({
const isVideoModel = !!(modelArch?.group === 'video');
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
const numTopCards = useMemo(() => {
let count = 4; // job settings, model config, target config, save config
if (modelArch?.additionalSections?.includes('model.multistage')) {
count += 1; // add multistage card
}
if (!modelArch?.disableSections?.includes('model.quantize')) {
count += 1; // add quantization card
}
return count;
}, [modelArch]);
if (modelArch?.disableSections?.includes('model.quantize')) {
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
if (numTopCards == 5) {
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
}
if (numTopCards == 6) {
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
}
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
@@ -91,7 +106,7 @@ export default function SimpleJob({
<>
<form onSubmit={handleSubmit} className="space-y-8">
<div className={topBarClass}>
<Card title="Job Settings">
<Card title="Job">
<TextInput
label="Training Name"
value={jobConfig.config.name}
@@ -124,7 +139,7 @@ export default function SimpleJob({
</Card>
{/* Model Configuration Section */}
<Card title="Model Configuration">
<Card title="Model">
<SelectInput
label="Model Architecture"
value={jobConfig.config.process[0].model.arch}
@@ -239,7 +254,32 @@ export default function SimpleJob({
/>
</Card>
)}
<Card title="Target Configuration">
{modelArch?.additionalSections?.includes('model.multistage') && (
<Card title="Multistage">
<FormGroup label="Stages to Train" docKey={'model.multistage'}>
<Checkbox
label="High Noise"
checked={jobConfig.config.process[0].model.model_kwargs?.train_high_noise || false}
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')}
/>
<Checkbox
label="Low Noise"
checked={jobConfig.config.process[0].model.model_kwargs?.train_low_noise || false}
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')}
/>
</FormGroup>
<NumberInput
label="Switch Every"
value={jobConfig.config.process[0].train.switch_boundary_every}
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
placeholder="eg. 1"
docKey={'train.switch_boundary_every'}
min={1}
required
/>
</Card>
)}
<Card title="Target">
<SelectInput
label="Target Type"
value={jobConfig.config.process[0].network?.type ?? 'lora'}
@@ -295,7 +335,7 @@ export default function SimpleJob({
</>
)}
</Card>
<Card title="Save Configuration">
<Card title="Save">
<SelectInput
label="Data Type"
value={jobConfig.config.process[0].save.dtype}
@@ -325,7 +365,7 @@ export default function SimpleJob({
</Card>
</div>
<div>
<Card title="Training Configuration">
<Card title="Training">
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
<div>
<NumberInput
@@ -645,7 +685,7 @@ export default function SimpleJob({
</Card>
</div>
<div>
<Card title="Sample Configuration">
<Card title="Sample">
<div
className={
isVideoModel

View File

@@ -78,6 +78,7 @@ export const defaultJobConfig: JobConfig = {
diff_output_preservation: false,
diff_output_preservation_multiplier: 1.0,
diff_output_preservation_class: 'person',
switch_boundary_every: 1,
},
model: {
name_or_path: 'ostris/Flex.1-alpha',

View File

@@ -3,7 +3,13 @@ import { GroupedSelectOption, SelectOption } from '@/types';
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
type AdditionalSections =
| 'datasets.control_path'
| 'datasets.do_i2v'
| 'sample.ctrl_img'
| 'datasets.num_frames'
| 'model.multistage'
| 'model.low_vram';
type ModelGroup = 'image' | 'video';
export interface ModelArch {
@@ -121,7 +127,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram'],
@@ -139,7 +145,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
},
disableSections: ['network.conv'],
@@ -158,7 +164,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
},
disableSections: ['network.conv'],
@@ -177,11 +183,41 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram'],
},
{
name: 'wan22_14b:t2v',
label: 'Wan 2.2 (14B)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [16, 1],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
'config.process[0].model.model_kwargs': [
{
train_high_noise: true,
train_low_noise: true,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
// accuracyRecoveryAdapters: {
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
// },
},
{
name: 'wan22_5b',
label: 'Wan 2.2 TI2V (5B)',

View File

@@ -283,7 +283,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
return (
<div className={classNames(className)}>
{label && (
<label className={labelClasses}>
<label className={classNames(labelClasses, 'mb-2')}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
@@ -292,7 +292,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
)}
</label>
)}
<div className="px-4 space-y-2">{children}</div>
<div className="space-y-2">{children}</div>
</div>
);
};

View File

@@ -111,6 +111,36 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'model.multistage': {
title: 'Stages to Train',
description: (
<>
Some models have multi stage networks that are trained and used separately in the denoising process. Most
common, is to have 2 stages. One for high noise and one for low noise. You can choose to train both stages at
once or train them separately. If trained at the same time, The trainer will alternate between training each
model every so many steps and will output 2 different LoRAs. If you choose to train only one stage, the
trainer will only train that stage and output a single LoRA.
</>
),
},
'train.switch_boundary_every': {
title: 'Switch Boundary Every',
description: (
<>
When training a model with multiple stages, this setting controls how often the trainer will switch between
training each stage.
<br />
<br />
For low vram settings, the model not being trained will be unloaded from the gpu to save memory. This takes some
time to do, so it is recommended to alternate less often when using low vram. A setting like 10 or 20 is
recommended for low vram settings.
<br />
<br />
The swap happens at the batch level, meaning it will swap between a gradient accumulation steps. To train both
stages in a single step, set them to switch every 1 step and set gradient accumulation to 2.
</>
),
},
};
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -119,6 +119,7 @@ export interface TrainConfig {
diff_output_preservation: boolean;
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;
switch_boundary_every: number;
}
export interface QuantizeKwargsConfig {