mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
import torch
|
|
from toolkit.basic import flush
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.models.base_model import BaseModel
|
|
|
|
|
|
class FakeTextEncoder(torch.nn.Module):
|
|
def __init__(self, device, dtype):
|
|
super().__init__()
|
|
# register a dummy parameter to avoid errors in some cases
|
|
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
|
|
self._device = device
|
|
self._dtype = dtype
|
|
|
|
def forward(self, *args, **kwargs):
|
|
raise NotImplementedError(
|
|
"This is a fake text encoder and should not be used for inference."
|
|
)
|
|
return None
|
|
|
|
@property
|
|
def device(self):
|
|
return self._device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._dtype
|
|
|
|
def to(self, *args, **kwargs):
|
|
return self
|
|
|
|
|
|
def unload_text_encoder(model: "BaseModel"):
|
|
# unload the text encoder in a way that will work with all models and will not throw errors
|
|
# we need to make it appear as a text encoder module without actually having one so all
|
|
# to functions and what not will work.
|
|
|
|
if model.text_encoder is not None:
|
|
if isinstance(model.text_encoder, list):
|
|
text_encoder_list = []
|
|
pipe = model.pipeline
|
|
|
|
# the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc.
|
|
if hasattr(pipe, "text_encoder"):
|
|
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
|
|
text_encoder_list.append(te)
|
|
pipe.text_encoder = te
|
|
|
|
i = 2
|
|
while hasattr(pipe, f"text_encoder_{i}"):
|
|
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
|
|
text_encoder_list.append(te)
|
|
setattr(pipe, f"text_encoder_{i}", te)
|
|
i += 1
|
|
model.text_encoder = text_encoder_list
|
|
else:
|
|
# only has a single text encoder
|
|
model.text_encoder = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
|
|
|
|
flush()
|