mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add ability to do more advanced sample prompt objects to prepart for a UI rework on control images and other things.
This commit is contained in:
@@ -301,25 +301,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
extra_args = {}
|
||||
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
||||
extra_args['adapter_image_path'] = test_image_paths[i]
|
||||
|
||||
sample_item = sample_config.samples[i]
|
||||
if sample_item.seed is not None:
|
||||
current_seed = sample_item.seed
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
width=sample_config.width,
|
||||
height=sample_config.height,
|
||||
negative_prompt=sample_config.neg,
|
||||
width=sample_item.width,
|
||||
height=sample_item.height,
|
||||
negative_prompt=sample_item.neg,
|
||||
seed=current_seed,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
guidance_scale=sample_item.guidance_scale,
|
||||
guidance_rescale=sample_config.guidance_rescale,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
network_multiplier=sample_config.network_multiplier,
|
||||
num_inference_steps=sample_item.sample_steps,
|
||||
network_multiplier=sample_item.network_multiplier,
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||
refiner_start_at=sample_config.refiner_start_at,
|
||||
extra_values=sample_config.extra_values,
|
||||
logger=self.logger,
|
||||
num_frames=sample_config.num_frames,
|
||||
fps=sample_config.fps,
|
||||
num_frames=sample_item.num_frames,
|
||||
fps=sample_item.fps,
|
||||
ctrl_img=sample_item.ctrl_img,
|
||||
ctrl_idx=sample_item.ctrl_idx,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
|
||||
@@ -37,6 +37,26 @@ class LoggingConfig:
|
||||
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
|
||||
self.run_name: str = kwargs.get('run_name', None)
|
||||
|
||||
class SampleItem:
|
||||
def __init__(
|
||||
self,
|
||||
sample_config: 'SampleConfig',
|
||||
**kwargs
|
||||
):
|
||||
# prompt should always be in the kwargs
|
||||
self.prompt = kwargs.get('prompt', None)
|
||||
self.width: int = kwargs.get('width', sample_config.width)
|
||||
self.height: int = kwargs.get('height', sample_config.height)
|
||||
self.neg: str = kwargs.get('neg', sample_config.neg)
|
||||
self.seed: Optional[int] = kwargs.get('seed', None) # if none, default to autogen seed
|
||||
self.guidance_scale: float = kwargs.get('guidance_scale', sample_config.guidance_scale)
|
||||
self.sample_steps: int = kwargs.get('sample_steps', sample_config.sample_steps)
|
||||
self.fps: int = kwargs.get('fps', sample_config.fps)
|
||||
self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames)
|
||||
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
|
||||
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
|
||||
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
|
||||
|
||||
|
||||
class SampleConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -44,7 +64,6 @@ class SampleConfig:
|
||||
self.sample_every: int = kwargs.get('sample_every', 100)
|
||||
self.width: int = kwargs.get('width', 512)
|
||||
self.height: int = kwargs.get('height', 512)
|
||||
self.prompts: list[str] = kwargs.get('prompts', [])
|
||||
self.neg = kwargs.get('neg', False)
|
||||
self.seed = kwargs.get('seed', 0)
|
||||
self.walk_seed = kwargs.get('walk_seed', False)
|
||||
@@ -62,6 +81,23 @@ class SampleConfig:
|
||||
if self.num_frames > 1 and self.ext not in ['webp']:
|
||||
print("Changing sample extention to animated webp")
|
||||
self.ext = 'webp'
|
||||
|
||||
prompts: list[str] = kwargs.get('prompts', [])
|
||||
|
||||
self.samples: Optional[List[SampleItem]] = None
|
||||
# use the legacy prompts if it is passed that way to get samples object
|
||||
default_samples_kwargs = [
|
||||
{"prompt": x} for x in prompts
|
||||
]
|
||||
raw_samples = kwargs.get('samples', default_samples_kwargs)
|
||||
self.samples = [SampleItem(self, **item) for item in raw_samples]
|
||||
|
||||
@property
|
||||
def prompts(self):
|
||||
# for backwards compatibility as this is checked for length frequently
|
||||
return [sample.prompt for sample in self.samples if sample.prompt is not None]
|
||||
|
||||
|
||||
|
||||
|
||||
class LormModuleSettingsConfig:
|
||||
@@ -893,6 +929,7 @@ class GenerateImageConfig:
|
||||
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
||||
extra_values: List[float] = None, # extra values to save with prompt file
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
ctrl_img: Optional[str] = None, # control image for controlnet
|
||||
num_frames: int = 1,
|
||||
fps: int = 15,
|
||||
ctrl_idx: int = 0
|
||||
@@ -926,7 +963,7 @@ class GenerateImageConfig:
|
||||
self.extra_values = extra_values if extra_values is not None else []
|
||||
self.num_frames = num_frames
|
||||
self.fps = fps
|
||||
self.ctrl_img = None
|
||||
self.ctrl_img = ctrl_img
|
||||
self.ctrl_idx = ctrl_idx
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user