fix img2img inpaint model

This commit is contained in:
layerdiffusion
2024-08-05 11:36:08 -07:00
parent 6d8522b014
commit d77582aa5a

View File

@@ -343,24 +343,24 @@ class StableDiffusionProcessing:
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
# image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)
if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image)
# if self.sd_model.cond_stage_key == "edit":
# return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.sd_model.is_inpaint:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# if self.sampler.conditioning_key == "crossattn-adm":
# return self.unclip_image_conditioning(source_image)
#
# if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
# return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)