mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-04 18:19:49 +00:00
WIP Ilora
This commit is contained in:
@@ -376,7 +376,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# 3 just do mode for now?
|
||||
# if args.weighting_scheme == "sigma_sqrt":
|
||||
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
|
||||
weighting = (sigmas ** -2.0).float()
|
||||
# weighting = (sigmas ** -2.0).float()
|
||||
weighting = torch.ones_like(sigmas)
|
||||
# elif args.weighting_scheme == "logit_normal":
|
||||
# # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
# u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
||||
|
||||
@@ -188,6 +188,10 @@ class AdapterConfig:
|
||||
# trains with a scaler to easy channel bias but merges it in on save
|
||||
self.merge_scaler: bool = kwargs.get('merge_scaler', False)
|
||||
|
||||
# for ilora
|
||||
self.head_dim: int = kwargs.get('head_dim', 1024)
|
||||
self.num_heads: int = kwargs.get('num_heads', 1)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -145,7 +145,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.ilora_module = InstantLoRAModule(
|
||||
vision_tokens=vision_tokens,
|
||||
vision_hidden_size=vision_hidden_size,
|
||||
head_dim=1024,
|
||||
head_dim=self.config.head_dim,
|
||||
num_heads=self.config.num_heads,
|
||||
sd=self.sd_ref()
|
||||
)
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
@@ -878,6 +879,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder.gradient_checkpointing = True
|
||||
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
additional = {}
|
||||
if self.config.type == 'ilora':
|
||||
return self.ilora_module.get_additional_save_metadata()
|
||||
return {}
|
||||
extra = self.ilora_module.get_additional_save_metadata()
|
||||
for k, v in extra.items():
|
||||
additional[k] = v
|
||||
additional['clip_layer'] = self.config.clip_layer
|
||||
additional['image_encoder_arch'] = self.config.head_dim
|
||||
return additional
|
||||
@@ -157,6 +157,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
only_if_contains = None,
|
||||
parameter_threshold: float = 0.0,
|
||||
attn_only: bool = False,
|
||||
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
|
||||
@@ -186,6 +187,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if ignore_if_contains is None:
|
||||
ignore_if_contains = []
|
||||
self.ignore_if_contains = ignore_if_contains
|
||||
|
||||
self.only_if_contains: Union[List, None] = only_if_contains
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
@@ -250,6 +254,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
loras = []
|
||||
skipped = []
|
||||
attached_modules = []
|
||||
lora_shape_dict = {}
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
@@ -269,6 +274,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]):
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
@@ -316,6 +324,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
use_bias=use_bias,
|
||||
)
|
||||
loras.append(lora)
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
|
||||
]
|
||||
return loras, skipped
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
@@ -181,6 +181,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
vision_hidden_size: int,
|
||||
vision_tokens: int,
|
||||
head_dim: int,
|
||||
num_heads: int, # number of heads in the resampler
|
||||
sd: 'StableDiffusion'
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
@@ -190,6 +191,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.vision_tokens = vision_tokens
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: List[torch.Tensor] = None
|
||||
@@ -243,7 +245,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=1, # output tokens
|
||||
num_queries=num_heads, # output tokens
|
||||
embedding_dim=vision_hidden_size,
|
||||
max_seq_len=vision_tokens,
|
||||
output_dim=head_dim,
|
||||
@@ -261,25 +263,26 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
self.migrate_weight_mapping()
|
||||
|
||||
def migrate_weight_mapping(self):
|
||||
# changes the names of the modules to common ones
|
||||
keymap = self.sd_ref().network.get_keymap()
|
||||
save_keymap = {}
|
||||
if keymap is not None:
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
# invert them
|
||||
save_keymap[diffusers_key] = ldm_key
|
||||
|
||||
new_keymap = {}
|
||||
for key, value in self.weight_mapping:
|
||||
if key in save_keymap:
|
||||
new_keymap[save_keymap[key]] = value
|
||||
else:
|
||||
print(f"Key {key} not found in keymap")
|
||||
new_keymap[key] = value
|
||||
self.weight_mapping = new_keymap
|
||||
else:
|
||||
print("No keymap found. Using default names")
|
||||
return
|
||||
return
|
||||
# # changes the names of the modules to common ones
|
||||
# keymap = self.sd_ref().network.get_keymap()
|
||||
# save_keymap = {}
|
||||
# if keymap is not None:
|
||||
# for ldm_key, diffusers_key in keymap.items():
|
||||
# # invert them
|
||||
# save_keymap[diffusers_key] = ldm_key
|
||||
#
|
||||
# new_keymap = {}
|
||||
# for key, value in self.weight_mapping:
|
||||
# if key in save_keymap:
|
||||
# new_keymap[save_keymap[key]] = value
|
||||
# else:
|
||||
# print(f"Key {key} not found in keymap")
|
||||
# new_keymap[key] = value
|
||||
# self.weight_mapping = new_keymap
|
||||
# else:
|
||||
# print("No keymap found. Using default names")
|
||||
# return
|
||||
|
||||
|
||||
def forward(self, img_embeds):
|
||||
@@ -291,7 +294,8 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
img_embeds = self.proj_module(img_embeds)
|
||||
if len(img_embeds.shape) == 3:
|
||||
img_embeds = img_embeds.squeeze(1)
|
||||
# merge the heads
|
||||
img_embeds = img_embeds.mean(dim=1)
|
||||
|
||||
self.img_embeds = []
|
||||
# get all the slices
|
||||
@@ -304,6 +308,11 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
# save the weight mapping
|
||||
return {
|
||||
"weight_mapping": self.weight_mapping
|
||||
"weight_mapping": self.weight_mapping,
|
||||
"num_heads": self.num_heads,
|
||||
"vision_hidden_size": self.vision_hidden_size,
|
||||
"head_dim": self.head_dim,
|
||||
"vision_tokens": self.vision_tokens,
|
||||
"output_size": self.output_size,
|
||||
}
|
||||
|
||||
|
||||
@@ -576,13 +576,11 @@ class StableDiffusion:
|
||||
if self.is_xl:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
text_encoder_3=self.text_encoder[2],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
tokenizer_3=self.tokenizer[2],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
|
||||
Reference in New Issue
Block a user