mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP create new class to add new models more easily
This commit is contained in:
@@ -68,6 +68,8 @@ import transformers
|
||||
import diffusers
|
||||
import hashlib
|
||||
|
||||
from toolkit.util.get_model import get_model_class
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
@@ -1423,7 +1425,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
||||
self.load_training_state_from_metadata(previous_refiner_save)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
ModelClass = get_model_class(self.model_config)
|
||||
self.sd = ModelClass(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
|
||||
Reference in New Issue
Block a user