mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
allow setting adapter weight in prompts
This commit is contained in:
@@ -198,6 +198,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
network_multiplier=sample_config.network_multiplier,
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ class SampleConfig:
|
||||
self.network_multiplier = kwargs.get('network_multiplier', 1)
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon']
|
||||
@@ -289,6 +290,7 @@ class GenerateImageConfig:
|
||||
output_tail: str = '', # tail to add to output filename
|
||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||
adapter_image_path: str = None, # path to adapter image
|
||||
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -312,6 +314,7 @@ class GenerateImageConfig:
|
||||
self.output_tail: str = output_tail
|
||||
self.gen_time: int = int(time.time() * 1000)
|
||||
self.adapter_image_path: str = adapter_image_path
|
||||
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
@@ -466,6 +469,8 @@ class GenerateImageConfig:
|
||||
self.network_multiplier = float(content)
|
||||
elif flag == 'gr':
|
||||
self.guidance_rescale = float(content)
|
||||
elif flag == 'a':
|
||||
self.adapter_conditioning_scale = float(content)
|
||||
|
||||
def post_process_embeddings(
|
||||
self,
|
||||
|
||||
@@ -366,7 +366,7 @@ class StableDiffusion:
|
||||
# not sure why this is double??
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = 1.0
|
||||
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
|
||||
Reference in New Issue
Block a user