From 1755e58dd91ba3913005a674cce3504a367be988 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 5 Aug 2025 08:55:16 -0600 Subject: [PATCH] Update generation script to handle latest models. --- jobs/process/GenerateProcess.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index 44fc6b28..6f67a322 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -10,10 +10,13 @@ from jobs.process.BaseProcess import BaseProcess from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \ add_base_model_info_to_meta +from toolkit.sampler import get_sampler from toolkit.stable_diffusion_model import StableDiffusion from toolkit.train_tools import get_torch_dtype import random +from toolkit.util.get_model import get_model_class + class GenerateConfig: @@ -84,10 +87,32 @@ class GenerateProcess(BaseProcess): self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16')) self.progress_bar = None - self.sd = StableDiffusion( + + ModelClass = get_model_class(self.model_config) + # if the model class has get_train_scheduler static method + if hasattr(ModelClass, 'get_train_scheduler'): + sampler = ModelClass.get_train_scheduler() + else: + # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' + if self.model_config.is_lumina2: + arch = 'lumina2' + sampler = get_sampler( + self.train_config.noise_scheduler, + { + "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", + }, + arch=arch, + ) + self.sd = ModelClass( device=self.device, model_config=self.model_config, dtype=self.model_config.dtype, + noise_scheduler=sampler, ) print(f"Using device {self.device}")