Added training to the ui. Still testing, but everything seems to be working.

This commit is contained in:
Jaret Burkett
2025-08-16 05:51:37 -06:00
parent ca7bfa414b
commit 8ea2cf00f6
12 changed files with 268 additions and 39 deletions

View File

@@ -335,7 +335,7 @@ class TrainConfig:
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999)
self.batch_size: int = kwargs.get('batch_size', 1)
self.orig_batch_size: int = self.batch_size
self.dtype: str = kwargs.get('dtype', 'fp32')
@@ -515,6 +515,9 @@ class TrainConfig:
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
if isinstance(self.guidance_loss_target, tuple):
self.guidance_loss_target = list(self.guidance_loss_target)
# for multi stage models, how often to switch the boundary
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']

View File

@@ -172,6 +172,11 @@ class BaseModel:
self.sample_prompts_cache = None
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
self.is_multistage = False
# a list of multistage boundaries starting with train step 1000 to first idx
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# properties for old arch for backwards compatibility
@property
@@ -1502,3 +1507,7 @@ class BaseModel:
def get_base_model_version(self) -> str:
# override in child classes to get the base model version
return "unknown"
def get_model_to_train(self):
# called to get model to attach LoRAs to. Can be overridden in child classes
return self.unet

View File

@@ -211,6 +211,12 @@ class StableDiffusion:
self.sample_prompts_cache = None
self.is_multistage = False
# a list of multistage boundaries starting with train step 1000 to first idx
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# properties for old arch for backwards compatibility
@property
def is_xl(self):
@@ -3123,3 +3129,6 @@ class StableDiffusion:
if self.is_v2:
return 'sd_2.1'
return 'sd_1.5'
def get_model_to_train(self):
return self.unet