Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see

This commit is contained in:
Jaret Burkett
2023-09-07 13:06:18 -06:00
parent 436bf0c6a3
commit 3feb663a51
10 changed files with 208 additions and 140 deletions

View File

@@ -442,21 +442,25 @@ class StableDiffusion:
return noise
def get_time_ids_from_latents(self, latents: torch.Tensor):
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = latents.dtype
if self.is_xl:
prompt_ids = train_tools.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
dtype=dtype,
).to(self.device_torch, dtype=dtype)
return prompt_ids
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = latents.dtype
# just do it without any cropping nonsense
target_size = (height, width)
original_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
batch_time_ids = torch.cat(
[add_time_ids for _ in range(bs)]
)
return batch_time_ids
else:
return None
@@ -682,7 +686,7 @@ class StableDiffusion:
if self.vae.device == 'cpu':
self.vae.to(self.device)
latents = latents.to(device, dtype=dtype)
latents = latents / 0.18215
latents = latents / self.vae.config['scaling_factor']
images = self.vae.decode(latents).sample
images = images.to(device, dtype=dtype)