mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-08 03:59:49 +00:00
Improvements for full tuning flux. Added debugging launch config for vscode
This commit is contained in:
@@ -174,7 +174,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
require_grads=False # we ensure them later
|
||||
)
|
||||
|
||||
self.get_params_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
require_grads=True # We check for grads when getting params
|
||||
)
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
@@ -575,9 +589,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def ensure_params_requires_grad(self):
|
||||
# get param groups
|
||||
for group in self.optimizer.param_groups:
|
||||
# for group in self.optimizer.param_groups:
|
||||
for group in self.params:
|
||||
for param in group['params']:
|
||||
param.requires_grad = True
|
||||
if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
|
||||
param.requires_grad_(True)
|
||||
|
||||
def setup_ema(self):
|
||||
if self.train_config.ema_config.use_ema:
|
||||
@@ -1487,7 +1503,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
else: # no network, embedding or adapter
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
self.sd.set_device_state(self.get_params_device_state_preset)
|
||||
|
||||
# params = self.get_params()
|
||||
if len(params) == 0:
|
||||
@@ -1521,6 +1537,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.start_step = self.step_num
|
||||
|
||||
optimizer_type = self.train_config.optimizer.lower()
|
||||
|
||||
# esure params require grad
|
||||
self.ensure_params_requires_grad()
|
||||
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
Reference in New Issue
Block a user