mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Update generation script to handle latest models.
This commit is contained in:
@@ -10,10 +10,13 @@ from jobs.process.BaseProcess import BaseProcess
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
|
||||
add_base_model_info_to_meta
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import random
|
||||
|
||||
from toolkit.util.get_model import get_model_class
|
||||
|
||||
|
||||
class GenerateConfig:
|
||||
|
||||
@@ -84,10 +87,32 @@ class GenerateProcess(BaseProcess):
|
||||
self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
|
||||
|
||||
self.progress_bar = None
|
||||
self.sd = StableDiffusion(
|
||||
|
||||
ModelClass = get_model_class(self.model_config)
|
||||
# if the model class has get_train_scheduler static method
|
||||
if hasattr(ModelClass, 'get_train_scheduler'):
|
||||
sampler = ModelClass.get_train_scheduler()
|
||||
else:
|
||||
# get the noise scheduler
|
||||
arch = 'sd'
|
||||
if self.model_config.is_pixart:
|
||||
arch = 'pixart'
|
||||
if self.model_config.is_flux:
|
||||
arch = 'flux'
|
||||
if self.model_config.is_lumina2:
|
||||
arch = 'lumina2'
|
||||
sampler = get_sampler(
|
||||
self.train_config.noise_scheduler,
|
||||
{
|
||||
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||
},
|
||||
arch=arch,
|
||||
)
|
||||
self.sd = ModelClass(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
dtype=self.model_config.dtype,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
|
||||
print(f"Using device {self.device}")
|
||||
|
||||
Reference in New Issue
Block a user