Added refiner fine tuning. Works, but needs some polish.

This commit is contained in:
Jaret Burkett
2023-11-05 17:15:03 -07:00
parent 8a9e8f708f
commit 93ea955d7c
14 changed files with 4541 additions and 128 deletions

View File

@@ -382,8 +382,8 @@ class SDTrainer(BaseSDTrainProcess):
adapter_strength_max = 1.0
else:
# training with assistance, we want it low
adapter_strength_min = 0.5
adapter_strength_max = 0.8
adapter_strength_min = 0.4
adapter_strength_max = 0.7
# adapter_strength_min = 0.9
# adapter_strength_max = 1.1
@@ -431,6 +431,8 @@ class SDTrainer(BaseSDTrainProcess):
# make the batch splits
if self.train_config.single_item_batching:
if self.model_config.refiner_name_or_path is not None:
raise ValueError("Single item batching is not supported when training the refiner")
batch_size = noisy_latents.shape[0]
# chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
@@ -452,7 +454,6 @@ class SDTrainer(BaseSDTrainProcess):
prompt_2_list = [[prompt] for prompt in prompts_2]
else:
# but it all in an array
noisy_latents_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
@@ -603,8 +604,13 @@ class SDTrainer(BaseSDTrainProcess):
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
with self.timer('scheduler_step'):
self.lr_scheduler.step()
else:
# gradient accumulation. Just a place for breakpoint
pass
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if self.embedding is not None:
with self.timer('restore_embeddings'):