mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -382,8 +382,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_strength_max = 1.0
|
||||
else:
|
||||
# training with assistance, we want it low
|
||||
adapter_strength_min = 0.5
|
||||
adapter_strength_max = 0.8
|
||||
adapter_strength_min = 0.4
|
||||
adapter_strength_max = 0.7
|
||||
# adapter_strength_min = 0.9
|
||||
# adapter_strength_max = 1.1
|
||||
|
||||
@@ -431,6 +431,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# make the batch splits
|
||||
if self.train_config.single_item_batching:
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
raise ValueError("Single item batching is not supported when training the refiner")
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# chunk/split everything
|
||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||
@@ -452,7 +454,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt_2_list = [[prompt] for prompt in prompts_2]
|
||||
|
||||
else:
|
||||
# but it all in an array
|
||||
noisy_latents_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
@@ -603,8 +604,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
else:
|
||||
# gradient accumulation. Just a place for breakpoint
|
||||
pass
|
||||
|
||||
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
with self.timer('restore_embeddings'):
|
||||
|
||||
Reference in New Issue
Block a user