Added Differential Output Preservation Loss to trainer and ui

This commit is contained in:
Jaret Burkett
2025-02-25 20:12:36 -07:00
parent 259ded9602
commit f6e16e582a
6 changed files with 127 additions and 6 deletions

View File

@@ -62,6 +62,20 @@ class PromptEmbeds:
prompt_embeds.attention_mask = self.attention_mask.clone()
return prompt_embeds
def expand_to_batch(self, batch_size):
pe = self.clone()
current_batch_size = pe.text_embeds.shape[0]
if current_batch_size == batch_size:
return pe
if current_batch_size != 1:
raise Exception("Can only expand batch size for batch size 1")
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if pe.pooled_embeds is not None:
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
if pe.attention_mask is not None:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
return pe
class EncodedPromptPair:
def __init__(