mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-04-27 17:51:50 +00:00
SD3 lora support
This commit is contained in:
@@ -175,6 +175,9 @@ class VectorEmbedder(nn.Module):
|
||||
#################################################################################
|
||||
|
||||
|
||||
class QkvLinear(torch.nn.Linear):
|
||||
pass
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
@@ -202,7 +205,7 @@ class SelfAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
|
||||
@@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module):
|
||||
}
|
||||
self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
self.depth = depth
|
||||
|
||||
def apply_model(self, x, sigma, c_crossattn=None, y=None):
|
||||
dtype = self.get_dtype()
|
||||
|
||||
@@ -82,3 +82,15 @@ class SD3Inferencer(torch.nn.Module):
|
||||
|
||||
def fix_dimensions(self, width, height):
|
||||
return width // 16 * 16, height // 16 * 16
|
||||
|
||||
def diffusers_weight_mapping(self):
|
||||
for i in range(self.model.depth):
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
|
||||
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"
|
||||
|
||||
Reference in New Issue
Block a user