Imporvements to ip weight adaptation. Bug fixes. Added masking to direct guidance loss. Allow importing a file for random triggers. Handle bas meta images with improper sizing.

This commit is contained in:
Jaret Burkett
2024-01-11 12:22:16 -07:00
parent b2a54c8f36
commit 290393f7ae
5 changed files with 101 additions and 34 deletions

View File

@@ -1087,7 +1087,9 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
else:

View File

@@ -392,7 +392,14 @@ class DatasetConfig:
self.dataset_path: str = kwargs.get('dataset_path', None)
self.default_caption: str = kwargs.get('default_caption', None)
self.random_triggers: List[str] = kwargs.get('random_triggers', [])
random_triggers = kwargs.get('random_triggers', [])
# if they are a string, load them from a file
if isinstance(random_triggers, str) and os.path.exists(random_triggers):
with open(random_triggers, 'r') as f:
random_triggers = f.read().splitlines()
# remove empty lines
random_triggers = [line for line in random_triggers if line.strip() != '']
self.random_triggers: List[str] = random_triggers
self.caption_ext: str = kwargs.get('caption_ext', None)
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)

View File

@@ -820,14 +820,24 @@ class MaskFileItemDTOMixin:
if self.dataset_config.invert_mask:
img = ImageOps.invert(img)
w, h = img.size
fix_size = False
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
fix_size = True
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
fix_size = True
if fix_size:
# swap all the sizes
self.scale_to_width, self.scale_to_height = self.scale_to_height, self.scale_to_width
self.crop_width, self.crop_height = self.crop_height, self.crop_width
self.crop_x, self.crop_y = self.crop_y, self.crop_x
if self.flip_x:
# do a flip
@@ -1052,8 +1062,14 @@ class PoiFileItemDTOMixin:
crop_bottom = initial_height
poi_height = crop_bottom - poi_y
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
current_resolution = get_resolution(poi_width, poi_height)
try:
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
current_resolution = get_resolution(poi_width, poi_height)
except Exception as e:
print(f"Error: {e}")
print(f"Error getting resolution: {self.path}")
raise e
return False
if current_resolution >= self.dataset_config.resolution:
# We can break now
break

View File

@@ -194,6 +194,8 @@ def get_direct_guidance_loss(
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
**kwargs
):
with torch.no_grad():
@@ -248,6 +250,8 @@ def get_direct_guidance_loss(
noise.detach().float(),
reduction="none"
)
if mask_multiplier is not None:
guidance_loss = guidance_loss * mask_multiplier
guidance_loss = guidance_loss.mean([1, 2, 3])
@@ -489,6 +493,8 @@ def get_guidance_loss(
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
**kwargs
):
# TODO add others and process individual batch items separately
@@ -549,6 +555,8 @@ def get_guidance_loss(
noise,
sd,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
**kwargs
)
else:

View File

@@ -480,19 +480,36 @@ class IPAdapter(torch.nn.Module):
current_shape = current_img_proj_state_dict[key].shape
new_shape = value.shape
if current_shape != new_shape:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_img_proj_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
try:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_img_proj_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
except RuntimeError as e:
print(e)
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
if len(current_shape) == 1:
current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
elif len(current_shape) == 2:
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
elif len(current_shape) == 3:
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
elif len(current_shape) == 4:
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
:current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
:current_shape[3]]
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
else:
current_img_proj_state_dict[key] = value
self.image_proj_model.load_state_dict(current_img_proj_state_dict)
@@ -504,19 +521,36 @@ class IPAdapter(torch.nn.Module):
current_shape = current_ip_adapter_state_dict[key].shape
new_shape = value.shape
if current_shape != new_shape:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_ip_adapter_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
try:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_ip_adapter_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
except RuntimeError as e:
print(e)
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
if(len(current_shape) == 1):
current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
elif(len(current_shape) == 2):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
elif(len(current_shape) == 3):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
elif(len(current_shape) == 4):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]]
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
else:
current_ip_adapter_state_dict[key] = value
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)