mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add ability to do more advanced sample prompt objects to prepart for a UI rework on control images and other things.
This commit is contained in:
@@ -301,25 +301,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
extra_args = {}
|
extra_args = {}
|
||||||
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
||||||
extra_args['adapter_image_path'] = test_image_paths[i]
|
extra_args['adapter_image_path'] = test_image_paths[i]
|
||||||
|
|
||||||
|
sample_item = sample_config.samples[i]
|
||||||
|
if sample_item.seed is not None:
|
||||||
|
current_seed = sample_item.seed
|
||||||
|
|
||||||
gen_img_config_list.append(GenerateImageConfig(
|
gen_img_config_list.append(GenerateImageConfig(
|
||||||
prompt=prompt, # it will autoparse the prompt
|
prompt=prompt, # it will autoparse the prompt
|
||||||
width=sample_config.width,
|
width=sample_item.width,
|
||||||
height=sample_config.height,
|
height=sample_item.height,
|
||||||
negative_prompt=sample_config.neg,
|
negative_prompt=sample_item.neg,
|
||||||
seed=current_seed,
|
seed=current_seed,
|
||||||
guidance_scale=sample_config.guidance_scale,
|
guidance_scale=sample_item.guidance_scale,
|
||||||
guidance_rescale=sample_config.guidance_rescale,
|
guidance_rescale=sample_config.guidance_rescale,
|
||||||
num_inference_steps=sample_config.sample_steps,
|
num_inference_steps=sample_item.sample_steps,
|
||||||
network_multiplier=sample_config.network_multiplier,
|
network_multiplier=sample_item.network_multiplier,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
output_ext=sample_config.ext,
|
output_ext=sample_config.ext,
|
||||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||||
refiner_start_at=sample_config.refiner_start_at,
|
refiner_start_at=sample_config.refiner_start_at,
|
||||||
extra_values=sample_config.extra_values,
|
extra_values=sample_config.extra_values,
|
||||||
logger=self.logger,
|
logger=self.logger,
|
||||||
num_frames=sample_config.num_frames,
|
num_frames=sample_item.num_frames,
|
||||||
fps=sample_config.fps,
|
fps=sample_item.fps,
|
||||||
|
ctrl_img=sample_item.ctrl_img,
|
||||||
|
ctrl_idx=sample_item.ctrl_idx,
|
||||||
**extra_args
|
**extra_args
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,26 @@ class LoggingConfig:
|
|||||||
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
|
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
|
||||||
self.run_name: str = kwargs.get('run_name', None)
|
self.run_name: str = kwargs.get('run_name', None)
|
||||||
|
|
||||||
|
class SampleItem:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_config: 'SampleConfig',
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# prompt should always be in the kwargs
|
||||||
|
self.prompt = kwargs.get('prompt', None)
|
||||||
|
self.width: int = kwargs.get('width', sample_config.width)
|
||||||
|
self.height: int = kwargs.get('height', sample_config.height)
|
||||||
|
self.neg: str = kwargs.get('neg', sample_config.neg)
|
||||||
|
self.seed: Optional[int] = kwargs.get('seed', None) # if none, default to autogen seed
|
||||||
|
self.guidance_scale: float = kwargs.get('guidance_scale', sample_config.guidance_scale)
|
||||||
|
self.sample_steps: int = kwargs.get('sample_steps', sample_config.sample_steps)
|
||||||
|
self.fps: int = kwargs.get('fps', sample_config.fps)
|
||||||
|
self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames)
|
||||||
|
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
|
||||||
|
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
|
||||||
|
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
|
||||||
|
|
||||||
|
|
||||||
class SampleConfig:
|
class SampleConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -44,7 +64,6 @@ class SampleConfig:
|
|||||||
self.sample_every: int = kwargs.get('sample_every', 100)
|
self.sample_every: int = kwargs.get('sample_every', 100)
|
||||||
self.width: int = kwargs.get('width', 512)
|
self.width: int = kwargs.get('width', 512)
|
||||||
self.height: int = kwargs.get('height', 512)
|
self.height: int = kwargs.get('height', 512)
|
||||||
self.prompts: list[str] = kwargs.get('prompts', [])
|
|
||||||
self.neg = kwargs.get('neg', False)
|
self.neg = kwargs.get('neg', False)
|
||||||
self.seed = kwargs.get('seed', 0)
|
self.seed = kwargs.get('seed', 0)
|
||||||
self.walk_seed = kwargs.get('walk_seed', False)
|
self.walk_seed = kwargs.get('walk_seed', False)
|
||||||
@@ -62,6 +81,23 @@ class SampleConfig:
|
|||||||
if self.num_frames > 1 and self.ext not in ['webp']:
|
if self.num_frames > 1 and self.ext not in ['webp']:
|
||||||
print("Changing sample extention to animated webp")
|
print("Changing sample extention to animated webp")
|
||||||
self.ext = 'webp'
|
self.ext = 'webp'
|
||||||
|
|
||||||
|
prompts: list[str] = kwargs.get('prompts', [])
|
||||||
|
|
||||||
|
self.samples: Optional[List[SampleItem]] = None
|
||||||
|
# use the legacy prompts if it is passed that way to get samples object
|
||||||
|
default_samples_kwargs = [
|
||||||
|
{"prompt": x} for x in prompts
|
||||||
|
]
|
||||||
|
raw_samples = kwargs.get('samples', default_samples_kwargs)
|
||||||
|
self.samples = [SampleItem(self, **item) for item in raw_samples]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompts(self):
|
||||||
|
# for backwards compatibility as this is checked for length frequently
|
||||||
|
return [sample.prompt for sample in self.samples if sample.prompt is not None]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LormModuleSettingsConfig:
|
class LormModuleSettingsConfig:
|
||||||
@@ -893,6 +929,7 @@ class GenerateImageConfig:
|
|||||||
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
||||||
extra_values: List[float] = None, # extra values to save with prompt file
|
extra_values: List[float] = None, # extra values to save with prompt file
|
||||||
logger: Optional[EmptyLogger] = None,
|
logger: Optional[EmptyLogger] = None,
|
||||||
|
ctrl_img: Optional[str] = None, # control image for controlnet
|
||||||
num_frames: int = 1,
|
num_frames: int = 1,
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
ctrl_idx: int = 0
|
ctrl_idx: int = 0
|
||||||
@@ -926,7 +963,7 @@ class GenerateImageConfig:
|
|||||||
self.extra_values = extra_values if extra_values is not None else []
|
self.extra_values = extra_values if extra_values is not None else []
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.ctrl_img = None
|
self.ctrl_img = ctrl_img
|
||||||
self.ctrl_idx = ctrl_idx
|
self.ctrl_idx = ctrl_idx
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user