mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Numerous fixes for time sampling. Still not perfect
This commit is contained in:
@@ -584,8 +584,13 @@ def encode_prompts_xl(
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
||||
if max_length is None and not truncate:
|
||||
raise ValueError("max_length must be set if truncate is True")
|
||||
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
try:
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("tokens.device", tokens.device)
|
||||
print("text_encoder.device", text_encoder.device)
|
||||
raise e
|
||||
|
||||
if truncate:
|
||||
return text_encoder(tokens)[0]
|
||||
@@ -771,8 +776,8 @@ def apply_snr_weight(
|
||||
):
|
||||
# will get it from noise scheduler if exist or will calculate it if not
|
||||
all_snr = get_all_snr(noise_scheduler, loss.device)
|
||||
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
||||
snr = torch.stack([all_snr[t] for t in step_indices])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
if fixed:
|
||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||
|
||||
Reference in New Issue
Block a user