mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Small tweaks and bug fixes and future proofing
This commit is contained in:
@@ -424,7 +424,9 @@ class Flex2(BaseModel):
|
|||||||
if self.random_blur_mask:
|
if self.random_blur_mask:
|
||||||
# blur the mask
|
# blur the mask
|
||||||
# Give it a channel dim of 1
|
# Give it a channel dim of 1
|
||||||
inpainting_tensor_mask = inpainting_tensor_mask.unsqueeze(1)
|
if len(inpainting_tensor_mask.shape) == 3:
|
||||||
|
# if it is 3d, add a channel dim
|
||||||
|
inpainting_tensor_mask = inpainting_tensor_mask.unsqueeze(1)
|
||||||
# we are at latent size, so keep kernel smaller
|
# we are at latent size, so keep kernel smaller
|
||||||
inpainting_tensor_mask = random_blur(
|
inpainting_tensor_mask = random_blur(
|
||||||
inpainting_tensor_mask,
|
inpainting_tensor_mask,
|
||||||
@@ -432,8 +434,6 @@ class Flex2(BaseModel):
|
|||||||
max_kernel_size=8,
|
max_kernel_size=8,
|
||||||
p=0.5
|
p=0.5
|
||||||
)
|
)
|
||||||
# remove the channel dim
|
|
||||||
inpainting_tensor_mask = inpainting_tensor_mask.squeeze(1)
|
|
||||||
|
|
||||||
do_mask_invert = False
|
do_mask_invert = False
|
||||||
if self.invert_inpaint_mask_chance > 0.0:
|
if self.invert_inpaint_mask_chance > 0.0:
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ def generate_readme(supporters):
|
|||||||
f.write("### GitHub Sponsors\n\n")
|
f.write("### GitHub Sponsors\n\n")
|
||||||
for sponsor in github_sponsors:
|
for sponsor in github_sponsors:
|
||||||
if sponsor['profile_image']:
|
if sponsor['profile_image']:
|
||||||
f.write(f"<a href=\"{sponsor['profile_url']}\" title=\"{sponsor['name']}\"><img src=\"{sponsor['profile_image']}\" width=\"50\" height=\"50\" alt=\"{sponsor['name']}\" style=\"border-radius:50%\"></a> ")
|
f.write(f"<a href=\"{sponsor['profile_url']}\" title=\"{sponsor['name']}\"><img src=\"{sponsor['profile_image']}\" width=\"50\" height=\"50\" alt=\"{sponsor['name']}\" style=\"border-radius:50%;display:inline-block;\"></a> ")
|
||||||
else:
|
else:
|
||||||
f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ")
|
f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ")
|
||||||
f.write("\n\n")
|
f.write("\n\n")
|
||||||
@@ -257,7 +257,7 @@ def generate_readme(supporters):
|
|||||||
f.write("### Patreon Supporters\n\n")
|
f.write("### Patreon Supporters\n\n")
|
||||||
for supporter in patreon_supporters:
|
for supporter in patreon_supporters:
|
||||||
if supporter['profile_image']:
|
if supporter['profile_image']:
|
||||||
f.write(f"<a href=\"{supporter['profile_url']}\" title=\"{supporter['name']}\"><img src=\"{supporter['profile_image']}\" width=\"50\" height=\"50\" alt=\"{supporter['name']}\" style=\"border-radius:50%\"></a> ")
|
f.write(f"<a href=\"{supporter['profile_url']}\" title=\"{supporter['name']}\"><img src=\"{supporter['profile_image']}\" width=\"50\" height=\"50\" alt=\"{supporter['name']}\" style=\"border-radius:50%;display:inline-block;\"></a> ")
|
||||||
else:
|
else:
|
||||||
f.write(f"[{supporter['name']}]({supporter['profile_url']}) ")
|
f.write(f"[{supporter['name']}]({supporter['profile_url']}) ")
|
||||||
f.write("\n\n")
|
f.write("\n\n")
|
||||||
|
|||||||
@@ -361,6 +361,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if self.transformer_only and is_unet and hasattr(root_module, 'blocks'):
|
if self.transformer_only and is_unet and hasattr(root_module, 'blocks'):
|
||||||
if "blocks" not in lora_name:
|
if "blocks" not in lora_name:
|
||||||
skip = True
|
skip = True
|
||||||
|
|
||||||
|
if self.transformer_only and is_unet and hasattr(root_module, 'single_blocks'):
|
||||||
|
if "single_blocks" not in lora_name and "double_blocks" not in lora_name:
|
||||||
|
skip = True
|
||||||
|
|
||||||
if (is_linear or is_conv2d) and not skip:
|
if (is_linear or is_conv2d) and not skip:
|
||||||
|
|
||||||
|
|||||||
@@ -149,7 +149,12 @@ def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
|||||||
pooled_embeds = None
|
pooled_embeds = None
|
||||||
if prompt_embeds[0].pooled_embeds is not None:
|
if prompt_embeds[0].pooled_embeds is not None:
|
||||||
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
|
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
|
||||||
return PromptEmbeds([text_embeds, pooled_embeds])
|
attention_mask = None
|
||||||
|
if prompt_embeds[0].attention_mask is not None:
|
||||||
|
attention_mask = torch.cat([p.attention_mask for p in prompt_embeds], dim=0)
|
||||||
|
pe = PromptEmbeds([text_embeds, pooled_embeds])
|
||||||
|
pe.attention_mask = attention_mask
|
||||||
|
return pe
|
||||||
|
|
||||||
|
|
||||||
def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
|
def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
|
||||||
|
|||||||
Reference in New Issue
Block a user