diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 58d48b4f..c81141da 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -506,15 +506,55 @@ class BaseModel: unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) else: ctrl_img = None + has_control_images = False + if gen_config.ctrl_img is not None or gen_config.ctrl_img_1 is not None or gen_config.ctrl_img_2 is not None or gen_config.ctrl_img_3 is not None: + has_control_images = True # load the control image if out model uses it in text encoding - if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings: - ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) + if has_control_images and self.encode_control_in_text_embeddings: + ctrl_img_list = [] + + if gen_config.ctrl_img is not None: + ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None # encode the prompt ourselves so we can do fun stuff with embeddings if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False diff --git a/version.py b/version.py index fd833cdf..57b37388 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.6.2" \ No newline at end of file +VERSION = "0.6.3" \ No newline at end of file