mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes and improvements to llmadapter
This commit is contained in:
@@ -219,6 +219,7 @@ class AdapterConfig:
|
||||
|
||||
# for llm adapter
|
||||
self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0)
|
||||
self.quantize_llm: bool = kwargs.get('quantize_llm', False)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
@@ -44,7 +44,7 @@ from transformers import (
|
||||
ConvNextModel,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor,
|
||||
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer
|
||||
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer, BitsAndBytesConfig
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
@@ -201,10 +201,26 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
||||
elif self.adapter_type == 'llm_adapter':
|
||||
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
|
||||
self.sd_ref().unet.device,
|
||||
dtype=torch_dtype
|
||||
)
|
||||
kwargs = {}
|
||||
if self.config.quantize_llm:
|
||||
bnb_kwargs = {
|
||||
'load_in_4bit': True,
|
||||
'bnb_4bit_quant_type': "nf4",
|
||||
'bnb_4bit_compute_dtype': torch.bfloat16
|
||||
}
|
||||
quantization_config = BitsAndBytesConfig(**bnb_kwargs)
|
||||
kwargs['quantization_config'] = quantization_config
|
||||
kwargs['torch_dtype'] = torch_dtype
|
||||
self.te = AutoModel.from_pretrained(
|
||||
self.config.text_encoder_path,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
|
||||
self.sd_ref().unet.device,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
self.te.to = lambda *args, **kwargs: None
|
||||
self.te.eval()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
self.llm_adapter = LLMAdapter(
|
||||
@@ -423,6 +439,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
if 'te_adapter' in state_dict:
|
||||
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
|
||||
|
||||
if 'llm_adapter' in state_dict:
|
||||
self.llm_adapter.load_state_dict(state_dict['llm_adapter'], strict=strict)
|
||||
|
||||
if 'te_augmenter' in state_dict:
|
||||
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
|
||||
|
||||
@@ -147,13 +147,14 @@ class LLMAdapter(torch.nn.Module):
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
# remove the system prompt from the input and attention mask
|
||||
text_input_ids = text_input_ids[:, self.system_prompt_length:]
|
||||
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
|
||||
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
|
||||
|
||||
Reference in New Issue
Block a user