Improvements for full tuning flux. Added debugging launch config for vscode

This commit is contained in:
Jaret Burkett
2024-10-29 04:54:08 -06:00
parent 3400882a80
commit 22cd40d7b9
6 changed files with 170 additions and 19 deletions

View File

@@ -41,6 +41,7 @@ def get_train_sd_device_state_preset(
train_embedding: bool = False,
train_refiner: bool = False,
unload_text_encoder: bool = False,
require_grads: bool = True,
):
preset = copy.deepcopy(empty_preset)
if not cached_latents:
@@ -48,27 +49,27 @@ def get_train_sd_device_state_preset(
if train_unet:
preset['unet']['training'] = True
preset['unet']['requires_grad'] = True
preset['unet']['requires_grad'] = require_grads
preset['unet']['device'] = device
else:
preset['unet']['device'] = device
if train_text_encoder:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['requires_grad'] = require_grads
preset['text_encoder']['device'] = device
else:
preset['text_encoder']['device'] = device
if train_embedding:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['requires_grad'] = require_grads
preset['text_encoder']['training'] = True
preset['unet']['training'] = True
if train_refiner:
preset['refiner_unet']['training'] = True
preset['refiner_unet']['requires_grad'] = True
preset['refiner_unet']['requires_grad'] = require_grads
preset['refiner_unet']['device'] = device
# if not training unet, move that to cpu
if not train_unet:
@@ -81,7 +82,7 @@ def get_train_sd_device_state_preset(
preset['refiner_unet']['requires_grad'] = False
if train_adapter:
preset['adapter']['requires_grad'] = True
preset['adapter']['requires_grad'] = require_grads
preset['adapter']['training'] = True
preset['adapter']['device'] = device
preset['unet']['training'] = True