mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed Dora implementation. Still highly experimental
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -172,4 +172,5 @@ cython_debug/
|
|||||||
/output/*
|
/output/*
|
||||||
!/output/.gitkeep
|
!/output/.gitkeep
|
||||||
/extensions/*
|
/extensions/*
|
||||||
!/extensions/example
|
!/extensions/example
|
||||||
|
/temp
|
||||||
@@ -32,16 +32,30 @@ class GenerateConfig:
|
|||||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||||
self.ext = kwargs.get('ext', 'png')
|
self.ext = kwargs.get('ext', 'png')
|
||||||
self.prompt_file = kwargs.get('prompt_file', False)
|
self.prompt_file = kwargs.get('prompt_file', False)
|
||||||
|
self.prompts_in_file = self.prompts
|
||||||
if self.prompts is None:
|
if self.prompts is None:
|
||||||
raise ValueError("Prompts must be set")
|
raise ValueError("Prompts must be set")
|
||||||
if isinstance(self.prompts, str):
|
if isinstance(self.prompts, str):
|
||||||
if os.path.exists(self.prompts):
|
if os.path.exists(self.prompts):
|
||||||
with open(self.prompts, 'r', encoding='utf-8') as f:
|
with open(self.prompts, 'r', encoding='utf-8') as f:
|
||||||
self.prompts = f.read().splitlines()
|
self.prompts_in_file = f.read().splitlines()
|
||||||
self.prompts = [p.strip() for p in self.prompts if len(p.strip()) > 0]
|
self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
|
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
|
||||||
|
|
||||||
|
self.random_prompts = kwargs.get('random_prompts', False)
|
||||||
|
self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
|
||||||
|
self.max_images = kwargs.get('max_prompts', 10000)
|
||||||
|
|
||||||
|
if self.random_prompts:
|
||||||
|
self.prompts = []
|
||||||
|
for i in range(self.max_images):
|
||||||
|
num_prompts = random.randint(1, self.max_random_per_prompt)
|
||||||
|
prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)]
|
||||||
|
self.prompts.append(", ".join(prompt_list))
|
||||||
|
else:
|
||||||
|
self.prompts = self.prompts_in_file
|
||||||
|
|
||||||
if kwargs.get('shuffle', False):
|
if kwargs.get('shuffle', False):
|
||||||
# shuffle the prompts
|
# shuffle the prompts
|
||||||
random.shuffle(self.prompts)
|
random.shuffle(self.prompts)
|
||||||
@@ -78,6 +92,9 @@ class GenerateProcess(BaseProcess):
|
|||||||
print("Loading model...")
|
print("Loading model...")
|
||||||
self.sd.load_model()
|
self.sd.load_model()
|
||||||
|
|
||||||
|
print("Compiling model...")
|
||||||
|
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
|
||||||
|
|
||||||
print(f"Generating {len(self.generate_config.prompts)} images")
|
print(f"Generating {len(self.generate_config.prompts)} images")
|
||||||
# build prompt image configs
|
# build prompt image configs
|
||||||
prompt_image_configs = []
|
prompt_image_configs = []
|
||||||
|
|||||||
@@ -629,8 +629,23 @@ class ClipImageFileItemDTOMixin:
|
|||||||
# Convert RGB to BGR
|
# Convert RGB to BGR
|
||||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||||
|
|
||||||
# apply augmentations
|
if self.clip_vision_is_quad:
|
||||||
augmented = self.clip_image_aug_transform(image=open_cv_image)["image"]
|
# image is in a 2x2 gris. split, run augs, and recombine
|
||||||
|
# split
|
||||||
|
img1, img2 = np.hsplit(open_cv_image, 2)
|
||||||
|
img1_1, img1_2 = np.vsplit(img1, 2)
|
||||||
|
img2_1, img2_2 = np.vsplit(img2, 2)
|
||||||
|
# apply augmentations
|
||||||
|
img1_1 = self.clip_image_aug_transform(image=img1_1)["image"]
|
||||||
|
img1_2 = self.clip_image_aug_transform(image=img1_2)["image"]
|
||||||
|
img2_1 = self.clip_image_aug_transform(image=img2_1)["image"]
|
||||||
|
img2_2 = self.clip_image_aug_transform(image=img2_2)["image"]
|
||||||
|
# recombine
|
||||||
|
augmented = np.vstack((np.hstack((img1_1, img1_2)), np.hstack((img2_1, img2_2))))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# apply augmentations
|
||||||
|
augmented = self.clip_image_aug_transform(image=open_cv_image)["image"]
|
||||||
|
|
||||||
# convert back to RGB tensor
|
# convert back to RGB tensor
|
||||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||||
|
|||||||
@@ -22,6 +22,13 @@ CONV_MODULES = [
|
|||||||
'LoRACompatibleConv'
|
'LoRACompatibleConv'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def transpose(weight, fan_in_fan_out):
|
||||||
|
if not fan_in_fan_out:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
if isinstance(weight, torch.nn.Parameter):
|
||||||
|
return torch.nn.Parameter(weight.T)
|
||||||
|
return weight.T
|
||||||
|
|
||||||
class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||||
# def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
|
# def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
|
||||||
@@ -65,15 +72,26 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
self.is_checkpointing = False
|
self.is_checkpointing = False
|
||||||
|
|
||||||
# m = Magnitude column-wise across output dimension
|
|
||||||
self.magnitude = nn.Parameter(self.get_orig_weight().norm(p=2, dim=0, keepdim=True))
|
|
||||||
|
|
||||||
d_out = org_module.out_features
|
d_out = org_module.out_features
|
||||||
d_in = org_module.in_features
|
d_in = org_module.in_features
|
||||||
|
|
||||||
std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float())
|
std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float())
|
||||||
self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev)
|
# self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) # lora_A
|
||||||
self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in))
|
# self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) # lora_B
|
||||||
|
self.lora_up = nn.Linear(self.lora_dim, d_out, bias=False) # lora_B
|
||||||
|
# self.lora_up.weight.data = torch.randn_like(self.lora_up.weight.data) * std_dev
|
||||||
|
self.lora_up.weight.data = torch.zeros_like(self.lora_up.weight.data)
|
||||||
|
# self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
|
||||||
|
# self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
|
||||||
|
self.lora_down = nn.Linear(d_in, self.lora_dim, bias=False) # lora_A
|
||||||
|
# self.lora_down.weight.data = torch.zeros_like(self.lora_down.weight.data)
|
||||||
|
self.lora_down.weight.data = torch.randn_like(self.lora_down.weight.data) * std_dev
|
||||||
|
|
||||||
|
# m = Magnitude column-wise across output dimension
|
||||||
|
weight = self.get_orig_weight()
|
||||||
|
lora_weight = self.lora_up.weight @ self.lora_down.weight
|
||||||
|
weight_norm = self._get_weight_norm(weight, lora_weight)
|
||||||
|
self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True)
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module[0].forward
|
self.org_forward = self.org_module[0].forward
|
||||||
@@ -88,11 +106,33 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
return self.org_module[0].bias.data.detach()
|
return self.org_module[0].bias.data.detach()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def dora_forward(self, x, *args, **kwargs):
|
# def dora_forward(self, x, *args, **kwargs):
|
||||||
lora = torch.matmul(self.lora_up, self.lora_down)
|
# lora = torch.matmul(self.lora_A, self.lora_B)
|
||||||
adapted = self.get_orig_weight() + lora
|
# adapted = self.get_orig_weight() + lora
|
||||||
column_norm = adapted.norm(p=2, dim=0, keepdim=True)
|
# column_norm = adapted.norm(p=2, dim=0, keepdim=True)
|
||||||
norm_adapted = adapted / column_norm
|
# norm_adapted = adapted / column_norm
|
||||||
calc_weights = self.magnitude * norm_adapted
|
# calc_weights = self.magnitude * norm_adapted
|
||||||
return F.linear(x, calc_weights, self.get_orig_bias())
|
# return F.linear(x, calc_weights, self.get_orig_bias())
|
||||||
|
|
||||||
|
def _get_weight_norm(self, weight, scaled_lora_weight) -> torch.Tensor:
|
||||||
|
# calculate L2 norm of weight matrix, column-wise
|
||||||
|
weight = weight + scaled_lora_weight.to(weight.device)
|
||||||
|
weight_norm = torch.linalg.norm(weight, dim=1)
|
||||||
|
return weight_norm
|
||||||
|
|
||||||
|
def apply_dora(self, x, scaled_lora_weight):
|
||||||
|
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L192
|
||||||
|
# lora weight is already scaled
|
||||||
|
|
||||||
|
# magnitude = self.lora_magnitude_vector[active_adapter]
|
||||||
|
weight = self.get_orig_weight()
|
||||||
|
weight_norm = self._get_weight_norm(weight, scaled_lora_weight)
|
||||||
|
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
|
||||||
|
# "[...] we suggest treating ||V +∆V ||_c in
|
||||||
|
# Eq. (5) as a constant, thereby detaching it from the gradient
|
||||||
|
# graph. This means that while ||V + ∆V ||_c dynamically
|
||||||
|
# reflects the updates of ∆V , it won’t receive any gradient
|
||||||
|
# during backpropagation"
|
||||||
|
weight_norm = weight_norm.detach()
|
||||||
|
dora_weight = transpose(weight + scaled_lora_weight, False)
|
||||||
|
return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight)
|
||||||
|
|||||||
@@ -52,8 +52,14 @@ def broadcast_and_multiply(tensor, multiplier):
|
|||||||
for _ in range(num_extra_dims):
|
for _ in range(num_extra_dims):
|
||||||
multiplier = multiplier.unsqueeze(-1)
|
multiplier = multiplier.unsqueeze(-1)
|
||||||
|
|
||||||
# Multiplying the broadcasted tensor with the output tensor
|
try:
|
||||||
result = tensor * multiplier
|
# Multiplying the broadcasted tensor with the output tensor
|
||||||
|
result = tensor * multiplier
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(e)
|
||||||
|
print(tensor.size())
|
||||||
|
print(multiplier.size())
|
||||||
|
raise e
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -248,9 +254,9 @@ class ToolkitModuleMixin:
|
|||||||
# network is not active, avoid doing anything
|
# network is not active, avoid doing anything
|
||||||
return self.org_forward(x, *args, **kwargs)
|
return self.org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
if self.__class__.__name__ == "DoRAModule":
|
# if self.__class__.__name__ == "DoRAModule":
|
||||||
# return dora forward
|
# # return dora forward
|
||||||
return self.dora_forward(x, *args, **kwargs)
|
# return self.dora_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||||
lora_output = self._call_forward(x)
|
lora_output = self._call_forward(x)
|
||||||
@@ -263,7 +269,27 @@ class ToolkitModuleMixin:
|
|||||||
# todo check if this is correct, do we just concat when doing cfg?
|
# todo check if this is correct, do we just concat when doing cfg?
|
||||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||||
|
|
||||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
scaled_lora_output = broadcast_and_multiply(lora_output, multiplier)
|
||||||
|
|
||||||
|
if self.__class__.__name__ == "DoRAModule":
|
||||||
|
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417
|
||||||
|
# x = dropout(x)
|
||||||
|
# todo this wont match the dropout applied to the lora
|
||||||
|
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
|
||||||
|
lx = self.dropout(x)
|
||||||
|
# normal dropout
|
||||||
|
elif self.dropout is not None and self.training:
|
||||||
|
lx = torch.nn.functional.dropout(x, p=self.dropout)
|
||||||
|
else:
|
||||||
|
lx = x
|
||||||
|
lora_weight = self.lora_up.weight @ self.lora_down.weight
|
||||||
|
# scale it here
|
||||||
|
# todo handle our batch split scalers for slider training. For now take the mean of them
|
||||||
|
scale = multiplier.mean()
|
||||||
|
scaled_lora_weight = lora_weight * scale
|
||||||
|
scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight)
|
||||||
|
|
||||||
|
x = org_forwarded + scaled_lora_output
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self: Module):
|
def enable_gradient_checkpointing(self: Module):
|
||||||
@@ -413,12 +439,12 @@ class ToolkitNetworkMixin:
|
|||||||
new_keymap = {}
|
new_keymap = {}
|
||||||
for ldm_key, diffusers_key in keymap.items():
|
for ldm_key, diffusers_key in keymap.items():
|
||||||
ldm_key = ldm_key.replace('.alpha', '.magnitude')
|
ldm_key = ldm_key.replace('.alpha', '.magnitude')
|
||||||
ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
|
# ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
|
||||||
ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
|
# ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
|
||||||
|
|
||||||
diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
|
diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
|
||||||
diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
|
# diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
|
||||||
diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
|
# diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
|
||||||
|
|
||||||
new_keymap[ldm_key] = diffusers_key
|
new_keymap[ldm_key] = diffusers_key
|
||||||
|
|
||||||
@@ -513,12 +539,8 @@ class ToolkitNetworkMixin:
|
|||||||
multiplier = self._multiplier
|
multiplier = self._multiplier
|
||||||
# get first module
|
# get first module
|
||||||
first_module = self.get_all_modules()[0]
|
first_module = self.get_all_modules()[0]
|
||||||
if self.network_type.lower() == 'dora':
|
device = first_module.lora_down.weight.device
|
||||||
device = first_module.lora_down.device
|
dtype = first_module.lora_down.weight.dtype
|
||||||
dtype = first_module.lora_down.dtype
|
|
||||||
else:
|
|
||||||
device = first_module.lora_down.weight.device
|
|
||||||
dtype = first_module.lora_down.weight.dtype
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
tensor_multiplier = None
|
tensor_multiplier = None
|
||||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||||
|
|||||||
Reference in New Issue
Block a user