mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Improved lorm extraction and training
This commit is contained in:
@@ -6,6 +6,8 @@ from diffusers import UNet2DConditionModel
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import LoRMConfig
|
||||
|
||||
conv = nn.Conv2d
|
||||
lin = nn.Linear
|
||||
_size_2_t = Union[int, Tuple[int, int]]
|
||||
@@ -29,12 +31,13 @@ CONV_MODULES = [
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
# "BasicTransformerBlock",
|
||||
# "ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
]
|
||||
|
||||
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
|
||||
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
@@ -279,13 +282,38 @@ def compute_optimal_bias(original_module, linear_down, linear_up, X):
|
||||
return optimal_bias
|
||||
|
||||
|
||||
def format_with_commas(n):
|
||||
return f"{n:,}"
|
||||
|
||||
|
||||
def print_lorm_extract_details(
|
||||
start_num_params: int,
|
||||
end_num_params: int,
|
||||
num_replaced: int,
|
||||
):
|
||||
start_formatted = format_with_commas(start_num_params)
|
||||
end_formatted = format_with_commas(end_num_params)
|
||||
num_replaced_formatted = format_with_commas(num_replaced)
|
||||
|
||||
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
|
||||
|
||||
print(f"Convert UNet result:")
|
||||
print(f" - converted: {num_replaced:>{width},} modules")
|
||||
print(f" - start: {start_num_params:>{width},} params")
|
||||
print(f" - end: {end_num_params:>{width},} params")
|
||||
|
||||
|
||||
lorm_ignore_if_contains = [
|
||||
'proj_out', 'proj_in',
|
||||
]
|
||||
|
||||
lorm_parameter_threshold = 1000000
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusers_unet_to_lorm(
|
||||
unet: UNet2DConditionModel,
|
||||
extract_mode: ExtractMode = "percentile",
|
||||
mode_param: Union[int, float] = 0.5,
|
||||
parameter_threshold: int = 500000,
|
||||
# parameter_threshold: int = 1500000
|
||||
config: LoRMConfig,
|
||||
):
|
||||
print('Converting UNet to LoRM UNet')
|
||||
start_num_params = count_parameters(unet)
|
||||
@@ -299,8 +327,6 @@ def convert_diffusers_unet_to_lorm(
|
||||
ignore_if_contains = [
|
||||
'proj_out', 'proj_in',
|
||||
]
|
||||
def format_with_commas(n):
|
||||
return f"{n:,}"
|
||||
|
||||
for name, module in named_modules:
|
||||
module_name = module.__class__.__name__
|
||||
@@ -311,6 +337,13 @@ def convert_diffusers_unet_to_lorm(
|
||||
combined_name = combined_name = f"{name}.{child_name}"
|
||||
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
|
||||
# pass
|
||||
|
||||
lorm_config = config.get_config_for_module(combined_name)
|
||||
|
||||
extract_mode = lorm_config.extract_mode
|
||||
extract_mode_param = lorm_config.extract_mode_param
|
||||
parameter_threshold = lorm_config.parameter_threshold
|
||||
|
||||
if any([word in child_name for word in ignore_if_contains]):
|
||||
pass
|
||||
|
||||
@@ -322,7 +355,7 @@ def convert_diffusers_unet_to_lorm(
|
||||
down_weight, up_weight, lora_dim, diff = extract_linear(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=mode_param,
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
@@ -362,7 +395,7 @@ def convert_diffusers_unet_to_lorm(
|
||||
down_weight, up_weight, lora_dim, diff = extract_conv(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=mode_param,
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
@@ -395,30 +428,25 @@ def convert_diffusers_unet_to_lorm(
|
||||
replace_module_by_path(unet, combined_name, new_module)
|
||||
converted_modules.append(new_module)
|
||||
num_replaced += 1
|
||||
layer_names_replaced.append(f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
|
||||
layer_names_replaced.append(
|
||||
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
|
||||
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
end_num_params = count_parameters(unet)
|
||||
|
||||
start_formatted = format_with_commas(start_num_params)
|
||||
end_formatted = format_with_commas(end_num_params)
|
||||
num_replaced_formatted = format_with_commas(num_replaced)
|
||||
|
||||
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
|
||||
|
||||
def sorting_key(s):
|
||||
# Extract the number part, remove commas, and convert to integer
|
||||
return int(s.split("-")[1].strip().replace(",", ""))
|
||||
|
||||
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
|
||||
|
||||
for layer_name in sorted_layer_names_replaced:
|
||||
print(layer_name)
|
||||
|
||||
print(f"Convert UNet result:")
|
||||
print(f" - converted: {num_replaced:>{width},} modules")
|
||||
print(f" - start: {start_num_params:>{width},} params")
|
||||
print(f" - end: {end_num_params:>{width},} params")
|
||||
print_lorm_extract_details(
|
||||
start_num_params=start_num_params,
|
||||
end_num_params=end_num_params,
|
||||
num_replaced=num_replaced,
|
||||
)
|
||||
|
||||
return converted_modules
|
||||
|
||||
Reference in New Issue
Block a user