mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added training to the ui. Still testing, but everything seems to be working.
This commit is contained in:
@@ -335,7 +335,7 @@ class TrainConfig:
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.orig_batch_size: int = self.batch_size
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
@@ -515,6 +515,9 @@ class TrainConfig:
|
||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||
if isinstance(self.guidance_loss_target, tuple):
|
||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||
|
||||
# for multi stage models, how often to switch the boundary
|
||||
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
|
||||
|
||||
|
||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||
|
||||
@@ -172,6 +172,11 @@ class BaseModel:
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
|
||||
self.is_multistage = False
|
||||
# a list of multistage boundaries starting with train step 1000 to first idx
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -1502,3 +1507,7 @@ class BaseModel:
|
||||
def get_base_model_version(self) -> str:
|
||||
# override in child classes to get the base model version
|
||||
return "unknown"
|
||||
|
||||
def get_model_to_train(self):
|
||||
# called to get model to attach LoRAs to. Can be overridden in child classes
|
||||
return self.unet
|
||||
|
||||
@@ -211,6 +211,12 @@ class StableDiffusion:
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
self.is_multistage = False
|
||||
# a list of multistage boundaries starting with train step 1000 to first idx
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -3123,3 +3129,6 @@ class StableDiffusion:
|
||||
if self.is_v2:
|
||||
return 'sd_2.1'
|
||||
return 'sd_1.5'
|
||||
|
||||
def get_model_to_train(self):
|
||||
return self.unet
|
||||
|
||||
Reference in New Issue
Block a user