Added pixart sigma support, but it wont work until i address breaking changes with lora code in diffusers so it can be upgraded.

This commit is contained in:
Jaret Burkett
2024-04-20 10:46:56 -06:00
parent 377b81ee3e
commit 5a70b7f38d
5 changed files with 603 additions and 18 deletions

View File

@@ -33,8 +33,9 @@ class TEAdapterAttnProcessor(nn.Module):
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None,
adapter_hidden_size=None):
adapter_hidden_size=None, layer_name=None):
super().__init__()
self.layer_name = layer_name
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -170,6 +171,7 @@ class TEAdapter(torch.nn.Module):
self.sd_ref: weakref.ref = weakref.ref(sd)
self.te_ref: weakref.ref = weakref.ref(te)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
self.adapter_modules = []
if self.adapter_ref().config.text_encoder_arch == "t5":
self.token_size = self.te_ref().config.d_model
@@ -239,9 +241,11 @@ class TEAdapter(torch.nn.Module):
scale=1.0,
num_tokens=self.adapter_ref().config.num_tokens,
adapter=self,
adapter_hidden_size=self.token_size
adapter_hidden_size=self.token_size,
layer_name=layer_name
)
attn_procs[name].load_state_dict(weights)
self.adapter_modules.append(attn_procs[name])
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
@@ -262,7 +266,10 @@ class TEAdapter(torch.nn.Module):
return_tensors="pt",
).input_ids.to(te.device)
outputs = te(input_ids=input_ids)
return outputs.last_hidden_state
outputs = outputs.last_hidden_state
return outputs
def forward(self, input):
return input