Files
ai-toolkit/toolkit/util/get_model.py
2025-03-04 18:43:52 -07:00

12 lines
401 B
Python

from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig
def get_model_class(config: ModelConfig):
if config.arch == "wan21":
from toolkit.models.wan21 import Wan21
return Wan21
elif config.arch == "cogview4":
from toolkit.models.cogview4 import CogView4
return CogView4
else:
return StableDiffusion