WIP create new class to add new models more easily

This commit is contained in:
Jaret Burkett
2025-03-01 13:49:02 -07:00
parent 60539c0b0f
commit acc79956aa
7 changed files with 1624 additions and 13 deletions

View File

@@ -68,6 +68,8 @@ import transformers
import diffusers
import hashlib
from toolkit.util.get_model import get_model_class
def flush():
torch.cuda.empty_cache()
gc.collect()
@@ -1423,7 +1425,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
model_config_to_load.refiner_name_or_path = previous_refiner_save
self.load_training_state_from_metadata(previous_refiner_save)
self.sd = StableDiffusion(
ModelClass = get_model_class(self.model_config)
self.sd = ModelClass(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,