mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Added caching to image sizes so we dont do it every time.
This commit is contained in:
@@ -39,7 +39,8 @@ from transformers import (
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
ConvNextImageProcessor,
|
||||
UMT5EncoderModel, LlamaTokenizerFast
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
@@ -165,6 +166,23 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
# self.te.to = lambda *args, **kwargs: None
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
elif self.config.text_encoder_arch == 'pile-t5':
|
||||
te_kwargs = {}
|
||||
# te_kwargs['load_in_4bit'] = True
|
||||
# te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
self.te = UMT5EncoderModel.from_pretrained(
|
||||
self.config.text_encoder_path,
|
||||
torch_dtype=torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
|
||||
# self.te.to = lambda *args, **kwargs: None
|
||||
self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
elif self.config.text_encoder_arch == 'clip':
|
||||
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
||||
dtype=torch_dtype)
|
||||
|
||||
Reference in New Issue
Block a user