mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added pixart sigma support, but it wont work until i address breaking changes with lora code in diffusers so it can be upgraded.
This commit is contained in:
@@ -33,8 +33,9 @@ class TEAdapterAttnProcessor(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None,
|
||||
adapter_hidden_size=None):
|
||||
adapter_hidden_size=None, layer_name=None):
|
||||
super().__init__()
|
||||
self.layer_name = layer_name
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
@@ -170,6 +171,7 @@ class TEAdapter(torch.nn.Module):
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.te_ref: weakref.ref = weakref.ref(te)
|
||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||
self.adapter_modules = []
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
@@ -239,9 +241,11 @@ class TEAdapter(torch.nn.Module):
|
||||
scale=1.0,
|
||||
num_tokens=self.adapter_ref().config.num_tokens,
|
||||
adapter=self,
|
||||
adapter_hidden_size=self.token_size
|
||||
adapter_hidden_size=self.token_size,
|
||||
layer_name=layer_name
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
self.adapter_modules.append(attn_procs[name])
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
@@ -262,7 +266,10 @@ class TEAdapter(torch.nn.Module):
|
||||
return_tensors="pt",
|
||||
).input_ids.to(te.device)
|
||||
outputs = te(input_ids=input_ids)
|
||||
return outputs.last_hidden_state
|
||||
outputs = outputs.last_hidden_state
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
Reference in New Issue
Block a user