mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes and improvements to llmadapter
This commit is contained in:
@@ -219,6 +219,7 @@ class AdapterConfig:
|
|||||||
|
|
||||||
# for llm adapter
|
# for llm adapter
|
||||||
self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0)
|
self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0)
|
||||||
|
self.quantize_llm: bool = kwargs.get('quantize_llm', False)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig:
|
class EmbeddingConfig:
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from transformers import (
|
|||||||
ConvNextModel,
|
ConvNextModel,
|
||||||
ConvNextForImageClassification,
|
ConvNextForImageClassification,
|
||||||
ConvNextImageProcessor,
|
ConvNextImageProcessor,
|
||||||
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer
|
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer, BitsAndBytesConfig
|
||||||
)
|
)
|
||||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||||
|
|
||||||
@@ -201,10 +201,26 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
||||||
elif self.adapter_type == 'llm_adapter':
|
elif self.adapter_type == 'llm_adapter':
|
||||||
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
|
kwargs = {}
|
||||||
self.sd_ref().unet.device,
|
if self.config.quantize_llm:
|
||||||
dtype=torch_dtype
|
bnb_kwargs = {
|
||||||
)
|
'load_in_4bit': True,
|
||||||
|
'bnb_4bit_quant_type': "nf4",
|
||||||
|
'bnb_4bit_compute_dtype': torch.bfloat16
|
||||||
|
}
|
||||||
|
quantization_config = BitsAndBytesConfig(**bnb_kwargs)
|
||||||
|
kwargs['quantization_config'] = quantization_config
|
||||||
|
kwargs['torch_dtype'] = torch_dtype
|
||||||
|
self.te = AutoModel.from_pretrained(
|
||||||
|
self.config.text_encoder_path,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
|
||||||
|
self.sd_ref().unet.device,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
self.te.to = lambda *args, **kwargs: None
|
||||||
self.te.eval()
|
self.te.eval()
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
|
||||||
self.llm_adapter = LLMAdapter(
|
self.llm_adapter = LLMAdapter(
|
||||||
@@ -423,6 +439,9 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
if 'te_adapter' in state_dict:
|
if 'te_adapter' in state_dict:
|
||||||
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
|
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
|
||||||
|
|
||||||
|
if 'llm_adapter' in state_dict:
|
||||||
|
self.llm_adapter.load_state_dict(state_dict['llm_adapter'], strict=strict)
|
||||||
|
|
||||||
if 'te_augmenter' in state_dict:
|
if 'te_augmenter' in state_dict:
|
||||||
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
|
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
|
||||||
|
|||||||
@@ -147,13 +147,14 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||||
|
|
||||||
# remove the system prompt from the input and attention mask
|
# remove the system prompt from the input and attention mask
|
||||||
text_input_ids = text_input_ids[:, self.system_prompt_length:]
|
|
||||||
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
|
||||||
|
|
||||||
prompt_embeds = text_encoder(
|
prompt_embeds = text_encoder(
|
||||||
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||||
)
|
)
|
||||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
|
||||||
|
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
||||||
|
|
||||||
dtype = text_encoder.dtype
|
dtype = text_encoder.dtype
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user