mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
12 lines
401 B
Python
12 lines
401 B
Python
from toolkit.stable_diffusion_model import StableDiffusion
|
|
from toolkit.config_modules import ModelConfig
|
|
|
|
def get_model_class(config: ModelConfig):
|
|
if config.arch == "wan21":
|
|
from toolkit.models.wan21 import Wan21
|
|
return Wan21
|
|
elif config.arch == "cogview4":
|
|
from toolkit.models.cogview4 import CogView4
|
|
return CogView4
|
|
else:
|
|
return StableDiffusion |