Added caching to image sizes so we dont do it every time.

This commit is contained in:
Jaret Burkett
2024-07-15 19:07:41 -06:00
parent e4558dff4b
commit 58dffd43a8
7 changed files with 90 additions and 34 deletions

View File

@@ -39,7 +39,8 @@ from transformers import (
AutoImageProcessor,
ConvNextModel,
ConvNextForImageClassification,
ConvNextImageProcessor
ConvNextImageProcessor,
UMT5EncoderModel, LlamaTokenizerFast
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
@@ -165,6 +166,23 @@ class CustomAdapter(torch.nn.Module):
# self.te.to = lambda *args, **kwargs: None
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
elif self.config.text_encoder_arch == 'pile-t5':
te_kwargs = {}
# te_kwargs['load_in_4bit'] = True
# te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
self.te = UMT5EncoderModel.from_pretrained(
self.config.text_encoder_path,
torch_dtype=torch_dtype,
**te_kwargs
)
# self.te.to = lambda *args, **kwargs: None
self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path)
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
elif self.config.text_encoder_arch == 'clip':
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
dtype=torch_dtype)