From 290393f7ae7d2acc4df52b234a069a02ef1c72bc Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 11 Jan 2024 12:22:16 -0700 Subject: [PATCH] Imporvements to ip weight adaptation. Bug fixes. Added masking to direct guidance loss. Allow importing a file for random triggers. Handle bas meta images with improper sizing. --- extensions_built_in/sd_trainer/SDTrainer.py | 4 +- toolkit/config_modules.py | 9 ++- toolkit/dataloader_mixins.py | 28 +++++-- toolkit/guidance.py | 8 ++ toolkit/ip_adapter.py | 86 ++++++++++++++------- 5 files changed, 101 insertions(+), 34 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 42fd7cc4..f18e95c7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1087,7 +1087,9 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs=pred_kwargs, batch=batch, noise=noise, - unconditional_embeds=unconditional_embeds + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, ) else: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 066ca6b1..0b5732cb 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -392,7 +392,14 @@ class DatasetConfig: self.dataset_path: str = kwargs.get('dataset_path', None) self.default_caption: str = kwargs.get('default_caption', None) - self.random_triggers: List[str] = kwargs.get('random_triggers', []) + random_triggers = kwargs.get('random_triggers', []) + # if they are a string, load them from a file + if isinstance(random_triggers, str) and os.path.exists(random_triggers): + with open(random_triggers, 'r') as f: + random_triggers = f.read().splitlines() + # remove empty lines + random_triggers = [line for line in random_triggers if line.strip() != ''] + self.random_triggers: List[str] = random_triggers self.caption_ext: str = kwargs.get('caption_ext', None) self.random_scale: bool = kwargs.get('random_scale', False) self.random_crop: bool = kwargs.get('random_crop', False) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 0197ed36..fc40b69f 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -820,14 +820,24 @@ class MaskFileItemDTOMixin: if self.dataset_config.invert_mask: img = ImageOps.invert(img) w, h = img.size + fix_size = False if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + fix_size = True elif h > w and self.scale_to_height < self.scale_to_width: # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + fix_size = True + + if fix_size: + # swap all the sizes + self.scale_to_width, self.scale_to_height = self.scale_to_height, self.scale_to_width + self.crop_width, self.crop_height = self.crop_height, self.crop_width + self.crop_x, self.crop_y = self.crop_y, self.crop_x + + + if self.flip_x: # do a flip @@ -1052,8 +1062,14 @@ class PoiFileItemDTOMixin: crop_bottom = initial_height poi_height = crop_bottom - poi_y - # now we have our random crop, but it may be smaller than resolution. Check and expand if needed - current_resolution = get_resolution(poi_width, poi_height) + try: + # now we have our random crop, but it may be smaller than resolution. Check and expand if needed + current_resolution = get_resolution(poi_width, poi_height) + except Exception as e: + print(f"Error: {e}") + print(f"Error getting resolution: {self.path}") + raise e + return False if current_resolution >= self.dataset_config.resolution: # We can break now break diff --git a/toolkit/guidance.py b/toolkit/guidance.py index c83a77f7..1fe3a95c 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -194,6 +194,8 @@ def get_direct_guidance_loss( noise: torch.Tensor, sd: 'StableDiffusion', unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, **kwargs ): with torch.no_grad(): @@ -248,6 +250,8 @@ def get_direct_guidance_loss( noise.detach().float(), reduction="none" ) + if mask_multiplier is not None: + guidance_loss = guidance_loss * mask_multiplier guidance_loss = guidance_loss.mean([1, 2, 3]) @@ -489,6 +493,8 @@ def get_guidance_loss( noise: torch.Tensor, sd: 'StableDiffusion', unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, **kwargs ): # TODO add others and process individual batch items separately @@ -549,6 +555,8 @@ def get_guidance_loss( noise, sd, unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, **kwargs ) else: diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 9ebecfb6..5d58a627 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -480,19 +480,36 @@ class IPAdapter(torch.nn.Module): current_shape = current_img_proj_state_dict[key].shape new_shape = value.shape if current_shape != new_shape: - # merge in what we can and leave the other values as they are - if len(current_shape) == 1: - current_img_proj_state_dict[key][:new_shape[0]] = value - elif len(current_shape) == 2: - current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value - elif len(current_shape) == 3: - current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value - elif len(current_shape) == 4: - current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], - :new_shape[3]] = value - else: - raise ValueError(f"unknown shape: {current_shape}") - print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + try: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_img_proj_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + except RuntimeError as e: + print(e) + print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + + if len(current_shape) == 1: + current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]] + elif len(current_shape) == 2: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]] + elif len(current_shape) == 3: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + elif len(current_shape) == 4: + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") else: current_img_proj_state_dict[key] = value self.image_proj_model.load_state_dict(current_img_proj_state_dict) @@ -504,19 +521,36 @@ class IPAdapter(torch.nn.Module): current_shape = current_ip_adapter_state_dict[key].shape new_shape = value.shape if current_shape != new_shape: - # merge in what we can and leave the other values as they are - if len(current_shape) == 1: - current_ip_adapter_state_dict[key][:new_shape[0]] = value - elif len(current_shape) == 2: - current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value - elif len(current_shape) == 3: - current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value - elif len(current_shape) == 4: - current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], - :new_shape[3]] = value - else: - raise ValueError(f"unknown shape: {current_shape}") - print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + try: + # merge in what we can and leave the other values as they are + if len(current_shape) == 1: + current_ip_adapter_state_dict[key][:new_shape[0]] = value + elif len(current_shape) == 2: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value + elif len(current_shape) == 3: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value + elif len(current_shape) == 4: + current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2], + :new_shape[3]] = value + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + except RuntimeError as e: + print(e) + print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + + if(len(current_shape) == 1): + current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]] + elif(len(current_shape) == 2): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]] + elif(len(current_shape) == 3): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + elif(len(current_shape) == 4): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] + else: + raise ValueError(f"unknown shape: {current_shape}") + print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") + else: current_ip_adapter_state_dict[key] = value self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)