mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Set gradient checkpointing on unet enabled by default. Help out immensly with sdxl backprop spikes
This commit is contained in:
@@ -34,6 +34,8 @@ config:
|
|||||||
steps: 500
|
steps: 500
|
||||||
# I have had good results with 4e-4 to 1e-4 at 500 steps
|
# I have had good results with 4e-4 to 1e-4 at 500 steps
|
||||||
lr: 1e-4
|
lr: 1e-4
|
||||||
|
# enables gradient checkpoint, saves vram, leave it on
|
||||||
|
gradient_checkpointing: true
|
||||||
# train the unet. I recommend leaving this true
|
# train the unet. I recommend leaving this true
|
||||||
train_unet: true
|
train_unet: true
|
||||||
# train the text encoder. I don't recommend this unless you have a special use case
|
# train the text encoder. I don't recommend this unless you have a special use case
|
||||||
@@ -66,6 +68,7 @@ config:
|
|||||||
name_or_path: "runwayml/stable-diffusion-v1-5"
|
name_or_path: "runwayml/stable-diffusion-v1-5"
|
||||||
is_v2: false # for v2 models
|
is_v2: false # for v2 models
|
||||||
is_v_pred: false # for v-prediction models (most v2 models)
|
is_v_pred: false # for v-prediction models (most v2 models)
|
||||||
|
is_xl: false # for SDXL models
|
||||||
|
|
||||||
# saving config
|
# saving config
|
||||||
save:
|
save:
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class TrainConfig:
|
|||||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False)
|
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
|
|||||||
Reference in New Issue
Block a user