Numerous fixes for time sampling. Still not perfect

This commit is contained in:
Jaret Burkett
2023-11-28 07:34:43 -07:00
parent d7e55b6ad4
commit 792a5e37e2
7 changed files with 160 additions and 91 deletions

View File

@@ -584,8 +584,13 @@ def encode_prompts_xl(
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
if max_length is None and not truncate:
raise ValueError("max_length must be set if truncate is True")
tokens = tokens.to(text_encoder.device)
try:
tokens = tokens.to(text_encoder.device)
except Exception as e:
print(e)
print("tokens.device", tokens.device)
print("text_encoder.device", text_encoder.device)
raise e
if truncate:
return text_encoder(tokens)[0]
@@ -771,8 +776,8 @@ def apply_snr_weight(
):
# will get it from noise scheduler if exist or will calculate it if not
all_snr = get_all_snr(noise_scheduler, loss.device)
snr = torch.stack([all_snr[t] for t in timesteps])
step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
snr = torch.stack([all_snr[t] for t in step_indices])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
if fixed:
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr