Bug fixes and minor features

This commit is contained in:
Jaret Burkett
2024-04-25 06:14:31 -06:00
parent 5a70b7f38d
commit 5da3613e0b
12 changed files with 218 additions and 31 deletions

View File

@@ -111,6 +111,7 @@ class CustomAdapter(torch.nn.Module):
self.load_state_dict(loaded_state_dict, strict=False)
def setup_adapter(self):
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
if self.adapter_type == 'photo_maker':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
@@ -146,14 +147,23 @@ class CustomAdapter(torch.nn.Module):
)
elif self.adapter_type == 'text_encoder':
if self.config.text_encoder_arch == 't5':
self.te = T5EncoderModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
dtype=get_torch_dtype(
self.sd_ref().dtype))
te_kwargs = {}
# te_kwargs['load_in_4bit'] = True
# te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
self.te = T5EncoderModel.from_pretrained(
self.config.text_encoder_path,
torch_dtype=torch_dtype,
**te_kwargs
)
# self.te.to = lambda *args, **kwargs: None
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
elif self.config.text_encoder_arch == 'clip':
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
dtype=get_torch_dtype(
self.sd_ref().dtype))
dtype=torch_dtype)
self.tokenizer = CLIPTokenizer.from_pretrained(self.config.text_encoder_path)
else:
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
@@ -531,7 +541,14 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False,
is_unconditional=False,
quad_count=4,
is_generating_samples=False,
) -> PromptEmbeds:
if self.adapter_type == 'text_encoder' and is_generating_samples:
# replace the prompt embed with ours
if is_unconditional:
return self.unconditional_embeds.clone()
return self.conditional_embeds.clone()
if self.adapter_type == 'ilora':
return prompt_embeds