mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bug fixes and minor features
This commit is contained in:
@@ -111,6 +111,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.load_state_dict(loaded_state_dict, strict=False)
|
||||
|
||||
def setup_adapter(self):
|
||||
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
|
||||
if self.adapter_type == 'photo_maker':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
@@ -146,14 +147,23 @@ class CustomAdapter(torch.nn.Module):
|
||||
)
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
if self.config.text_encoder_arch == 't5':
|
||||
self.te = T5EncoderModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
||||
dtype=get_torch_dtype(
|
||||
self.sd_ref().dtype))
|
||||
te_kwargs = {}
|
||||
# te_kwargs['load_in_4bit'] = True
|
||||
# te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
self.te = T5EncoderModel.from_pretrained(
|
||||
self.config.text_encoder_path,
|
||||
torch_dtype=torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
|
||||
# self.te.to = lambda *args, **kwargs: None
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
elif self.config.text_encoder_arch == 'clip':
|
||||
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
||||
dtype=get_torch_dtype(
|
||||
self.sd_ref().dtype))
|
||||
dtype=torch_dtype)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
else:
|
||||
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
|
||||
@@ -531,7 +541,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=False,
|
||||
quad_count=4,
|
||||
is_generating_samples=False,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'text_encoder' and is_generating_samples:
|
||||
# replace the prompt embed with ours
|
||||
if is_unconditional:
|
||||
return self.unconditional_embeds.clone()
|
||||
return self.conditional_embeds.clone()
|
||||
|
||||
if self.adapter_type == 'ilora':
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user