mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added training to the ui. Still testing, but everything seems to be working.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from typing_extensions import Self
|
||||
import torch
|
||||
import yaml
|
||||
@@ -134,13 +134,15 @@ class DualWanTransformer3DModel(torch.nn.Module):
|
||||
getattr(self, t_name).to(self.device_torch)
|
||||
torch.cuda.empty_cache()
|
||||
self._active_transformer_name = t_name
|
||||
|
||||
|
||||
if self.transformer.device != hidden_states.device:
|
||||
if self.low_vram:
|
||||
# move other transformer to cpu
|
||||
other_tname = 'transformer_1' if t_name == 'transformer_2' else 'transformer_2'
|
||||
other_tname = (
|
||||
"transformer_1" if t_name == "transformer_2" else "transformer_2"
|
||||
)
|
||||
getattr(self, other_tname).to("cpu")
|
||||
|
||||
|
||||
self.transformer.to(hidden_states.device)
|
||||
|
||||
return self.transformer(
|
||||
@@ -184,11 +186,33 @@ class Wan2214bModel(Wan225bModel):
|
||||
self.target_lora_modules = ["DualWanTransformer3DModel"]
|
||||
self._wan_cache = None
|
||||
|
||||
self.is_multistage = True
|
||||
# multistage boundaries split the models up when sampling timesteps
|
||||
# for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2
|
||||
self.multistage_boundaries: List[float] = [0.875, 0.0]
|
||||
|
||||
self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True)
|
||||
self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True)
|
||||
|
||||
self.trainable_multistage_boundaries: List[int] = []
|
||||
if self.train_high_noise:
|
||||
self.trainable_multistage_boundaries.append(0)
|
||||
if self.train_low_noise:
|
||||
self.trainable_multistage_boundaries.append(1)
|
||||
|
||||
if len(self.trainable_multistage_boundaries) == 0:
|
||||
raise ValueError(
|
||||
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_step_saves_to_keep_multiplier(self):
|
||||
# the cleanup mechanism checks this to see how many saves to keep
|
||||
# if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep
|
||||
if self.network is not None:
|
||||
if (
|
||||
self.network is not None
|
||||
and self.network.network_config.split_multistage_loras
|
||||
):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
@@ -264,7 +288,7 @@ class Wan2214bModel(Wan225bModel):
|
||||
transformer_1.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
# todo handle two ARAs
|
||||
self.print_and_status_update("Quantizing Transformer 1")
|
||||
quantize_model(self, transformer_1)
|
||||
@@ -289,7 +313,7 @@ class Wan2214bModel(Wan225bModel):
|
||||
transformer_2.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
# todo handle two ARAs
|
||||
self.print_and_status_update("Quantizing Transformer 2")
|
||||
quantize_model(self, transformer_2)
|
||||
@@ -309,7 +333,13 @@ class Wan2214bModel(Wan225bModel):
|
||||
boundary_ratio=boundary_ratio_t2v,
|
||||
low_vram=self.model_config.low_vram,
|
||||
)
|
||||
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None:
|
||||
# apply the accuracy recovery adapter to both transformers
|
||||
self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers")
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
return transformer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
@@ -407,17 +437,20 @@ class Wan2214bModel(Wan225bModel):
|
||||
# just save as a combo lora
|
||||
save_file(state_dict, output_path, metadata=metadata)
|
||||
return
|
||||
|
||||
|
||||
# we need to build out both dictionaries for high and low noise LoRAs
|
||||
high_noise_lora = {}
|
||||
low_noise_lora = {}
|
||||
|
||||
only_train_high_noise = self.train_high_noise and not self.train_low_noise
|
||||
only_train_low_noise = self.train_low_noise and not self.train_high_noise
|
||||
|
||||
for key in state_dict:
|
||||
if ".transformer_1." in key:
|
||||
if ".transformer_1." in key or only_train_high_noise:
|
||||
# this is a high noise LoRA
|
||||
new_key = key.replace(".transformer_1.", ".")
|
||||
high_noise_lora[new_key] = state_dict[key]
|
||||
elif ".transformer_2." in key:
|
||||
elif ".transformer_2." in key or only_train_low_noise:
|
||||
# this is a low noise LoRA
|
||||
new_key = key.replace(".transformer_2.", ".")
|
||||
low_noise_lora[new_key] = state_dict[key]
|
||||
@@ -439,11 +472,14 @@ class Wan2214bModel(Wan225bModel):
|
||||
|
||||
def load_lora(self, file: str):
|
||||
# if it doesnt have high_noise or low_noise, it is a combo LoRA
|
||||
if "_high_noise.safetensors" not in file and "_low_noise.safetensors" not in file:
|
||||
# this is a combined LoRA, we need to split it up
|
||||
if (
|
||||
"_high_noise.safetensors" not in file
|
||||
and "_low_noise.safetensors" not in file
|
||||
):
|
||||
# this is a combined LoRA, we dont need to split it up
|
||||
sd = load_file(file)
|
||||
return sd
|
||||
|
||||
|
||||
# we may have been passed the high_noise or the low_noise LoRA path, but we need to load both
|
||||
high_noise_lora_path = file.replace(
|
||||
"_low_noise.safetensors", "_high_noise.safetensors"
|
||||
@@ -454,7 +490,7 @@ class Wan2214bModel(Wan225bModel):
|
||||
|
||||
combined_dict = {}
|
||||
|
||||
if os.path.exists(high_noise_lora_path):
|
||||
if os.path.exists(high_noise_lora_path) and self.train_high_noise:
|
||||
# load the high noise LoRA
|
||||
high_noise_lora = load_file(high_noise_lora_path)
|
||||
for key in high_noise_lora:
|
||||
@@ -462,7 +498,7 @@ class Wan2214bModel(Wan225bModel):
|
||||
"diffusion_model.", "diffusion_model.transformer_1."
|
||||
)
|
||||
combined_dict[new_key] = high_noise_lora[key]
|
||||
if os.path.exists(low_noise_lora_path):
|
||||
if os.path.exists(low_noise_lora_path) and self.train_low_noise:
|
||||
# load the low noise LoRA
|
||||
low_noise_lora = load_file(low_noise_lora_path)
|
||||
for key in low_noise_lora:
|
||||
@@ -470,5 +506,35 @@ class Wan2214bModel(Wan225bModel):
|
||||
"diffusion_model.", "diffusion_model.transformer_2."
|
||||
)
|
||||
combined_dict[new_key] = low_noise_lora[key]
|
||||
|
||||
# if we are not training both stages, we wont have transformer designations in the keys
|
||||
if not self.train_high_noise and not self.train_low_noise:
|
||||
new_dict = {}
|
||||
for key in combined_dict:
|
||||
if ".transformer_1." in key:
|
||||
new_key = key.replace(".transformer_1.", ".")
|
||||
elif ".transformer_2." in key:
|
||||
new_key = key.replace(".transformer_2.", ".")
|
||||
else:
|
||||
new_key = key
|
||||
new_dict[new_key] = combined_dict[key]
|
||||
combined_dict = new_dict
|
||||
|
||||
return combined_dict
|
||||
|
||||
def get_model_to_train(self):
|
||||
# todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key.
|
||||
# called when setting up the LoRA. We only need to get the model for the stages we want to train.
|
||||
if self.train_high_noise and self.train_low_noise:
|
||||
# we are training both stages, return the unified model
|
||||
return self.model
|
||||
elif self.train_high_noise:
|
||||
# we are only training the high noise stage, return transformer_1
|
||||
return self.model.transformer_1
|
||||
elif self.train_low_noise:
|
||||
# we are only training the low noise stage, return transformer_2
|
||||
return self.model.transformer_2
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
|
||||
)
|
||||
|
||||
@@ -1862,7 +1862,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
total_loss = None
|
||||
self.optimizer.zero_grad()
|
||||
for batch in batch_list:
|
||||
if self.sd.is_multistage:
|
||||
# handle multistage switching
|
||||
if self.steps_this_boundary >= self.train_config.switch_boundary_every:
|
||||
# iterate to make sure we only train trainable_multistage_boundaries
|
||||
while True:
|
||||
self.steps_this_boundary = 0
|
||||
self.current_boundary_index += 1
|
||||
if self.current_boundary_index >= len(self.sd.multistage_boundaries):
|
||||
self.current_boundary_index = 0
|
||||
if self.current_boundary_index in self.sd.trainable_multistage_boundaries:
|
||||
# if this boundary is trainable, we can stop looking
|
||||
break
|
||||
loss = self.train_single_accumulation(batch)
|
||||
self.steps_this_boundary += 1
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
@@ -1907,7 +1920,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.adapter.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
{'loss': (total_loss / len(batch_list)).item()}
|
||||
)
|
||||
|
||||
self.end_of_training_loop()
|
||||
|
||||
@@ -260,6 +260,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
self.current_boundary_index = 0
|
||||
self.steps_this_boundary = 0
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -1171,6 +1174,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
if self.sd.is_multistage:
|
||||
with self.timer('adjust_multistage_timesteps'):
|
||||
# get our current sample range
|
||||
boundaries = [1000] + self.sd.multistage_boundaries
|
||||
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
|
||||
lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False)
|
||||
hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True)
|
||||
first_idx = lo.item() if hi > lo else 0
|
||||
last_idx = (hi - 1).item() if hi > lo else 999
|
||||
|
||||
min_noise_steps = first_idx
|
||||
max_noise_steps = last_idx
|
||||
|
||||
# clip min max indicies
|
||||
min_noise_steps = max(min_noise_steps, 0)
|
||||
max_noise_steps = min(max_noise_steps, num_train_timesteps - 1)
|
||||
|
||||
|
||||
with self.timer('prepare_timesteps_indices'):
|
||||
|
||||
content_or_style = self.train_config.content_or_style
|
||||
@@ -1209,11 +1230,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
0,
|
||||
self.train_config.num_train_timesteps - 1,
|
||||
min_noise_steps,
|
||||
max_noise_steps - 1
|
||||
max_noise_steps
|
||||
)
|
||||
timestep_indices = timestep_indices.long().clamp(
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1
|
||||
min_noise_steps,
|
||||
max_noise_steps
|
||||
)
|
||||
|
||||
elif content_or_style == 'balanced':
|
||||
@@ -1226,7 +1247,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.noise_scheduler == 'flowmatch':
|
||||
# flowmatch uses indices, so we need to use indices
|
||||
min_idx = 0
|
||||
max_idx = max_noise_steps - 1
|
||||
max_idx = max_noise_steps
|
||||
timestep_indices = torch.randint(
|
||||
min_idx,
|
||||
max_idx,
|
||||
@@ -1676,7 +1697,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.network = NetworkClass(
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
unet=self.sd.get_model_to_train(),
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
|
||||
@@ -335,7 +335,7 @@ class TrainConfig:
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.orig_batch_size: int = self.batch_size
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
@@ -515,6 +515,9 @@ class TrainConfig:
|
||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||
if isinstance(self.guidance_loss_target, tuple):
|
||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||
|
||||
# for multi stage models, how often to switch the boundary
|
||||
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
|
||||
|
||||
|
||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||
|
||||
@@ -172,6 +172,11 @@ class BaseModel:
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
|
||||
self.is_multistage = False
|
||||
# a list of multistage boundaries starting with train step 1000 to first idx
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -1502,3 +1507,7 @@ class BaseModel:
|
||||
def get_base_model_version(self) -> str:
|
||||
# override in child classes to get the base model version
|
||||
return "unknown"
|
||||
|
||||
def get_model_to_train(self):
|
||||
# called to get model to attach LoRAs to. Can be overridden in child classes
|
||||
return self.unet
|
||||
|
||||
@@ -211,6 +211,12 @@ class StableDiffusion:
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
self.is_multistage = False
|
||||
# a list of multistage boundaries starting with train step 1000 to first idx
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -3123,3 +3129,6 @@ class StableDiffusion:
|
||||
if self.is_v2:
|
||||
return 'sd_2.1'
|
||||
return 'sd_1.5'
|
||||
|
||||
def get_model_to_train(self):
|
||||
return self.unet
|
||||
|
||||
@@ -40,10 +40,25 @@ export default function SimpleJob({
|
||||
|
||||
const isVideoModel = !!(modelArch?.group === 'video');
|
||||
|
||||
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
|
||||
const numTopCards = useMemo(() => {
|
||||
let count = 4; // job settings, model config, target config, save config
|
||||
if (modelArch?.additionalSections?.includes('model.multistage')) {
|
||||
count += 1; // add multistage card
|
||||
}
|
||||
if (!modelArch?.disableSections?.includes('model.quantize')) {
|
||||
count += 1; // add quantization card
|
||||
}
|
||||
return count;
|
||||
|
||||
}, [modelArch]);
|
||||
|
||||
if (modelArch?.disableSections?.includes('model.quantize')) {
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
|
||||
if (numTopCards == 5) {
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
|
||||
}
|
||||
if (numTopCards == 6) {
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
|
||||
}
|
||||
|
||||
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
|
||||
@@ -91,7 +106,7 @@ export default function SimpleJob({
|
||||
<>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
<div className={topBarClass}>
|
||||
<Card title="Job Settings">
|
||||
<Card title="Job">
|
||||
<TextInput
|
||||
label="Training Name"
|
||||
value={jobConfig.config.name}
|
||||
@@ -124,7 +139,7 @@ export default function SimpleJob({
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
<Card title="Model Configuration">
|
||||
<Card title="Model">
|
||||
<SelectInput
|
||||
label="Model Architecture"
|
||||
value={jobConfig.config.process[0].model.arch}
|
||||
@@ -239,7 +254,32 @@ export default function SimpleJob({
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Target Configuration">
|
||||
{modelArch?.additionalSections?.includes('model.multistage') && (
|
||||
<Card title="Multistage">
|
||||
<FormGroup label="Stages to Train" docKey={'model.multistage'}>
|
||||
<Checkbox
|
||||
label="High Noise"
|
||||
checked={jobConfig.config.process[0].model.model_kwargs?.train_high_noise || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Low Noise"
|
||||
checked={jobConfig.config.process[0].model.model_kwargs?.train_low_noise || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="Switch Every"
|
||||
value={jobConfig.config.process[0].train.switch_boundary_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
|
||||
placeholder="eg. 1"
|
||||
docKey={'train.switch_boundary_every'}
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Target">
|
||||
<SelectInput
|
||||
label="Target Type"
|
||||
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
||||
@@ -295,7 +335,7 @@ export default function SimpleJob({
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
<Card title="Save Configuration">
|
||||
<Card title="Save">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
value={jobConfig.config.process[0].save.dtype}
|
||||
@@ -325,7 +365,7 @@ export default function SimpleJob({
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training Configuration">
|
||||
<Card title="Training">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
@@ -645,7 +685,7 @@ export default function SimpleJob({
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Sample Configuration">
|
||||
<Card title="Sample">
|
||||
<div
|
||||
className={
|
||||
isVideoModel
|
||||
|
||||
@@ -78,6 +78,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
diff_output_preservation: false,
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
diff_output_preservation_class: 'person',
|
||||
switch_boundary_every: 1,
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
|
||||
@@ -3,7 +3,13 @@ import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||
type AdditionalSections =
|
||||
| 'datasets.control_path'
|
||||
| 'datasets.do_i2v'
|
||||
| 'sample.ctrl_img'
|
||||
| 'datasets.num_frames'
|
||||
| 'model.multistage'
|
||||
| 'model.low_vram';
|
||||
type ModelGroup = 'image' | 'video';
|
||||
|
||||
export interface ModelArch {
|
||||
@@ -121,7 +127,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||
@@ -139,7 +145,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
@@ -158,7 +164,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
@@ -177,11 +183,41 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||
},
|
||||
{
|
||||
name: 'wan22_14b:t2v',
|
||||
label: 'Wan 2.2 (14B)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
|
||||
'config.process[0].model.model_kwargs': [
|
||||
{
|
||||
train_high_noise: true,
|
||||
train_low_noise: true,
|
||||
},
|
||||
{},
|
||||
],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
|
||||
// accuracyRecoveryAdapters: {
|
||||
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
|
||||
// },
|
||||
},
|
||||
{
|
||||
name: 'wan22_5b',
|
||||
label: 'Wan 2.2 TI2V (5B)',
|
||||
|
||||
@@ -283,7 +283,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
<label className={classNames(labelClasses, 'mb-2')}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
@@ -292,7 +292,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<div className="px-4 space-y-2">{children}</div>
|
||||
<div className="space-y-2">{children}</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -111,6 +111,36 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'model.multistage': {
|
||||
title: 'Stages to Train',
|
||||
description: (
|
||||
<>
|
||||
Some models have multi stage networks that are trained and used separately in the denoising process. Most
|
||||
common, is to have 2 stages. One for high noise and one for low noise. You can choose to train both stages at
|
||||
once or train them separately. If trained at the same time, The trainer will alternate between training each
|
||||
model every so many steps and will output 2 different LoRAs. If you choose to train only one stage, the
|
||||
trainer will only train that stage and output a single LoRA.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.switch_boundary_every': {
|
||||
title: 'Switch Boundary Every',
|
||||
description: (
|
||||
<>
|
||||
When training a model with multiple stages, this setting controls how often the trainer will switch between
|
||||
training each stage.
|
||||
<br />
|
||||
<br />
|
||||
For low vram settings, the model not being trained will be unloaded from the gpu to save memory. This takes some
|
||||
time to do, so it is recommended to alternate less often when using low vram. A setting like 10 or 20 is
|
||||
recommended for low vram settings.
|
||||
<br />
|
||||
<br />
|
||||
The swap happens at the batch level, meaning it will swap between a gradient accumulation steps. To train both
|
||||
stages in a single step, set them to switch every 1 step and set gradient accumulation to 2.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -119,6 +119,7 @@ export interface TrainConfig {
|
||||
diff_output_preservation: boolean;
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
switch_boundary_every: number;
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
|
||||
Reference in New Issue
Block a user