From abf7cd221d17365d4a47b066233603ae2629c4cf Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 24 Sep 2023 06:51:54 -0600 Subject: [PATCH] allow setting adapter weight in prompts --- jobs/process/BaseSDTrainProcess.py | 1 + toolkit/config_modules.py | 5 +++++ toolkit/stable_diffusion_model.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index bbcb466b..276e72fc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -198,6 +198,7 @@ class BaseSDTrainProcess(BaseTrainProcess): network_multiplier=sample_config.network_multiplier, output_path=output_path, output_ext=sample_config.ext, + adapter_conditioning_scale=sample_config.adapter_conditioning_scale, **extra_args )) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 6e0dc2f1..a7adbe3e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -37,6 +37,7 @@ class SampleConfig: self.network_multiplier = kwargs.get('network_multiplier', 1) self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) self.ext: ImgExt = kwargs.get('format', 'jpg') + self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) NetworkType = Literal['lora', 'locon'] @@ -289,6 +290,7 @@ class GenerateImageConfig: output_tail: str = '', # tail to add to output filename add_prompt_file: bool = False, # add a prompt file with generated image adapter_image_path: str = None, # path to adapter image + adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning ): self.width: int = width self.height: int = height @@ -312,6 +314,7 @@ class GenerateImageConfig: self.output_tail: str = output_tail self.gen_time: int = int(time.time() * 1000) self.adapter_image_path: str = adapter_image_path + self.adapter_conditioning_scale: float = adapter_conditioning_scale # prompt string will override any settings above self._process_prompt_string() @@ -466,6 +469,8 @@ class GenerateImageConfig: self.network_multiplier = float(content) elif flag == 'gr': self.guidance_rescale = float(content) + elif flag == 'a': + self.adapter_conditioning_scale = float(content) def post_process_embeddings( self, diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 1296a856..9a827f82 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -366,7 +366,7 @@ class StableDiffusion: # not sure why this is double?? validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) extra['image'] = validation_image - extra['adapter_conditioning_scale'] = 1.0 + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale if isinstance(self.adapter, IPAdapter): transform = transforms.Compose([ transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),