allow setting adapter weight in prompts

This commit is contained in:
Jaret Burkett
2023-09-24 06:51:54 -06:00
parent e5153d87c9
commit abf7cd221d
3 changed files with 7 additions and 1 deletions

View File

@@ -198,6 +198,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
network_multiplier=sample_config.network_multiplier, network_multiplier=sample_config.network_multiplier,
output_path=output_path, output_path=output_path,
output_ext=sample_config.ext, output_ext=sample_config.ext,
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
**extra_args **extra_args
)) ))

View File

@@ -37,6 +37,7 @@ class SampleConfig:
self.network_multiplier = kwargs.get('network_multiplier', 1) self.network_multiplier = kwargs.get('network_multiplier', 1)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext: ImgExt = kwargs.get('format', 'jpg') self.ext: ImgExt = kwargs.get('format', 'jpg')
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
NetworkType = Literal['lora', 'locon'] NetworkType = Literal['lora', 'locon']
@@ -289,6 +290,7 @@ class GenerateImageConfig:
output_tail: str = '', # tail to add to output filename output_tail: str = '', # tail to add to output filename
add_prompt_file: bool = False, # add a prompt file with generated image add_prompt_file: bool = False, # add a prompt file with generated image
adapter_image_path: str = None, # path to adapter image adapter_image_path: str = None, # path to adapter image
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
): ):
self.width: int = width self.width: int = width
self.height: int = height self.height: int = height
@@ -312,6 +314,7 @@ class GenerateImageConfig:
self.output_tail: str = output_tail self.output_tail: str = output_tail
self.gen_time: int = int(time.time() * 1000) self.gen_time: int = int(time.time() * 1000)
self.adapter_image_path: str = adapter_image_path self.adapter_image_path: str = adapter_image_path
self.adapter_conditioning_scale: float = adapter_conditioning_scale
# prompt string will override any settings above # prompt string will override any settings above
self._process_prompt_string() self._process_prompt_string()
@@ -466,6 +469,8 @@ class GenerateImageConfig:
self.network_multiplier = float(content) self.network_multiplier = float(content)
elif flag == 'gr': elif flag == 'gr':
self.guidance_rescale = float(content) self.guidance_rescale = float(content)
elif flag == 'a':
self.adapter_conditioning_scale = float(content)
def post_process_embeddings( def post_process_embeddings(
self, self,

View File

@@ -366,7 +366,7 @@ class StableDiffusion:
# not sure why this is double?? # not sure why this is double??
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image extra['image'] = validation_image
extra['adapter_conditioning_scale'] = 1.0 extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter): if isinstance(self.adapter, IPAdapter):
transform = transforms.Compose([ transform = transforms.Compose([
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),