mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -482,6 +482,8 @@ class TrainConfig:
|
||||
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
|
||||
# will make training faster and use less vram
|
||||
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
|
||||
# will toggle all datasets to cache text embeddings
|
||||
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
|
||||
# for swapping which parameters are trained during training
|
||||
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
|
||||
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
|
||||
@@ -1189,6 +1191,7 @@ def validate_configs(
|
||||
train_config: TrainConfig,
|
||||
model_config: ModelConfig,
|
||||
save_config: SaveConfig,
|
||||
dataset_configs: List[DatasetConfig]
|
||||
):
|
||||
if model_config.is_flux:
|
||||
if save_config.save_format != 'diffusers':
|
||||
@@ -1200,3 +1203,18 @@ def validate_configs(
|
||||
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
|
||||
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
|
||||
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
|
||||
|
||||
# see if any datasets are caching text embeddings
|
||||
is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs)
|
||||
if is_caching_text_embeddings:
|
||||
|
||||
# check if they are doing differential output preservation
|
||||
if train_config.diff_output_preservation:
|
||||
raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.")
|
||||
|
||||
# make sure they are all cached
|
||||
for dataset in dataset_configs:
|
||||
if not dataset.cache_text_embeddings:
|
||||
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
|
||||
|
||||
|
||||
|
||||
@@ -558,6 +558,8 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
|
||||
self.cache_latents_all_latents()
|
||||
if self.is_caching_clip_vision_to_disk:
|
||||
self.cache_clip_vision_to_disk()
|
||||
if self.is_caching_text_embeddings:
|
||||
self.cache_text_embeddings()
|
||||
if self.is_generating_controls:
|
||||
# always do this last
|
||||
self.setup_controls()
|
||||
|
||||
@@ -13,8 +13,8 @@ from toolkit import image_utils
|
||||
from toolkit.basic import get_quick_signature_string
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
|
||||
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -32,6 +32,7 @@ def print_once(msg):
|
||||
|
||||
class FileItemDTO(
|
||||
LatentCachingFileItemDTOMixin,
|
||||
TextEmbeddingFileItemDTOMixin,
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
@@ -124,6 +125,7 @@ class FileItemDTO(
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_text_embedding()
|
||||
self.cleanup_control()
|
||||
self.cleanup_inpaint()
|
||||
self.cleanup_clip_image()
|
||||
@@ -136,6 +138,7 @@ class DataLoaderBatchDTO:
|
||||
try:
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -156,6 +159,7 @@ class DataLoaderBatchDTO:
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
if any([x.control_tensor is not None for x in self.file_items]):
|
||||
@@ -268,6 +272,22 @@ class DataLoaderBatchDTO:
|
||||
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
|
||||
else:
|
||||
raise Exception("clip_image_embeds_unconditional is None for some file items")
|
||||
|
||||
if any([x.prompt_embeds is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_prompt_embeds = None
|
||||
for x in self.file_items:
|
||||
if x.prompt_embeds is not None:
|
||||
base_prompt_embeds = x.prompt_embeds
|
||||
break
|
||||
prompt_embeds_list = []
|
||||
for x in self.file_items:
|
||||
if x.prompt_embeds is None:
|
||||
prompt_embeds_list.append(base_prompt_embeds)
|
||||
else:
|
||||
prompt_embeds_list.append(x.prompt_embeds)
|
||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
@@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose
|
||||
import albumentations as A
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin:
|
||||
self.extra_values: List[float] = dataset_config.extra_values
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
|
||||
if self.raw_caption is not None:
|
||||
# we already loaded it
|
||||
pass
|
||||
@@ -635,6 +636,9 @@ class ImageProcessingDTOMixin:
|
||||
if self.dataset_config.num_frames > 1:
|
||||
self.load_and_process_video(transform, only_load_latents)
|
||||
return
|
||||
# handle get_prompt_embedding
|
||||
if self.is_text_embedding_cached:
|
||||
self.load_prompt_embedding()
|
||||
# if we are caching latents, just do that
|
||||
if self.is_latent_cached:
|
||||
self.get_latent()
|
||||
@@ -1773,6 +1777,61 @@ class LatentCachingMixin:
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
class TextEmbeddingFileItemDTOMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
self._text_embedding_path: Union[str, None] = None
|
||||
self.is_text_embedding_cached = False
|
||||
self.text_embedding_load_device = 'cpu'
|
||||
self.text_embedding_space_version = 'sd1'
|
||||
self.text_embedding_version = 1
|
||||
|
||||
def get_text_embedding_info_dict(self: 'FileItemDTO'):
|
||||
# make sure the caption is loaded here
|
||||
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
|
||||
if self.caption is None:
|
||||
self.load_caption()
|
||||
# throw error is [trigger] in caption as we cannot inject it while caching
|
||||
if '[trigger]' in self.caption:
|
||||
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
|
||||
item = OrderedDict([
|
||||
("caption", self.caption),
|
||||
("text_embedding_space_version", self.text_embedding_space_version),
|
||||
("text_embedding_version", self.text_embedding_version),
|
||||
])
|
||||
return item
|
||||
|
||||
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._text_embedding_path is not None and not recalculate:
|
||||
return self._text_embedding_path
|
||||
else:
|
||||
# we store text embeddings in a folder in same path as image called _text_embedding_cache
|
||||
img_dir = os.path.dirname(self.path)
|
||||
te_dir = os.path.join(img_dir, '_t_e_cache')
|
||||
hash_dict = self.get_text_embedding_info_dict()
|
||||
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._text_embedding_path
|
||||
|
||||
def cleanup_text_embedding(self):
|
||||
if self.prompt_embeds is not None:
|
||||
# we are caching on disk, don't save in memory
|
||||
self.prompt_embeds = None
|
||||
|
||||
def load_prompt_embedding(self, device=None):
|
||||
if not self.is_text_embedding_cached:
|
||||
return
|
||||
if self.prompt_embeds is None:
|
||||
# load it from disk
|
||||
self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path())
|
||||
|
||||
class TextEmbeddingCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
@@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin:
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
|
||||
if self.is_caching_text_embeddings:
|
||||
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
|
||||
|
||||
def cache_text_embeddings(self: 'AiToolkitDataset'):
|
||||
|
||||
with accelerator.main_process_first():
|
||||
print_acc(f"Caching text_embeddings for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
print_acc(" - Saving text embeddings to disk")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
did_move = False
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_flux:
|
||||
file_item.latent_space_version = 'flux1'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = self.sd.model_config.arch
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
|
||||
file_item.text_embedding_space_version = self.sd.model_config.arch
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
try:
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
except Exception as e:
|
||||
print_acc(f"Error processing image: {file_item.path}")
|
||||
print_acc(f"Error: {str(e)}")
|
||||
raise e
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
|
||||
del imgs
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
|
||||
# only process if not saved to disk
|
||||
if not os.path.exists(text_embedding_path):
|
||||
# load if not loaded
|
||||
if not did_move:
|
||||
self.sd.set_device_state_preset('cache_text_encoder')
|
||||
did_move = True
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
# save it
|
||||
prompt_embeds.save(text_embedding_path)
|
||||
del prompt_embeds
|
||||
file_item.is_text_embedding_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
# if did_move:
|
||||
# self.sd.restore_device_state()
|
||||
|
||||
|
||||
class CLIPCachingMixin:
|
||||
|
||||
@@ -168,6 +168,8 @@ class BaseModel:
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
self.is_transformer = False
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -484,19 +486,23 @@ class BaseModel:
|
||||
quad_count=4
|
||||
)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(
|
||||
gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
if self.sample_prompts_cache is not None:
|
||||
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
else:
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(
|
||||
gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
|
||||
@@ -92,6 +92,56 @@ class PromptEmbeds:
|
||||
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
|
||||
return pe
|
||||
|
||||
def save(self, path: str):
|
||||
"""
|
||||
Save the prompt embeds to a file.
|
||||
:param path: The path to save the prompt embeds.
|
||||
"""
|
||||
pe = self.clone()
|
||||
state_dict = {}
|
||||
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
|
||||
for i, text_embed in enumerate(pe.text_embeds):
|
||||
state_dict[f"text_embed_{i}"] = text_embed.cpu()
|
||||
else:
|
||||
state_dict["text_embed"] = pe.text_embeds.cpu()
|
||||
|
||||
if pe.pooled_embeds is not None:
|
||||
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
|
||||
if pe.attention_mask is not None:
|
||||
state_dict["attention_mask"] = pe.attention_mask.cpu()
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
save_file(state_dict, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> 'PromptEmbeds':
|
||||
"""
|
||||
Load the prompt embeds from a file.
|
||||
:param path: The path to load the prompt embeds from.
|
||||
:return: An instance of PromptEmbeds.
|
||||
"""
|
||||
state_dict = load_file(path, device='cpu')
|
||||
text_embeds = []
|
||||
pooled_embeds = None
|
||||
attention_mask = None
|
||||
for key in sorted(state_dict.keys()):
|
||||
if key.startswith("text_embed_"):
|
||||
text_embeds.append(state_dict[key])
|
||||
elif key == "text_embed":
|
||||
text_embeds.append(state_dict[key])
|
||||
elif key == "pooled_embed":
|
||||
pooled_embeds = state_dict[key]
|
||||
elif key == "attention_mask":
|
||||
attention_mask = state_dict[key]
|
||||
pe = cls(None)
|
||||
pe.text_embeds = text_embeds
|
||||
if len(text_embeds) == 1:
|
||||
pe.text_embeds = text_embeds[0]
|
||||
if pooled_embeds is not None:
|
||||
pe.pooled_embeds = pooled_embeds
|
||||
if attention_mask is not None:
|
||||
pe.attention_mask = attention_mask
|
||||
return pe
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
|
||||
@@ -209,6 +209,8 @@ class StableDiffusion:
|
||||
# todo update this based on the model
|
||||
self.is_transformer = False
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -1426,18 +1428,22 @@ class StableDiffusion:
|
||||
quad_count=4
|
||||
)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
if self.sample_prompts_cache is not None:
|
||||
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
else:
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
|
||||
63
toolkit/unloader.py
Normal file
63
toolkit/unloader.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from toolkit.basic import flush
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.models.base_model import BaseModel
|
||||
|
||||
|
||||
class FakeTextEncoder(torch.nn.Module):
|
||||
def __init__(self, device, dtype):
|
||||
super().__init__()
|
||||
# register a dummy parameter to avoid errors in some cases
|
||||
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"This is a fake text encoder and should not be used for inference."
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def unload_text_encoder(model: "BaseModel"):
|
||||
# unload the text encoder in a way that will work with all models and will not throw errors
|
||||
# we need to make it appear as a text encoder module without actually having one so all
|
||||
# to functions and what not will work.
|
||||
|
||||
if model.text_encoder is not None:
|
||||
if isinstance(model.text_encoder, list):
|
||||
text_encoder_list = []
|
||||
pipe = model.pipeline
|
||||
|
||||
# the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc.
|
||||
if hasattr(pipe, "text_encoder"):
|
||||
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
|
||||
text_encoder_list.append(te)
|
||||
pipe.text_encoder = te
|
||||
|
||||
i = 2
|
||||
while hasattr(pipe, f"text_encoder_{i}"):
|
||||
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
|
||||
text_encoder_list.append(te)
|
||||
setattr(pipe, f"text_encoder_{i}", te)
|
||||
i += 1
|
||||
model.text_encoder = text_encoder_list
|
||||
else:
|
||||
# only has a single text encoder
|
||||
model.text_encoder = FakeTextEncoder()
|
||||
|
||||
flush()
|
||||
Reference in New Issue
Block a user