Improvements for full tuning flux. Added debugging launch config for vscode

This commit is contained in:
Jaret Burkett
2024-10-29 04:54:08 -06:00
parent 3400882a80
commit 22cd40d7b9
6 changed files with 170 additions and 19 deletions

View File

@@ -174,7 +174,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder
unload_text_encoder=self.train_config.unload_text_encoder,
require_grads=False # we ensure them later
)
self.get_params_device_state_preset = get_train_sd_device_state_preset(
device=self.device_torch,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
require_grads=True # We check for grads when getting params
)
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
@@ -575,9 +589,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
def ensure_params_requires_grad(self):
# get param groups
for group in self.optimizer.param_groups:
# for group in self.optimizer.param_groups:
for group in self.params:
for param in group['params']:
param.requires_grad = True
if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
param.requires_grad_(True)
def setup_ema(self):
if self.train_config.ema_config.use_ema:
@@ -1487,7 +1503,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
else: # no network, embedding or adapter
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
self.sd.set_device_state(self.get_params_device_state_preset)
# params = self.get_params()
if len(params) == 0:
@@ -1521,6 +1537,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.start_step = self.step_num
optimizer_type = self.train_config.optimizer.lower()
# esure params require grad
self.ensure_params_requires_grad()
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer