Merge branch 'main' into wan21

This commit is contained in:
Jaret Burkett
2025-03-04 00:31:57 -07:00
11 changed files with 603 additions and 132 deletions

File diff suppressed because one or more lines are too long

309
scripts/update_sponsors.py Normal file
View File

@@ -0,0 +1,309 @@
import os
import requests
import json
from datetime import datetime
from dotenv import load_dotenv
# Load environment variables from .env file
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env")
load_dotenv(dotenv_path=env_path)
# API credentials
PATREON_TOKEN = os.getenv("PATREON_ACCESS_TOKEN")
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")
GITHUB_USERNAME = os.getenv("GITHUB_USERNAME")
GITHUB_ORG = os.getenv("GITHUB_ORG") # Organization name (optional)
# Output file
README_PATH = "SUPPORTERS.md"
def fetch_patreon_supporters():
"""Fetch current Patreon supporters"""
print("Fetching Patreon supporters...")
headers = {
"Authorization": f"Bearer {PATREON_TOKEN}",
"Content-Type": "application/json"
}
url = "https://www.patreon.com/api/oauth2/v2/campaigns"
try:
# First get the campaign ID
campaign_response = requests.get(url, headers=headers)
campaign_response.raise_for_status()
campaign_data = campaign_response.json()
if not campaign_data.get('data'):
print("No campaigns found for this Patreon account")
return []
campaign_id = campaign_data['data'][0]['id']
# Now get the supporters for this campaign
members_url = f"https://www.patreon.com/api/oauth2/v2/campaigns/{campaign_id}/members"
params = {
"include": "user",
"fields[member]": "full_name,is_follower,patron_status", # Removed profile_url
"fields[user]": "image_url"
}
supporters = []
while members_url:
members_response = requests.get(members_url, headers=headers, params=params)
members_response.raise_for_status()
members_data = members_response.json()
# Process the response to extract active patrons
for member in members_data.get('data', []):
attributes = member.get('attributes', {})
# Only include active patrons
if attributes.get('patron_status') == 'active_patron':
name = attributes.get('full_name', 'Anonymous Supporter')
# Get user data which contains the profile image
user_id = member.get('relationships', {}).get('user', {}).get('data', {}).get('id')
profile_image = None
profile_url = None # Removed profile_url since it's not supported
if user_id:
for included in members_data.get('included', []):
if included.get('id') == user_id and included.get('type') == 'user':
profile_image = included.get('attributes', {}).get('image_url')
break
supporters.append({
'name': name,
'profile_image': profile_image,
'profile_url': profile_url, # This will be None
'platform': 'Patreon',
'amount': 0 # Placeholder, as Patreon API doesn't provide this in the current response
})
# Handle pagination
members_url = members_data.get('links', {}).get('next')
print(f"Found {len(supporters)} active Patreon supporters")
return supporters
except requests.exceptions.RequestException as e:
print(f"Error fetching Patreon data: {e}")
print(f"Response content: {e.response.content if hasattr(e, 'response') else 'No response content'}")
return []
def fetch_github_sponsors():
"""Fetch current GitHub sponsors for a user or organization"""
print("Fetching GitHub sponsors...")
headers = {
"Authorization": f"Bearer {GITHUB_TOKEN}",
"Accept": "application/vnd.github.v3+json"
}
# Determine if we're fetching for a user or an organization
entity_type = "organization" if GITHUB_ORG else "user"
entity_name = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
if not entity_name:
print("Error: Neither GITHUB_USERNAME nor GITHUB_ORG is set")
return []
# Different GraphQL query structure based on entity type
if entity_type == "user":
query = """
query {
user(login: "%s") {
sponsorshipsAsMaintainer(first: 100) {
nodes {
sponsorEntity {
... on User {
login
name
avatarUrl
url
}
... on Organization {
login
name
avatarUrl
url
}
}
tier {
monthlyPriceInDollars
}
isOneTimePayment
isActive
}
}
}
}
""" % entity_name
else: # organization
query = """
query {
organization(login: "%s") {
sponsorshipsAsMaintainer(first: 100) {
nodes {
sponsorEntity {
... on User {
login
name
avatarUrl
url
}
... on Organization {
login
name
avatarUrl
url
}
}
tier {
monthlyPriceInDollars
}
isOneTimePayment
isActive
}
}
}
}
""" % entity_name
try:
response = requests.post(
"https://api.github.com/graphql",
headers=headers,
json={"query": query}
)
response.raise_for_status()
data = response.json()
# Process the response - the path to the data differs based on entity type
if entity_type == "user":
sponsors_data = data.get('data', {}).get('user', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
else:
sponsors_data = data.get('data', {}).get('organization', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
sponsors = []
for sponsor in sponsors_data:
# Only include active sponsors
if sponsor.get('isActive'):
entity = sponsor.get('sponsorEntity', {})
name = entity.get('name') or entity.get('login', 'Anonymous Sponsor')
profile_image = entity.get('avatarUrl')
profile_url = entity.get('url')
amount = sponsor.get('tier', {}).get('monthlyPriceInDollars', 0)
sponsors.append({
'name': name,
'profile_image': profile_image,
'profile_url': profile_url,
'platform': 'GitHub Sponsors',
'amount': amount
})
print(f"Found {len(sponsors)} active GitHub sponsors for {entity_type} '{entity_name}'")
return sponsors
except requests.exceptions.RequestException as e:
print(f"Error fetching GitHub sponsors data: {e}")
return []
def generate_readme(supporters):
"""Generate a README.md file with supporter information"""
print(f"Generating {README_PATH}...")
# Sort supporters by amount (descending) and then by name
supporters.sort(key=lambda x: (-x['amount'], x['name'].lower()))
# Determine the proper footer links based on what's configured
github_entity = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
github_entity_type = "orgs" if GITHUB_ORG else "sponsors"
github_sponsor_url = f"https://github.com/{github_entity_type}/{github_entity}"
with open(README_PATH, "w", encoding="utf-8") as f:
f.write("## Support My Work\n\n")
f.write("If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps! \n\n")
# Create appropriate call-to-action based on what's configured
cta_parts = []
if github_entity:
cta_parts.append(f"[Become a sponsor on GitHub]({github_sponsor_url})")
if PATREON_TOKEN:
cta_parts.append("[support me on Patreon](https://www.patreon.com/ostris)")
if cta_parts:
if GITHUB_ORG:
f.write(f"{' or '.join(cta_parts)}.\n\n")
f.write("Thank you to all my current supporters!\n\n")
f.write(f"_Last updated: {datetime.now().strftime('%Y-%m-%d')}_\n\n")
# Write GitHub Sponsors section
github_sponsors = [s for s in supporters if s['platform'] == 'GitHub Sponsors']
if github_sponsors:
f.write("### GitHub Sponsors\n\n")
for sponsor in github_sponsors:
if sponsor['profile_image']:
f.write(f"<a href=\"{sponsor['profile_url']}\" title=\"{sponsor['name']}\"><img src=\"{sponsor['profile_image']}\" width=\"50\" height=\"50\" alt=\"{sponsor['name']}\" style=\"border-radius:50%\"></a> ")
else:
f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ")
f.write("\n\n")
# Write Patreon section
patreon_supporters = [s for s in supporters if s['platform'] == 'Patreon']
if patreon_supporters:
f.write("### Patreon Supporters\n\n")
for supporter in patreon_supporters:
if supporter['profile_image']:
f.write(f"<a href=\"{supporter['profile_url']}\" title=\"{supporter['name']}\"><img src=\"{supporter['profile_image']}\" width=\"50\" height=\"50\" alt=\"{supporter['name']}\" style=\"border-radius:50%\"></a> ")
else:
f.write(f"[{supporter['name']}]({supporter['profile_url']}) ")
f.write("\n\n")
f.write("\n---\n\n")
print(f"Successfully generated {README_PATH} with {len(supporters)} supporters!")
def main():
"""Main function"""
print("Starting supporter data collection...")
# Check if required environment variables are set
missing_vars = []
if not GITHUB_TOKEN:
missing_vars.append("GITHUB_TOKEN")
# Either username or org is required for GitHub
if not GITHUB_USERNAME and not GITHUB_ORG:
missing_vars.append("GITHUB_USERNAME or GITHUB_ORG")
# Patreon token is optional but warn if missing
patreon_enabled = bool(PATREON_TOKEN)
if missing_vars:
print(f"Error: Missing required environment variables: {', '.join(missing_vars)}")
print("Please add them to your .env file")
return
if not patreon_enabled:
print("Warning: PATREON_ACCESS_TOKEN not set. Will only fetch GitHub sponsors.")
# Fetch data from both platforms
patreon_supporters = fetch_patreon_supporters() if PATREON_TOKEN else []
github_sponsors = fetch_github_sponsors()
# Combine supporters from both platforms
all_supporters = patreon_supporters + github_sponsors
if not all_supporters:
print("No supporters found on either platform")
return
# Generate README
generate_readme(all_supporters)
if __name__ == "__main__":
main()

View File

@@ -1,3 +0,0 @@
- only do ema on main device? shouldne be needed other than saving and sampling
- check when to unwrap model and what it does
- disable timer for non main local

View File

@@ -135,6 +135,15 @@ class NetworkConfig:
self.conv = 4
self.transformer_only = kwargs.get('transformer_only', True)
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
if self.lokr_full_rank and self.type.lower() == 'lokr':
self.linear = 9999999999
self.linear_alpha = 9999999999
self.conv = 9999999999
self.conv_alpha = 9999999999
# -1 automatically finds the largest factor
self.lokr_factor = kwargs.get('lokr_factor', -1)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']

View File

@@ -231,12 +231,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
module_class = DoRAModule
elif self.network_type.lower() == "lokr":
self.module_class = LokrModule
module_class = LokrModule
self.network_config: NetworkConfig = kwargs.get("network_config", None)
self.peft_format = peft_format
# always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2:
self.peft_format = True
# don't do peft format for lokr
if self.network_type.lower() != "lokr":
self.peft_format = True
if self.peft_format:
# no alpha for peft
@@ -338,8 +344,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if (is_linear or is_conv2d) and not skip:
if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
continue
if self.only_if_contains is not None:
if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]):
continue
dim = None
alpha = None
@@ -373,6 +380,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name)
continue
module_kwargs = {}
if self.network_type.lower() == "lokr":
module_kwargs["factor"] = self.network_config.lokr_factor
lora = module_class(
lora_name,
@@ -386,10 +398,16 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
network=self,
parent=module,
use_bias=use_bias,
**module_kwargs
)
loras.append(lora)
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
]
if self.network_type.lower() == "lokr":
try:
lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)]
except:
pass
else:
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
return loras, skipped
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]

