mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Improvements for full tuning flux. Added debugging launch config for vscode
This commit is contained in:
@@ -41,6 +41,7 @@ def get_train_sd_device_state_preset(
|
||||
train_embedding: bool = False,
|
||||
train_refiner: bool = False,
|
||||
unload_text_encoder: bool = False,
|
||||
require_grads: bool = True,
|
||||
):
|
||||
preset = copy.deepcopy(empty_preset)
|
||||
if not cached_latents:
|
||||
@@ -48,27 +49,27 @@ def get_train_sd_device_state_preset(
|
||||
|
||||
if train_unet:
|
||||
preset['unet']['training'] = True
|
||||
preset['unet']['requires_grad'] = True
|
||||
preset['unet']['requires_grad'] = require_grads
|
||||
preset['unet']['device'] = device
|
||||
else:
|
||||
preset['unet']['device'] = device
|
||||
|
||||
if train_text_encoder:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['requires_grad'] = require_grads
|
||||
preset['text_encoder']['device'] = device
|
||||
else:
|
||||
preset['text_encoder']['device'] = device
|
||||
|
||||
if train_embedding:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['requires_grad'] = require_grads
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['unet']['training'] = True
|
||||
|
||||
if train_refiner:
|
||||
preset['refiner_unet']['training'] = True
|
||||
preset['refiner_unet']['requires_grad'] = True
|
||||
preset['refiner_unet']['requires_grad'] = require_grads
|
||||
preset['refiner_unet']['device'] = device
|
||||
# if not training unet, move that to cpu
|
||||
if not train_unet:
|
||||
@@ -81,7 +82,7 @@ def get_train_sd_device_state_preset(
|
||||
preset['refiner_unet']['requires_grad'] = False
|
||||
|
||||
if train_adapter:
|
||||
preset['adapter']['requires_grad'] = True
|
||||
preset['adapter']['requires_grad'] = require_grads
|
||||
preset['adapter']['training'] = True
|
||||
preset['adapter']['device'] = device
|
||||
preset['unet']['training'] = True
|
||||
|
||||
Reference in New Issue
Block a user