Bug fixes and improvements to llmadapter

This commit is contained in:
Jaret Burkett
2025-02-15 07:18:07 -07:00
parent bd8d7dc081
commit 87e557cf1e
3 changed files with 28 additions and 7 deletions

View File

@@ -219,6 +219,7 @@ class AdapterConfig:
# for llm adapter
self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0)
self.quantize_llm: bool = kwargs.get('quantize_llm', False)
class EmbeddingConfig:

View File

@@ -44,7 +44,7 @@ from transformers import (
ConvNextModel,
ConvNextForImageClassification,
ConvNextImageProcessor,
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer, BitsAndBytesConfig
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
@@ -201,10 +201,26 @@ class CustomAdapter(torch.nn.Module):
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
elif self.adapter_type == 'llm_adapter':
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
self.sd_ref().unet.device,
dtype=torch_dtype
)
kwargs = {}
if self.config.quantize_llm:
bnb_kwargs = {
'load_in_4bit': True,
'bnb_4bit_quant_type': "nf4",
'bnb_4bit_compute_dtype': torch.bfloat16
}
quantization_config = BitsAndBytesConfig(**bnb_kwargs)
kwargs['quantization_config'] = quantization_config
kwargs['torch_dtype'] = torch_dtype
self.te = AutoModel.from_pretrained(
self.config.text_encoder_path,
**kwargs
)
else:
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
self.sd_ref().unet.device,
dtype=torch_dtype,
)
self.te.to = lambda *args, **kwargs: None
self.te.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
self.llm_adapter = LLMAdapter(
@@ -423,6 +439,9 @@ class CustomAdapter(torch.nn.Module):
if 'te_adapter' in state_dict:
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
if 'llm_adapter' in state_dict:
self.llm_adapter.load_state_dict(state_dict['llm_adapter'], strict=strict)
if 'te_augmenter' in state_dict:
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)

View File

@@ -147,13 +147,14 @@ class LLMAdapter(torch.nn.Module):
prompt_attention_mask = text_inputs.attention_mask.to(device)
# remove the system prompt from the input and attention mask
text_input_ids = text_input_ids[:, self.system_prompt_length:]
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
prompt_embeds = text_encoder(
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
)
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
dtype = text_encoder.dtype