Improved lorm extraction and training

This commit is contained in:
Jaret Burkett
2023-10-28 08:21:59 -06:00
parent 0a79ac9604
commit 6f3e0d5af2
10 changed files with 559 additions and 196 deletions

View File

@@ -6,6 +6,8 @@ from diffusers import UNet2DConditionModel
from torch import Tensor
from tqdm import tqdm
from toolkit.config_modules import LoRMConfig
conv = nn.Conv2d
lin = nn.Linear
_size_2_t = Union[int, Tuple[int, int]]
@@ -29,12 +31,13 @@ CONV_MODULES = [
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
# "BasicTransformerBlock",
# "ResnetBlock2D",
"Downsample2D",
"Upsample2D",
]
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
@@ -279,13 +282,38 @@ def compute_optimal_bias(original_module, linear_down, linear_up, X):
return optimal_bias
def format_with_commas(n):
return f"{n:,}"
def print_lorm_extract_details(
start_num_params: int,
end_num_params: int,
num_replaced: int,
):
start_formatted = format_with_commas(start_num_params)
end_formatted = format_with_commas(end_num_params)
num_replaced_formatted = format_with_commas(num_replaced)
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
print(f"Convert UNet result:")
print(f" - converted: {num_replaced:>{width},} modules")
print(f" - start: {start_num_params:>{width},} params")
print(f" - end: {end_num_params:>{width},} params")
lorm_ignore_if_contains = [
'proj_out', 'proj_in',
]
lorm_parameter_threshold = 1000000
@torch.no_grad()
def convert_diffusers_unet_to_lorm(
unet: UNet2DConditionModel,
extract_mode: ExtractMode = "percentile",
mode_param: Union[int, float] = 0.5,
parameter_threshold: int = 500000,
# parameter_threshold: int = 1500000
config: LoRMConfig,
):
print('Converting UNet to LoRM UNet')
start_num_params = count_parameters(unet)
@@ -299,8 +327,6 @@ def convert_diffusers_unet_to_lorm(
ignore_if_contains = [
'proj_out', 'proj_in',
]
def format_with_commas(n):
return f"{n:,}"
for name, module in named_modules:
module_name = module.__class__.__name__
@@ -311,6 +337,13 @@ def convert_diffusers_unet_to_lorm(
combined_name = combined_name = f"{name}.{child_name}"
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
# pass
lorm_config = config.get_config_for_module(combined_name)
extract_mode = lorm_config.extract_mode
extract_mode_param = lorm_config.extract_mode_param
parameter_threshold = lorm_config.parameter_threshold
if any([word in child_name for word in ignore_if_contains]):
pass
@@ -322,7 +355,7 @@ def convert_diffusers_unet_to_lorm(
down_weight, up_weight, lora_dim, diff = extract_linear(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=mode_param,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
down_weight = down_weight.to(dtype=dtype)
@@ -362,7 +395,7 @@ def convert_diffusers_unet_to_lorm(
down_weight, up_weight, lora_dim, diff = extract_conv(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=mode_param,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
down_weight = down_weight.to(dtype=dtype)
@@ -395,30 +428,25 @@ def convert_diffusers_unet_to_lorm(
replace_module_by_path(unet, combined_name, new_module)
converted_modules.append(new_module)
num_replaced += 1
layer_names_replaced.append(f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
layer_names_replaced.append(
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
pbar.update(1)
pbar.close()
end_num_params = count_parameters(unet)
start_formatted = format_with_commas(start_num_params)
end_formatted = format_with_commas(end_num_params)
num_replaced_formatted = format_with_commas(num_replaced)
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
def sorting_key(s):
# Extract the number part, remove commas, and convert to integer
return int(s.split("-")[1].strip().replace(",", ""))
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
for layer_name in sorted_layer_names_replaced:
print(layer_name)
print(f"Convert UNet result:")
print(f" - converted: {num_replaced:>{width},} modules")
print(f" - start: {start_num_params:>{width},} params")
print(f" - end: {end_num_params:>{width},} params")
print_lorm_extract_details(
start_num_params=start_num_params,
end_num_params=end_num_params,
num_replaced=num_replaced,
)
return converted_modules