From 87e557cf1ed47ceceb1db91d8865c1f306e6e58f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 15 Feb 2025 07:18:07 -0700 Subject: [PATCH] Bug fixes and improvements to llmadapter --- toolkit/config_modules.py | 1 + toolkit/custom_adapter.py | 29 ++++++++++++++++++++++++----- toolkit/models/llm_adapter.py | 5 +++-- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index bee186e7..44091b55 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -219,6 +219,7 @@ class AdapterConfig: # for llm adapter self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0) + self.quantize_llm: bool = kwargs.get('quantize_llm', False) class EmbeddingConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index c2c03952..5f9391a4 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -44,7 +44,7 @@ from transformers import ( ConvNextModel, ConvNextForImageClassification, ConvNextImageProcessor, - UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer + UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer, BitsAndBytesConfig ) from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel @@ -201,10 +201,26 @@ class CustomAdapter(torch.nn.Module): self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer) elif self.adapter_type == 'llm_adapter': - self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to( - self.sd_ref().unet.device, - dtype=torch_dtype - ) + kwargs = {} + if self.config.quantize_llm: + bnb_kwargs = { + 'load_in_4bit': True, + 'bnb_4bit_quant_type': "nf4", + 'bnb_4bit_compute_dtype': torch.bfloat16 + } + quantization_config = BitsAndBytesConfig(**bnb_kwargs) + kwargs['quantization_config'] = quantization_config + kwargs['torch_dtype'] = torch_dtype + self.te = AutoModel.from_pretrained( + self.config.text_encoder_path, + **kwargs + ) + else: + self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to( + self.sd_ref().unet.device, + dtype=torch_dtype, + ) + self.te.to = lambda *args, **kwargs: None self.te.eval() self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path) self.llm_adapter = LLMAdapter( @@ -423,6 +439,9 @@ class CustomAdapter(torch.nn.Module): if 'te_adapter' in state_dict: self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict) + + if 'llm_adapter' in state_dict: + self.llm_adapter.load_state_dict(state_dict['llm_adapter'], strict=strict) if 'te_augmenter' in state_dict: self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict) diff --git a/toolkit/models/llm_adapter.py b/toolkit/models/llm_adapter.py index 97ed2455..c35f6ae7 100644 --- a/toolkit/models/llm_adapter.py +++ b/toolkit/models/llm_adapter.py @@ -147,13 +147,14 @@ class LLMAdapter(torch.nn.Module): prompt_attention_mask = text_inputs.attention_mask.to(device) # remove the system prompt from the input and attention mask - text_input_ids = text_input_ids[:, self.system_prompt_length:] - prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] prompt_embeds = text_encoder( text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True ) prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds = prompt_embeds[:, self.system_prompt_length:] + prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] dtype = text_encoder.dtype