mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Small tweaks and bug fixes and future proofing
This commit is contained in:
@@ -361,6 +361,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if self.transformer_only and is_unet and hasattr(root_module, 'blocks'):
|
||||
if "blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if self.transformer_only and is_unet and hasattr(root_module, 'single_blocks'):
|
||||
if "single_blocks" not in lora_name and "double_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
|
||||
@@ -149,7 +149,12 @@ def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
||||
pooled_embeds = None
|
||||
if prompt_embeds[0].pooled_embeds is not None:
|
||||
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
|
||||
return PromptEmbeds([text_embeds, pooled_embeds])
|
||||
attention_mask = None
|
||||
if prompt_embeds[0].attention_mask is not None:
|
||||
attention_mask = torch.cat([p.attention_mask for p in prompt_embeds], dim=0)
|
||||
pe = PromptEmbeds([text_embeds, pooled_embeds])
|
||||
pe.attention_mask = attention_mask
|
||||
return pe
|
||||
|
||||
|
||||
def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
|
||||
|
||||
Reference in New Issue
Block a user