diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index b15fe524..e5965dfb 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1670,11 +1670,12 @@ class SDTrainer(BaseSDTrainProcess): ) else: - with self.timer('predict_unet'): - if unconditional_embeds is not None: - unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() - if self.adapter and isinstance(self.adapter, CustomAdapter): + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + if self.adapter and isinstance(self.adapter, CustomAdapter): + with self.timer('condition_noisy_latents'): noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + with self.timer('predict_unet'): noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), timesteps=timesteps, diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 30813a2e..6cbaa542 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -252,9 +252,7 @@ class BaseSDTrainProcess(BaseTrainProcess): test_image_paths = [] if self.adapter_config is not None and self.adapter_config.test_img_path is not None: - test_image_path_list = self.adapter_config.test_img_path.split(',') - test_image_path_list = [p.strip() for p in test_image_path_list] - test_image_path_list = [p for p in test_image_path_list if p != ''] + test_image_path_list = self.adapter_config.test_img_path # divide up images so they are evenly distributed across prompts for i in range(len(sample_config.prompts)): test_image_paths.append(test_image_path_list[i % len(test_image_path_list)]) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a4cd1969..974c01ff 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -165,7 +165,13 @@ class AdapterConfig: self.downscale_factor: int = kwargs.get('downscale_factor', 8) self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') self.image_dir: str = kwargs.get('image_dir', None) - self.test_img_path: str = kwargs.get('test_img_path', None) + self.test_img_path: List[str] = kwargs.get('test_img_path', None) + if self.test_img_path is not None: + if isinstance(self.test_img_path, str): + self.test_img_path = self.test_img_path.split(',') + self.test_img_path = [p.strip() for p in self.test_img_path] + self.test_img_path = [p for p in self.test_img_path if p != ''] + self.train: str = kwargs.get('train', False) self.image_encoder_path: str = kwargs.get('image_encoder_path', None) self.name_or_path = kwargs.get('name_or_path', None) @@ -244,6 +250,7 @@ class AdapterConfig: self.num_control_images: int = kwargs.get('num_control_images', 1) # decimal for how often the control is dropped out and replaced with noise 1.0 is 100% self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0) + self.has_inpainting_input: bool = kwargs.get('has_inpainting_input', False) class EmbeddingConfig: @@ -714,6 +721,9 @@ class DatasetConfig: self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc + # inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will + # be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored + self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None) # instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters) self.full_size_control_images: bool = kwargs.get('full_size_control_images', False) self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index fca2ed2b..6dac3568 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -569,13 +569,56 @@ class CustomAdapter(torch.nn.Module): def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): with torch.no_grad(): if self.adapter_type in ['control_lora']: + # inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor + # 4th channel is the mask with 1 being keep area and 0 being area to inpaint. sd: StableDiffusion = self.sd_ref() - control_tensor = batch.control_tensor + inpainting_latent = None + if self.config.has_inpainting_input: + do_dropout = random.random() < self.config.control_image_dropout + if batch.inpaint_tensor is not None and not do_dropout: + # currently 0-1, we need rgb to be -1 to 1 before encoding with the vae + inpainting_tensor_rgba = batch.inpaint_tensor.to(latents.device, dtype=latents.dtype) + inpainting_tensor_mask = inpainting_tensor_rgba[:, 3:4, :, :] + inpainting_tensor_rgb = inpainting_tensor_rgba[:, :3, :, :] + # we need to make sure the inpaint area is black multiply the rgb channels by the mask + inpainting_tensor_rgb = inpainting_tensor_rgb * inpainting_tensor_mask + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if inpainting_tensor_rgb.shape[2] != batch.tensor.shape[2] or inpainting_tensor_rgb.shape[3] != batch.tensor.shape[3]: + inpainting_tensor_rgb = F.interpolate(inpainting_tensor_rgb, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear') + + # scale to -1 to 1 + inpainting_tensor_rgb = inpainting_tensor_rgb * 2 - 1 + + # encode it + inpainting_latent = sd.encode_images(inpainting_tensor_rgb).to(latents.device, latents.dtype) + + # resize the mask to match the new encoded size + inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear') + inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype) + # mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it. + inpainting_tensor_mask = 1 - inpainting_tensor_mask + # leave the mask as 0-1 and concat on channel of latents + inpainting_latent = torch.cat((inpainting_latent, inpainting_tensor_mask), dim=1) + else: + # we have iinpainting but didnt get a control. or we are doing a dropout + # the input needs to be all zeros for the latents and all 1s for the mask + inpainting_latent = torch.zeros_like(latents) + # add ones for the mask since we are technically inpainting everything + inpainting_latent = torch.cat((inpainting_latent, torch.ones_like(inpainting_latent[:, :1, :, :])), dim=1) + + if self.config.num_control_images == 1: + # this is our only control + control_latent = inpainting_latent.to(latents.device, latents.dtype) + latents = torch.cat((latents, control_latent), dim=1) + return latents.detach() + + control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype) if control_tensor is None: # concat random normal noise onto the latents # check dimension, this is before they are rearranged # it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging - ctrl = torch.randn( + ctrl = torch.zeros( latents.shape[0], # bs latents.shape[1] * self.num_control_images, # ch latents.shape[2], @@ -583,6 +626,9 @@ class CustomAdapter(torch.nn.Module): device=latents.device, dtype=latents.dtype ) + if inpainting_latent is not None: + # inpainting always comes first + ctrl = torch.cat((inpainting_latent, ctrl), dim=1) latents = torch.cat((latents, ctrl), dim=1) return latents.detach() # if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w] @@ -622,6 +668,9 @@ class CustomAdapter(torch.nn.Module): control_latent_list.append(control_latent) # stack them on the channel dimension control_latent = torch.cat(control_latent_list, dim=1) + if inpainting_latent is not None: + # inpainting always comes first + control_latent = torch.cat((inpainting_latent, control_latent), dim=1) # concat it onto the latents latents = torch.cat((latents, control_latent), dim=1) return latents.detach() diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 095fa014..bc5212ff 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -12,7 +12,7 @@ from PIL.ImageOps import exif_transpose from toolkit import image_utils from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ - UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin + UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin if TYPE_CHECKING: @@ -34,6 +34,7 @@ class FileItemDTO( CaptionProcessingDTOMixin, ImageProcessingDTOMixin, ControlFileItemDTOMixin, + InpaintControlFileItemDTOMixin, ClipImageFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, @@ -108,6 +109,7 @@ class FileItemDTO( self.tensor = None self.cleanup_latent() self.cleanup_control() + self.cleanup_inpaint() self.cleanup_clip_image() self.cleanup_mask() self.cleanup_unconditional() @@ -154,6 +156,22 @@ class DataLoaderBatchDTO: else: control_tensors.append(x.control_tensor) self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + + self.inpaint_tensor: Union[torch.Tensor, None] = None + if any([x.inpaint_tensor is not None for x in self.file_items]): + # find one to use as a base + base_inpaint_tensor = None + for x in self.file_items: + if x.inpaint_tensor is not None: + base_inpaint_tensor = x.inpaint_tensor + break + inpaint_tensors = [] + for x in self.file_items: + if x.inpaint_tensor is None: + inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor)) + else: + inpaint_tensors.append(x.inpaint_tensor) + self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors]) self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items] diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 43120a7f..115a6ac2 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -635,6 +635,8 @@ class ImageProcessingDTOMixin: self.get_latent() if self.has_control_image: self.load_control_image() + if self.has_inpaint_image: + self.load_inpaint_image() if self.has_clip_image: self.load_clip_image() if self.has_mask_image: @@ -730,6 +732,8 @@ class ImageProcessingDTOMixin: if not only_load_latents: if self.has_control_image: self.load_control_image() + if self.has_inpaint_image: + self.load_inpaint_image() if self.has_clip_image: self.load_clip_image() if self.has_mask_image: @@ -738,6 +742,89 @@ class ImageProcessingDTOMixin: self.load_unconditional_image() +class InpaintControlFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_inpaint_image = False + self.inpaint_path: Union[str, None] = None + self.inpaint_tensor: Union[torch.Tensor, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.inpaint_path is not None: + # find the control image path + inpaint_path = dataset_config.inpaint_path + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + + for ext in img_ext_list: + p = os.path.join(inpaint_path, file_name_no_ext + ext) + if os.path.exists(p): + self.inpaint_path = p + self.has_inpaint_image = True + break + + def load_inpaint_image(self: 'FileItemDTO'): + try: + # image must have alpha channel for inpaint + img = Image.open(self.inpaint_path) + # make sure has aplha + if img.mode != 'RGBA': + raise ValueError(f"Image must have alpha channel for inpaint: {self.inpaint_path}") + img = exif_transpose(img) + + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Inpaint images not supported for non-bucket datasets") + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + tensor = self.augment_spatial_control(img, transform=transform) + else: + tensor = transform(img) + + # is 0 to 1 with alpha + self.inpaint_tensor = tensor + + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.inpaint_path}") + + + def cleanup_inpaint(self: 'FileItemDTO'): + self.inpaint_tensor = None + + class ControlFileItemDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): @@ -786,7 +873,7 @@ class ControlFileItemDTOMixin: print_acc(f"Error: {e}") print_acc(f"Error loading image: {control_path}") - if self.full_size_control_images: + if not self.full_size_control_images: # we just scale them to 512x512: w, h = img.size img = img.resize((512, 512), Image.BICUBIC) diff --git a/toolkit/models/control_lora_adapter.py b/toolkit/models/control_lora_adapter.py index 23f9033d..3588302d 100644 --- a/toolkit/models/control_lora_adapter.py +++ b/toolkit/models/control_lora_adapter.py @@ -46,15 +46,24 @@ class ImgEmbedder(torch.nn.Module): cls, model: FluxTransformer2DModel, adapter: 'ControlLoraAdapter', - num_control_images=1 + num_control_images=1, + has_inpainting_input=False ): - if model.__class__.__name__ == 'FluxTransformer2DModel': + if model.__class__.__name__ == 'FluxTransformer2DModel': + num_adapter_in_channels = model.x_embedder.in_features * num_control_images + + if has_inpainting_input: + # inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask + # packed it is 64ch + 4ch mask + # so we need to add 4 to the input channels + num_adapter_in_channels += 4 + x_embedder: torch.nn.Linear = model.x_embedder img_embedder = cls( adapter, orig_layer=x_embedder, - in_channels=x_embedder.in_features * num_control_images, - out_channels=x_embedder.out_features, + in_channels=num_adapter_in_channels, + out_channels=x_embedder.out_features, ) # hijack the forward method @@ -181,7 +190,8 @@ class ControlLoraAdapter(torch.nn.Module): self.x_embedder = ImgEmbedder.from_model( sd.unet, self, - num_control_images=config.num_control_images + num_control_images=config.num_control_images, + has_inpainting_input=config.has_inpainting_input ) self.x_embedder.to(self.device_torch) diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index 7f66a0f2..62fd8710 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -16,6 +16,10 @@ from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance from diffusers.image_processor import PipelineImageInput +from PIL import Image +import torch.nn.functional as F +from torchvision import transforms + if is_torch_xla_available(): @@ -1428,6 +1432,22 @@ class FluxWithCFGPipeline(FluxPipeline): class FluxAdvancedControlPipeline(FluxControlPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + do_inpainting=False, + num_controls=1, + ): + self.do_inpainting = do_inpainting + self.num_controls = num_controls + super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer) + @torch.no_grad() def __call__( self, @@ -1581,6 +1601,17 @@ class FluxAdvancedControlPipeline(FluxControlPipeline): # 4. Prepare latent variables # num_channels_latents = self.transformer.config.in_channels // 8 num_channels_latents = 128 // 8 + + # pull mask off control image if there is one it is a pil image + mask = None + if control_image is not None and self.do_inpainting and control_image.mode == "RGBA": + control_img_array = np.array(control_image) + mask = control_img_array[:, :, 3:4] + # scale it to 0 - 1 + mask = mask / 255.0 + # multiply rgb by mask + control_img_array = control_img_array[:, :, :3] * mask + control_image = Image.fromarray(control_img_array.astype(np.uint8)) control_image = self.prepare_image( image=control_image, @@ -1593,14 +1624,28 @@ class FluxAdvancedControlPipeline(FluxControlPipeline): ) if control_image.ndim == 4: + num_control_channels = num_channels_latents control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if mask is not None: + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + mask = transform(mask).to(device, dtype=control_image.dtype).unsqueeze(0) + # resize mask to match control image + mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False) + mask = mask.to(device) + # invert mask + mask = 1 - mask + control_image = torch.cat([control_image, mask], dim=1) + num_control_channels += 1 height_control_image, width_control_image = control_image.shape[2:] control_image = self._pack_latents( control_image, batch_size * num_images_per_prompt, - num_channels_latents, + num_control_channels, height_control_image, width_control_image, ) @@ -1642,9 +1687,6 @@ class FluxAdvancedControlPipeline(FluxControlPipeline): guidance = guidance.expand(latents.shape[0]) else: guidance = None - - # flux has 64 input channels. - total_controls = (self.transformer.config.in_channels // 64) - 1 # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1652,7 +1694,16 @@ class FluxAdvancedControlPipeline(FluxControlPipeline): if self.interrupt: continue - control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)] + control_image_list = [] + for idx in range(self.num_controls): + if idx == 0 and self.do_inpainting: + ctrl = torch.zeros_like(latents) + # do ones for mask and zeros for image + ctrl = torch.cat([ctrl, torch.ones_like(ctrl[:, :, :4])], dim=2) + control_image_list.append(ctrl) + else: + control_image_list.append(torch.zeros_like(latents)) + control_image_list[control_image_idx] = control_image latent_model_input = torch.cat([latents] + control_image_list, dim=2) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 98f5c269..5e3f58f4 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1246,6 +1246,8 @@ class StableDiffusion: # see if it is a control lora if self.adapter.control_lora is not None: Pipe = FluxAdvancedControlPipeline + extra_args['do_inpainting'] = self.adapter.config.has_inpainting_input + extra_args['num_controls'] = self.adapter.config.num_control_images pipeline = Pipe( vae=self.vae, @@ -1257,6 +1259,7 @@ class StableDiffusion: scheduler=noise_scheduler, **extra_args ) + pipeline.watermark = None elif self.is_lumina2: pipeline = Lumina2Text2ImgPipeline( @@ -1355,7 +1358,14 @@ class StableDiffusion: extra = {} validation_image = None if self.adapter is not None and gen_config.adapter_image_path is not None: - validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") + validation_image = Image.open(gen_config.adapter_image_path) + # if the name doesnt have .inpainting. in it, make sure it is rgb + if ".inpaint." not in gen_config.adapter_image_path: + validation_image = validation_image.convert("RGB") + else: + # make sure it has an alpha + if validation_image.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") if isinstance(self.adapter, T2IAdapter): # not sure why this is double?? validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))