View File

@@ -10,24 +10,23 @@ from toolkit.network_mixins import ToolkitModuleMixin
from typing import TYPE_CHECKING, Union, List
from optimum.quanto import QBytesTensor, QTensor
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# 4, build custom backward function
# -
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
@@ -38,7 +37,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
@@ -47,12 +46,12 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
while m < n:
new_m = m + 1
while dimension%new_m != 0:
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
@@ -62,7 +61,8 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
def make_weight_cp(t, wa, wb):
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l',
t, wa, wb) # [c, d, k1, k2]
return rebuild2
@@ -71,31 +71,25 @@ def make_kron(w1, w2, scale):
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
rebuild = torch.kron(w1, w2)
return rebuild*scale
class LokrModule(ToolkitModuleMixin, nn.Module):
"""
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.,
rank_dropout=0.,
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.,
rank_dropout=0.,
module_dropout=0.,
use_cp=False,
decompose_both = False,
decompose_both=False,
network: 'LoRASpecialNetwork' = None,
factor:int=-1, # factorization factor
factor: int = -1, # factorization factor
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
@@ -107,38 +101,49 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
self.cp = False
self.use_w1 = False
self.use_w2 = False
self.can_merge_in = True
self.shape = org_module.weight.shape
if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
self.cp = use_cp and k_size!=(1, 1)
# ((a, b), (c, d), *k_size)
shape = ((out_l, out_k), (in_m, in_n), *k_size)
self.cp = use_cp and k_size != (1, 1)
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
self.lokr_w1_a = nn.Parameter(
torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(
torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
self.lokr_w1 = nn.Parameter(torch.empty(
shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim >= max(shape[0][1], shape[1][1])/2:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
self.lokr_w2 = nn.Parameter(torch.empty(
shape[0][1], shape[1][1], *k_size))
elif self.cp:
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
else: # Conv2d not cp
self.lokr_t2 = nn.Parameter(torch.empty(
lora_dim, lora_dim, shape[2], shape[3]))
self.lokr_w2_a = nn.Parameter(
torch.empty(lora_dim, shape[0][1])) # b, 1-mode
self.lokr_w2_b = nn.Parameter(
torch.empty(lora_dim, shape[1][1])) # d, 2-mode
else: # Conv2d not cp
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
self.lokr_w2_a = nn.Parameter(
torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(
lora_dim, shape[1][1]*shape[2]*shape[3]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
self.op = F.conv2d
self.extra_args = {
"stride": org_module.stride,
@@ -147,48 +152,55 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
"groups": org_module.groups
}
else: # Linear
else: # Linear
in_dim = org_module.in_features
out_dim = org_module.out_features
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
shape = ((out_l, out_k), (in_m, in_n))
# smaller part. weight scale
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
self.lokr_w1_a = nn.Parameter(
torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(
torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
self.lokr_w1 = nn.Parameter(torch.empty(
shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim < max(shape[0][1], shape[1][1])/2:
# bigger part. weight and LoRA. [b, dim] x [dim, d]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
self.lokr_w2_a = nn.Parameter(
torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(
torch.empty(lora_dim, shape[1][1]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
self.lokr_w2 = nn.Parameter(
torch.empty(shape[0][1], shape[1][1]))
self.op = F.linear
self.extra_args = {}
self.dropout = dropout
if dropout:
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
print("[WARN]LoKr haven't implemented normal dropout yet.")
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if isinstance(alpha, torch.Tensor):
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
if self.use_w2 and self.use_w1:
#use scale = 1
# use scale = 1
alpha = lora_dim
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant
if self.use_w2:
torch.nn.init.constant_(self.lokr_w2, 0)
@@ -197,7 +209,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
torch.nn.init.constant_(self.lokr_w2_b, 0)
if self.use_w1:
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
else:
@@ -208,8 +220,8 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
self.org_module = [org_module]
weight = make_kron(
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
else self.lokr_w2_a@self.lokr_w2_b),
torch.tensor(self.multiplier * self.scale)
)
@@ -219,12 +231,12 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, orig_weight = None):
def get_weight(self, orig_weight=None):
weight = make_kron(
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
else self.lokr_w2_a@self.lokr_w2_b),
torch.tensor(self.scale)
)
@@ -232,51 +244,88 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
weight = weight.reshape(orig_weight.shape)
if self.training and self.rank_dropout:
drop = torch.rand(weight.size(0)) < self.rank_dropout
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
weight *= drop.view(-1, [1] *
len(weight.shape[1:])).to(weight.device)
return weight
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.get_weight().norm()
norm = torch.clamp(orig_norm, max_norm/2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu()/norm.cpu()
scaled = ratio.item() != 1.0
if scaled:
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
if self.use_w1:
self.lokr_w1 *= ratio**(1/modules)
else:
self.lokr_w1_a *= ratio**(1/modules)
self.lokr_w1_b *= ratio**(1/modules)
if self.use_w2:
self.lokr_w2 *= ratio**(1/modules)
else:
if self.cp:
self.lokr_t2 *= ratio**(1/modules)
self.lokr_w2_a *= ratio**(1/modules)
self.lokr_w2_b *= ratio**(1/modules)
return scaled, orig_norm*ratio
def merge_in(self, merge_weight=1.0):
if not self.can_merge_in:
return
def forward(self, x):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.op(
x,
self.org_module[0].weight.data,
None if self.org_module[0].bias is None else self.org_module[0].bias.data
)
weight = (
self.org_module[0].weight.data
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
# extract weight from org_module
org_sd = self.org_module[0].state_dict()
# todo find a way to merge in weights when doing quantized model
if 'weight._data' in org_sd:
# quantized weight
return
weight_key = "weight"
if 'weight._data' in org_sd:
# quantized weight
weight_key = "weight._data"
orig_dtype = org_sd[weight_key].dtype
weight = org_sd[weight_key].float()
scale = self.scale
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale = scale * self.scalar
lokr_weight = self.get_weight(weight)
merged_weight = (
weight
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype)
)
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
return self.op(
x,
# set weight to org_module
org_sd[weight_key] = merged_weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def get_orig_weight(self):
weight = self.org_module[0].weight
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
return weight.dequantize().data.detach()
else:
return weight.data.detach()
def get_orig_bias(self):
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
return self.org_module[0].bias.dequantize().data.detach()
else:
return self.org_module[0].bias.data.detach()
return None
def _call_forward(self, x):
if isinstance(x, QTensor) or isinstance(x, QBytesTensor):
x = x.dequantize()
orig_dtype = x.dtype
orig_weight = self.get_orig_weight()
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
multiplier = self.network_ref().torch_multiplier
if x.dtype != orig_weight.dtype:
x = x.to(dtype=orig_weight.dtype)
# we do not currently support split batch multipliers for lokr. Just do a mean
multiplier = torch.mean(multiplier)
weight = (
orig_weight
+ lokr_weight * multiplier
)
bias = self.get_orig_bias()
if bias is not None:
bias = bias.to(weight.device, dtype=weight.dtype)
output = self.op(
x,
weight.view(self.shape),
bias,
**self.extra_args
)
)
return output.to(orig_dtype)

View File

@@ -272,6 +272,9 @@ class ToolkitModuleMixin:
# if self.__class__.__name__ == "DoRAModule":
# # return dora forward
# return self.dora_forward(x, *args, **kwargs)
if self.__class__.__name__ == "LokrModule":
return self._call_forward(x)
org_forwarded = self.org_forward(x, *args, **kwargs)
@@ -540,6 +543,17 @@ class ToolkitNetworkMixin:
new_save_dict[new_key] = value
save_dict = new_save_dict
if self.network_type.lower() == "lokr":
new_save_dict = {}
for key, value in save_dict.items():
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
new_key = key
new_key = new_key.replace('lora_transformer_', 'lycoris_')
new_save_dict[new_key] = value
save_dict = new_save_dict
if metadata is None:
metadata = OrderedDict()
@@ -585,6 +599,10 @@ class ToolkitNetworkMixin:
load_key = load_key.replace('.', '$$')
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
if self.network_type.lower() == "lokr":
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
load_key = load_key.replace('lycoris_', 'lora_transformer_')
load_sd[load_key] = value
@@ -616,9 +634,22 @@ class ToolkitNetworkMixin:
# without having to set it in every single module every time it changes
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
try:
first_module = self.get_all_modules()[0]
except IndexError:
raise ValueError("There are not any lora modules in this network. Check your config and try again")
if hasattr(first_module, 'lora_down'):
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
elif hasattr(first_module, 'lokr_w1'):
device = first_module.lokr_w1.device
dtype = first_module.lokr_w1.dtype
elif hasattr(first_module, 'lokr_w1_a'):
device = first_module.lokr_w1_a.device
dtype = first_module.lokr_w1_a.dtype
else:
raise ValueError("Unknown module type")
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):

View File

@@ -1385,7 +1385,8 @@ class StableDiffusion:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \
and gen_config.adapter_image_path is not None:
# handle condition the prompts
gen_config.prompt = self.adapter.condition_prompt(
gen_config.prompt,
@@ -1439,7 +1440,7 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image,
prompt_embeds=conditional_embeds,

View File

@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
type: 'lora',
linear: 16,
linear_alpha: 16,
lokr_full_rank: true,
lokr_factor: -1
},
save: {
dtype: 'bf16',

View File

@@ -227,8 +227,31 @@ export default function TrainingForm() {
</div>
</FormGroup>
</Card>
{jobConfig.config.process[0].network?.type && (
<Card title="LoRA Configuration">
<Card title="Target Configuration">
<SelectInput
label="Target Type"
value={jobConfig.config.process[0].network?.type ?? 'lora'}
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
options={[
{ value: 'lora', label: 'LoRA' },
{ value: 'lokr', label: 'LoKr' },
]}
/>
{jobConfig.config.process[0].network?.type == 'lokr' && (
<SelectInput
label="LoKr Factor"
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
options={[
{ value: '-1', label: 'Auto' },
{ value: '4', label: '4' },
{ value: '8', label: '8' },
{ value: '16', label: '16' },
{ value: '32', label: '32' },
]}
/>
)}
{jobConfig.config.process[0].network?.type == 'lora' && (
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
@@ -242,8 +265,8 @@ export default function TrainingForm() {
max={1024}
required
/>
</Card>
)}
)}
</Card>
<Card title="Save Configuration">
<SelectInput
label="Data Type"
@@ -397,7 +420,9 @@ export default function TrainingForm() {
label="DFE Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>

View File

@@ -50,9 +50,11 @@ export interface GPUApiResponse {
*/
export interface NetworkConfig {
type: 'lora';
type: string;
linear: number;
linear_alpha: number;
lokr_full_rank: boolean;
lokr_factor: number;
}
export interface SaveConfig {