Files
ai-toolkit/toolkit/unloader.py

65 lines
2.1 KiB
Python

import torch
from toolkit.basic import flush
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
class FakeTextEncoder(torch.nn.Module):
def __init__(self, device, dtype):
super().__init__()
# register a dummy parameter to avoid errors in some cases
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
self._device = device
self._dtype = dtype
def forward(self, *args, **kwargs):
raise NotImplementedError(
"This is a fake text encoder and should not be used for inference."
)
return None
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(self, *args, **kwargs):
return self
def unload_text_encoder(model: "BaseModel"):
# unload the text encoder in a way that will work with all models and will not throw errors
# we need to make it appear as a text encoder module without actually having one so all
# to functions and what not will work.
if model.text_encoder is not None:
if isinstance(model.text_encoder, list):
text_encoder_list = []
pipe = model.pipeline
# the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc.
if hasattr(pipe, "text_encoder"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
pipe.text_encoder.to('cpu')
pipe.text_encoder = te
i = 2
while hasattr(pipe, f"text_encoder_{i}"):
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
text_encoder_list.append(te)
setattr(pipe, f"text_encoder_{i}", te)
i += 1
model.text_encoder = text_encoder_list
else:
# only has a single text encoder
model.text_encoder = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
flush()