mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Bug fixes and minor features
This commit is contained in:
@@ -672,6 +672,7 @@ class StableDiffusion:
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_generating_samples=True,
|
||||
)
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
@@ -679,6 +680,7 @@ class StableDiffusion:
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=True,
|
||||
is_generating_samples=True,
|
||||
)
|
||||
|
||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||
@@ -1324,6 +1326,20 @@ class StableDiffusion:
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
elif isinstance(self.text_encoder, T5EncoderModel):
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=77, # todo set this higher when not transfer learning
|
||||
dropout_prob=dropout_prob
|
||||
)
|
||||
return PromptEmbeds(
|
||||
embeds,
|
||||
# do we want attn mask here?
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts(
|
||||
|
||||
Reference in New Issue
Block a user