WIP Ilora

This commit is contained in:
Jaret Burkett
2024-06-14 09:31:01 -06:00
parent bd10d2d668
commit 37cebd9458
6 changed files with 57 additions and 29 deletions

View File

@@ -376,7 +376,8 @@ class SDTrainer(BaseSDTrainProcess):
# 3 just do mode for now?
# if args.weighting_scheme == "sigma_sqrt":
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
weighting = (sigmas ** -2.0).float()
# weighting = (sigmas ** -2.0).float()
weighting = torch.ones_like(sigmas)
# elif args.weighting_scheme == "logit_normal":
# # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
# u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)

View File

@@ -188,6 +188,10 @@ class AdapterConfig:
# trains with a scaler to easy channel bias but merges it in on save
self.merge_scaler: bool = kwargs.get('merge_scaler', False)
# for ilora
self.head_dim: int = kwargs.get('head_dim', 1024)
self.num_heads: int = kwargs.get('num_heads', 1)
class EmbeddingConfig:
def __init__(self, **kwargs):

View File

@@ -145,7 +145,8 @@ class CustomAdapter(torch.nn.Module):
self.ilora_module = InstantLoRAModule(
vision_tokens=vision_tokens,
vision_hidden_size=vision_hidden_size,
head_dim=1024,
head_dim=self.config.head_dim,
num_heads=self.config.num_heads,
sd=self.sd_ref()
)
elif self.adapter_type == 'text_encoder':
@@ -878,6 +879,11 @@ class CustomAdapter(torch.nn.Module):
self.vision_encoder.gradient_checkpointing = True
def get_additional_save_metadata(self) -> Dict[str, Any]:
additional = {}
if self.config.type == 'ilora':
return self.ilora_module.get_additional_save_metadata()
return {}
extra = self.ilora_module.get_additional_save_metadata()
for k, v in extra.items():
additional[k] = v
additional['clip_layer'] = self.config.clip_layer
additional['image_encoder_arch'] = self.config.head_dim
return additional

View File

@@ -157,6 +157,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
use_bias: bool = False,
is_lorm: bool = False,
ignore_if_contains = None,
only_if_contains = None,
parameter_threshold: float = 0.0,
attn_only: bool = False,
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
@@ -186,6 +187,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if ignore_if_contains is None:
ignore_if_contains = []
self.ignore_if_contains = ignore_if_contains
self.only_if_contains: Union[List, None] = only_if_contains
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
@@ -250,6 +254,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
loras = []
skipped = []
attached_modules = []
lora_shape_dict = {}
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
@@ -269,6 +274,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]):
continue
dim = None
alpha = None
@@ -316,6 +324,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
use_bias=use_bias,
)
loras.append(lora)
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
]
return loras, skipped
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]

View File

@@ -181,6 +181,7 @@ class InstantLoRAModule(torch.nn.Module):
vision_hidden_size: int,
vision_tokens: int,
head_dim: int,
num_heads: int, # number of heads in the resampler
sd: 'StableDiffusion'
):
super(InstantLoRAModule, self).__init__()
@@ -190,6 +191,7 @@ class InstantLoRAModule(torch.nn.Module):
self.vision_hidden_size = vision_hidden_size
self.vision_tokens = vision_tokens
self.head_dim = head_dim
self.num_heads = num_heads
# stores the projection vector. Grabbed by modules
self.img_embeds: List[torch.Tensor] = None
@@ -243,7 +245,7 @@ class InstantLoRAModule(torch.nn.Module):
depth=4,
dim_head=64,
heads=12,
num_queries=1, # output tokens
num_queries=num_heads, # output tokens
embedding_dim=vision_hidden_size,
max_seq_len=vision_tokens,
output_dim=head_dim,
@@ -261,25 +263,26 @@ class InstantLoRAModule(torch.nn.Module):
self.migrate_weight_mapping()
def migrate_weight_mapping(self):
# changes the names of the modules to common ones
keymap = self.sd_ref().network.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
# invert them
save_keymap[diffusers_key] = ldm_key
new_keymap = {}
for key, value in self.weight_mapping:
if key in save_keymap:
new_keymap[save_keymap[key]] = value
else:
print(f"Key {key} not found in keymap")
new_keymap[key] = value
self.weight_mapping = new_keymap
else:
print("No keymap found. Using default names")
return
return
# # changes the names of the modules to common ones
# keymap = self.sd_ref().network.get_keymap()
# save_keymap = {}
# if keymap is not None:
# for ldm_key, diffusers_key in keymap.items():
# # invert them
# save_keymap[diffusers_key] = ldm_key
#
# new_keymap = {}
# for key, value in self.weight_mapping:
# if key in save_keymap:
# new_keymap[save_keymap[key]] = value
# else:
# print(f"Key {key} not found in keymap")
# new_keymap[key] = value
# self.weight_mapping = new_keymap
# else:
# print("No keymap found. Using default names")
# return
def forward(self, img_embeds):
@@ -291,7 +294,8 @@ class InstantLoRAModule(torch.nn.Module):
img_embeds = self.resampler(img_embeds)
img_embeds = self.proj_module(img_embeds)
if len(img_embeds.shape) == 3:
img_embeds = img_embeds.squeeze(1)
# merge the heads
img_embeds = img_embeds.mean(dim=1)
self.img_embeds = []
# get all the slices
@@ -304,6 +308,11 @@ class InstantLoRAModule(torch.nn.Module):
def get_additional_save_metadata(self) -> Dict[str, Any]:
# save the weight mapping
return {
"weight_mapping": self.weight_mapping
"weight_mapping": self.weight_mapping,
"num_heads": self.num_heads,
"vision_hidden_size": self.vision_hidden_size,
"head_dim": self.head_dim,
"vision_tokens": self.vision_tokens,
"output_size": self.output_size,
}

View File

@@ -576,13 +576,11 @@ class StableDiffusion:
if self.is_xl:
pipeline = Pipe(
vae=self.vae,
transformer=self.unet,
unet=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
text_encoder_3=self.text_encoder[2],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
tokenizer_3=self.tokenizer[2],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)