mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see
This commit is contained in:
@@ -442,21 +442,25 @@ class StableDiffusion:
|
||||
return noise
|
||||
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
|
||||
height = h * VAE_SCALE_FACTOR
|
||||
width = w * VAE_SCALE_FACTOR
|
||||
|
||||
dtype = latents.dtype
|
||||
|
||||
if self.is_xl:
|
||||
prompt_ids = train_tools.get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=False, # look into this
|
||||
dtype=dtype,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
return prompt_ids
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
|
||||
height = h * VAE_SCALE_FACTOR
|
||||
width = w * VAE_SCALE_FACTOR
|
||||
|
||||
dtype = latents.dtype
|
||||
# just do it without any cropping nonsense
|
||||
target_size = (height, width)
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||
|
||||
batch_time_ids = torch.cat(
|
||||
[add_time_ids for _ in range(bs)]
|
||||
)
|
||||
return batch_time_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -682,7 +686,7 @@ class StableDiffusion:
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
latents = latents / 0.18215
|
||||
latents = latents / self.vae.config['scaling_factor']
|
||||
images = self.vae.decode(latents).sample
|
||||
images = images.to(device, dtype=dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user