Added ability to add models to finetune as plugins. Also added flux2 new arch via that method.

This commit is contained in:
Jaret Burkett
2025-03-27 16:07:00 -06:00
parent e9e30104d3
commit 5365200da1
12 changed files with 936 additions and 1058 deletions

View File

@@ -1,12 +1,49 @@
import os
from typing import List
from toolkit.models.base_model import BaseModel
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig
from toolkit.paths import TOOLKIT_ROOT
import importlib
import pkgutil
from toolkit.models.wan21 import Wan21
from toolkit.models.cogview4 import CogView4
BUILT_IN_MODELS = [
Wan21,
CogView4,
]
def get_all_models() -> List[BaseModel]:
extension_folders = ['extensions', 'extensions_built_in']
# This will hold the classes from all extension modules
all_model_classes: List[BaseModel] = BUILT_IN_MODELS
# Iterate over all directories (i.e., packages) in the "extensions" directory
for sub_dir in extension_folders:
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
try:
# Import the module
module = importlib.import_module(f"{sub_dir}.{name}")
# Get the value of the AI_TOOLKIT_MODELS variable
models = getattr(module, "AI_TOOLKIT_MODELS", None)
# Check if the value is a list
if isinstance(models, list):
# Iterate over the list and add the classes to the main list
all_model_classes.extend(models)
except ImportError as e:
print(f"Failed to import the {name} module. Error: {str(e)}")
return all_model_classes
def get_model_class(config: ModelConfig):
if config.arch == "wan21":
from toolkit.models.wan21 import Wan21
return Wan21
elif config.arch == "cogview4":
from toolkit.models.cogview4 import CogView4
return CogView4
else:
return StableDiffusion
all_models = get_all_models()
for ModelClass in all_models:
if ModelClass.arch == config.arch:
return ModelClass
# default to the legacy model
return StableDiffusion