diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 814796e6..5da603ef 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -301,25 +301,31 @@ class BaseSDTrainProcess(BaseTrainProcess): extra_args = {} if self.adapter_config is not None and self.adapter_config.test_img_path is not None: extra_args['adapter_image_path'] = test_image_paths[i] + + sample_item = sample_config.samples[i] + if sample_item.seed is not None: + current_seed = sample_item.seed gen_img_config_list.append(GenerateImageConfig( prompt=prompt, # it will autoparse the prompt - width=sample_config.width, - height=sample_config.height, - negative_prompt=sample_config.neg, + width=sample_item.width, + height=sample_item.height, + negative_prompt=sample_item.neg, seed=current_seed, - guidance_scale=sample_config.guidance_scale, + guidance_scale=sample_item.guidance_scale, guidance_rescale=sample_config.guidance_rescale, - num_inference_steps=sample_config.sample_steps, - network_multiplier=sample_config.network_multiplier, + num_inference_steps=sample_item.sample_steps, + network_multiplier=sample_item.network_multiplier, output_path=output_path, output_ext=sample_config.ext, adapter_conditioning_scale=sample_config.adapter_conditioning_scale, refiner_start_at=sample_config.refiner_start_at, extra_values=sample_config.extra_values, logger=self.logger, - num_frames=sample_config.num_frames, - fps=sample_config.fps, + num_frames=sample_item.num_frames, + fps=sample_item.fps, + ctrl_img=sample_item.ctrl_img, + ctrl_idx=sample_item.ctrl_idx, **extra_args )) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f05d0a6e..a8814f94 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -37,6 +37,26 @@ class LoggingConfig: self.project_name: str = kwargs.get('project_name', 'ai-toolkit') self.run_name: str = kwargs.get('run_name', None) +class SampleItem: + def __init__( + self, + sample_config: 'SampleConfig', + **kwargs + ): + # prompt should always be in the kwargs + self.prompt = kwargs.get('prompt', None) + self.width: int = kwargs.get('width', sample_config.width) + self.height: int = kwargs.get('height', sample_config.height) + self.neg: str = kwargs.get('neg', sample_config.neg) + self.seed: Optional[int] = kwargs.get('seed', None) # if none, default to autogen seed + self.guidance_scale: float = kwargs.get('guidance_scale', sample_config.guidance_scale) + self.sample_steps: int = kwargs.get('sample_steps', sample_config.sample_steps) + self.fps: int = kwargs.get('fps', sample_config.fps) + self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames) + self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) + self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) + self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) + class SampleConfig: def __init__(self, **kwargs): @@ -44,7 +64,6 @@ class SampleConfig: self.sample_every: int = kwargs.get('sample_every', 100) self.width: int = kwargs.get('width', 512) self.height: int = kwargs.get('height', 512) - self.prompts: list[str] = kwargs.get('prompts', []) self.neg = kwargs.get('neg', False) self.seed = kwargs.get('seed', 0) self.walk_seed = kwargs.get('walk_seed', False) @@ -62,6 +81,23 @@ class SampleConfig: if self.num_frames > 1 and self.ext not in ['webp']: print("Changing sample extention to animated webp") self.ext = 'webp' + + prompts: list[str] = kwargs.get('prompts', []) + + self.samples: Optional[List[SampleItem]] = None + # use the legacy prompts if it is passed that way to get samples object + default_samples_kwargs = [ + {"prompt": x} for x in prompts + ] + raw_samples = kwargs.get('samples', default_samples_kwargs) + self.samples = [SampleItem(self, **item) for item in raw_samples] + + @property + def prompts(self): + # for backwards compatibility as this is checked for length frequently + return [sample.prompt for sample in self.samples if sample.prompt is not None] + + class LormModuleSettingsConfig: @@ -893,6 +929,7 @@ class GenerateImageConfig: refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end extra_values: List[float] = None, # extra values to save with prompt file logger: Optional[EmptyLogger] = None, + ctrl_img: Optional[str] = None, # control image for controlnet num_frames: int = 1, fps: int = 15, ctrl_idx: int = 0 @@ -926,7 +963,7 @@ class GenerateImageConfig: self.extra_values = extra_values if extra_values is not None else [] self.num_frames = num_frames self.fps = fps - self.ctrl_img = None + self.ctrl_img = ctrl_img self.ctrl_idx = ctrl_idx