mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Bug fixes. added ability to use l1 loss. varous other tests and improvements
This commit is contained in:
@@ -841,6 +841,14 @@ class StableDiffusion:
|
||||
else:
|
||||
timestep = timestep.repeat(latents.shape[0], 0)
|
||||
|
||||
|
||||
# handle t2i adapters
|
||||
if 'down_intrablock_additional_residuals' in kwargs:
|
||||
# go through each item and concat if doing cfg and it doesnt have the same shape
|
||||
for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']):
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
if is_input_scaled:
|
||||
return model_input
|
||||
@@ -1599,6 +1607,8 @@ class StableDiffusion:
|
||||
training_modules = []
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['cache_clip']:
|
||||
active_modules = ['clip']
|
||||
if device_state_preset in ['generate']:
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user