mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added ability to add models to finetune as plugins. Also added flux2 new arch via that method.
This commit is contained in:
@@ -1,12 +1,49 @@
|
||||
import os
|
||||
from typing import List
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
from toolkit.models.wan21 import Wan21
|
||||
from toolkit.models.cogview4 import CogView4
|
||||
|
||||
BUILT_IN_MODELS = [
|
||||
Wan21,
|
||||
CogView4,
|
||||
]
|
||||
|
||||
|
||||
def get_all_models() -> List[BaseModel]:
|
||||
extension_folders = ['extensions', 'extensions_built_in']
|
||||
|
||||
# This will hold the classes from all extension modules
|
||||
all_model_classes: List[BaseModel] = BUILT_IN_MODELS
|
||||
|
||||
# Iterate over all directories (i.e., packages) in the "extensions" directory
|
||||
for sub_dir in extension_folders:
|
||||
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
|
||||
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f"{sub_dir}.{name}")
|
||||
# Get the value of the AI_TOOLKIT_MODELS variable
|
||||
models = getattr(module, "AI_TOOLKIT_MODELS", None)
|
||||
# Check if the value is a list
|
||||
if isinstance(models, list):
|
||||
# Iterate over the list and add the classes to the main list
|
||||
all_model_classes.extend(models)
|
||||
except ImportError as e:
|
||||
print(f"Failed to import the {name} module. Error: {str(e)}")
|
||||
return all_model_classes
|
||||
|
||||
|
||||
def get_model_class(config: ModelConfig):
|
||||
if config.arch == "wan21":
|
||||
from toolkit.models.wan21 import Wan21
|
||||
return Wan21
|
||||
elif config.arch == "cogview4":
|
||||
from toolkit.models.cogview4 import CogView4
|
||||
return CogView4
|
||||
else:
|
||||
return StableDiffusion
|
||||
all_models = get_all_models()
|
||||
for ModelClass in all_models:
|
||||
if ModelClass.arch == config.arch:
|
||||
return ModelClass
|
||||
# default to the legacy model
|
||||
return StableDiffusion
|
||||
|
||||
Reference in New Issue
Block a user