mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Update generation script to handle latest models.
This commit is contained in:
@@ -10,10 +10,13 @@ from jobs.process.BaseProcess import BaseProcess
|
|||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
|
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
|
||||||
add_base_model_info_to_meta
|
add_base_model_info_to_meta
|
||||||
|
from toolkit.sampler import get_sampler
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from toolkit.util.get_model import get_model_class
|
||||||
|
|
||||||
|
|
||||||
class GenerateConfig:
|
class GenerateConfig:
|
||||||
|
|
||||||
@@ -84,10 +87,32 @@ class GenerateProcess(BaseProcess):
|
|||||||
self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
|
self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
|
||||||
|
|
||||||
self.progress_bar = None
|
self.progress_bar = None
|
||||||
self.sd = StableDiffusion(
|
|
||||||
|
ModelClass = get_model_class(self.model_config)
|
||||||
|
# if the model class has get_train_scheduler static method
|
||||||
|
if hasattr(ModelClass, 'get_train_scheduler'):
|
||||||
|
sampler = ModelClass.get_train_scheduler()
|
||||||
|
else:
|
||||||
|
# get the noise scheduler
|
||||||
|
arch = 'sd'
|
||||||
|
if self.model_config.is_pixart:
|
||||||
|
arch = 'pixart'
|
||||||
|
if self.model_config.is_flux:
|
||||||
|
arch = 'flux'
|
||||||
|
if self.model_config.is_lumina2:
|
||||||
|
arch = 'lumina2'
|
||||||
|
sampler = get_sampler(
|
||||||
|
self.train_config.noise_scheduler,
|
||||||
|
{
|
||||||
|
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||||
|
},
|
||||||
|
arch=arch,
|
||||||
|
)
|
||||||
|
self.sd = ModelClass(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
|
noise_scheduler=sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Using device {self.device}")
|
print(f"Using device {self.device}")
|
||||||
|
|||||||
Reference in New Issue
Block a user