mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Bug fixes. allow for random negative prompts
This commit is contained in:
@@ -171,11 +171,19 @@ class TEAdapter(torch.nn.Module):
|
||||
self.te_ref: weakref.ref = weakref.ref(te)
|
||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
else:
|
||||
self.token_size = self.te_ref().config.hidden_size
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
attn_dict_map = {
|
||||
|
||||
}
|
||||
module_idx = 0
|
||||
attn_processors_list = list(sd.unet.attn_processors.keys())
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
|
||||
Reference in New Issue
Block a user