mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 12:23:57 +00:00
Imporvements to ip weight adaptation. Bug fixes. Added masking to direct guidance loss. Allow importing a file for random triggers. Handle bas meta images with improper sizing.
This commit is contained in:
@@ -1087,7 +1087,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs=pred_kwargs,
|
||||
batch=batch,
|
||||
noise=noise,
|
||||
unconditional_embeds=unconditional_embeds
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -392,7 +392,14 @@ class DatasetConfig:
|
||||
self.dataset_path: str = kwargs.get('dataset_path', None)
|
||||
|
||||
self.default_caption: str = kwargs.get('default_caption', None)
|
||||
self.random_triggers: List[str] = kwargs.get('random_triggers', [])
|
||||
random_triggers = kwargs.get('random_triggers', [])
|
||||
# if they are a string, load them from a file
|
||||
if isinstance(random_triggers, str) and os.path.exists(random_triggers):
|
||||
with open(random_triggers, 'r') as f:
|
||||
random_triggers = f.read().splitlines()
|
||||
# remove empty lines
|
||||
random_triggers = [line for line in random_triggers if line.strip() != '']
|
||||
self.random_triggers: List[str] = random_triggers
|
||||
self.caption_ext: str = kwargs.get('caption_ext', None)
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
|
||||
@@ -820,14 +820,24 @@ class MaskFileItemDTOMixin:
|
||||
if self.dataset_config.invert_mask:
|
||||
img = ImageOps.invert(img)
|
||||
w, h = img.size
|
||||
fix_size = False
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
fix_size = True
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
fix_size = True
|
||||
|
||||
if fix_size:
|
||||
# swap all the sizes
|
||||
self.scale_to_width, self.scale_to_height = self.scale_to_height, self.scale_to_width
|
||||
self.crop_width, self.crop_height = self.crop_height, self.crop_width
|
||||
self.crop_x, self.crop_y = self.crop_y, self.crop_x
|
||||
|
||||
|
||||
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
@@ -1052,8 +1062,14 @@ class PoiFileItemDTOMixin:
|
||||
crop_bottom = initial_height
|
||||
|
||||
poi_height = crop_bottom - poi_y
|
||||
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
|
||||
current_resolution = get_resolution(poi_width, poi_height)
|
||||
try:
|
||||
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
|
||||
current_resolution = get_resolution(poi_width, poi_height)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error getting resolution: {self.path}")
|
||||
raise e
|
||||
return False
|
||||
if current_resolution >= self.dataset_config.resolution:
|
||||
# We can break now
|
||||
break
|
||||
|
||||
@@ -194,6 +194,8 @@ def get_direct_guidance_loss(
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
mask_multiplier=None,
|
||||
prior_pred=None,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
@@ -248,6 +250,8 @@ def get_direct_guidance_loss(
|
||||
noise.detach().float(),
|
||||
reduction="none"
|
||||
)
|
||||
if mask_multiplier is not None:
|
||||
guidance_loss = guidance_loss * mask_multiplier
|
||||
|
||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||
|
||||
@@ -489,6 +493,8 @@ def get_guidance_loss(
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
mask_multiplier=None,
|
||||
prior_pred=None,
|
||||
**kwargs
|
||||
):
|
||||
# TODO add others and process individual batch items separately
|
||||
@@ -549,6 +555,8 @@ def get_guidance_loss(
|
||||
noise,
|
||||
sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -480,19 +480,36 @@ class IPAdapter(torch.nn.Module):
|
||||
current_shape = current_img_proj_state_dict[key].shape
|
||||
new_shape = value.shape
|
||||
if current_shape != new_shape:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_img_proj_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
try:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_img_proj_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
|
||||
|
||||
if len(current_shape) == 1:
|
||||
current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
|
||||
elif len(current_shape) == 2:
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
|
||||
elif len(current_shape) == 3:
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
|
||||
elif len(current_shape) == 4:
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
|
||||
:current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
|
||||
:current_shape[3]]
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
else:
|
||||
current_img_proj_state_dict[key] = value
|
||||
self.image_proj_model.load_state_dict(current_img_proj_state_dict)
|
||||
@@ -504,19 +521,36 @@ class IPAdapter(torch.nn.Module):
|
||||
current_shape = current_ip_adapter_state_dict[key].shape
|
||||
new_shape = value.shape
|
||||
if current_shape != new_shape:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
try:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
|
||||
|
||||
if(len(current_shape) == 1):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
|
||||
elif(len(current_shape) == 2):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
|
||||
elif(len(current_shape) == 3):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
|
||||
elif(len(current_shape) == 4):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]]
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
|
||||
else:
|
||||
current_ip_adapter_state_dict[key] = value
|
||||
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
|
||||
|
||||
Reference in New Issue
Block a user