Bug fixes. added ability to use l1 loss. varous other tests and improvements

This commit is contained in:
Jaret Burkett
2024-01-31 06:30:54 -07:00
parent 92b9c71d44
commit 1ae1017748
9 changed files with 474 additions and 23 deletions

View File

@@ -841,6 +841,14 @@ class StableDiffusion:
else:
timestep = timestep.repeat(latents.shape[0], 0)
# handle t2i adapters
if 'down_intrablock_additional_residuals' in kwargs:
# go through each item and concat if doing cfg and it doesnt have the same shape
for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']):
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
def scale_model_input(model_input, timestep_tensor):
if is_input_scaled:
return model_input
@@ -1599,6 +1607,8 @@ class StableDiffusion:
training_modules = []
if device_state_preset in ['cache_latents']:
active_modules = ['vae']
if device_state_preset in ['cache_clip']:
active_modules = ['clip']
if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']