mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge branch 'main' into wan21
This commit is contained in:
309
scripts/update_sponsors.py
Normal file
309
scripts/update_sponsors.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env")
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
# API credentials
|
||||
PATREON_TOKEN = os.getenv("PATREON_ACCESS_TOKEN")
|
||||
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")
|
||||
GITHUB_USERNAME = os.getenv("GITHUB_USERNAME")
|
||||
GITHUB_ORG = os.getenv("GITHUB_ORG") # Organization name (optional)
|
||||
|
||||
# Output file
|
||||
README_PATH = "SUPPORTERS.md"
|
||||
|
||||
def fetch_patreon_supporters():
|
||||
"""Fetch current Patreon supporters"""
|
||||
print("Fetching Patreon supporters...")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {PATREON_TOKEN}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
url = "https://www.patreon.com/api/oauth2/v2/campaigns"
|
||||
|
||||
try:
|
||||
# First get the campaign ID
|
||||
campaign_response = requests.get(url, headers=headers)
|
||||
campaign_response.raise_for_status()
|
||||
campaign_data = campaign_response.json()
|
||||
|
||||
if not campaign_data.get('data'):
|
||||
print("No campaigns found for this Patreon account")
|
||||
return []
|
||||
|
||||
campaign_id = campaign_data['data'][0]['id']
|
||||
|
||||
# Now get the supporters for this campaign
|
||||
members_url = f"https://www.patreon.com/api/oauth2/v2/campaigns/{campaign_id}/members"
|
||||
params = {
|
||||
"include": "user",
|
||||
"fields[member]": "full_name,is_follower,patron_status", # Removed profile_url
|
||||
"fields[user]": "image_url"
|
||||
}
|
||||
|
||||
supporters = []
|
||||
while members_url:
|
||||
members_response = requests.get(members_url, headers=headers, params=params)
|
||||
members_response.raise_for_status()
|
||||
members_data = members_response.json()
|
||||
|
||||
# Process the response to extract active patrons
|
||||
for member in members_data.get('data', []):
|
||||
attributes = member.get('attributes', {})
|
||||
|
||||
# Only include active patrons
|
||||
if attributes.get('patron_status') == 'active_patron':
|
||||
name = attributes.get('full_name', 'Anonymous Supporter')
|
||||
|
||||
# Get user data which contains the profile image
|
||||
user_id = member.get('relationships', {}).get('user', {}).get('data', {}).get('id')
|
||||
profile_image = None
|
||||
profile_url = None # Removed profile_url since it's not supported
|
||||
|
||||
if user_id:
|
||||
for included in members_data.get('included', []):
|
||||
if included.get('id') == user_id and included.get('type') == 'user':
|
||||
profile_image = included.get('attributes', {}).get('image_url')
|
||||
break
|
||||
|
||||
supporters.append({
|
||||
'name': name,
|
||||
'profile_image': profile_image,
|
||||
'profile_url': profile_url, # This will be None
|
||||
'platform': 'Patreon',
|
||||
'amount': 0 # Placeholder, as Patreon API doesn't provide this in the current response
|
||||
})
|
||||
|
||||
# Handle pagination
|
||||
members_url = members_data.get('links', {}).get('next')
|
||||
|
||||
print(f"Found {len(supporters)} active Patreon supporters")
|
||||
return supporters
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error fetching Patreon data: {e}")
|
||||
print(f"Response content: {e.response.content if hasattr(e, 'response') else 'No response content'}")
|
||||
return []
|
||||
|
||||
def fetch_github_sponsors():
|
||||
"""Fetch current GitHub sponsors for a user or organization"""
|
||||
print("Fetching GitHub sponsors...")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {GITHUB_TOKEN}",
|
||||
"Accept": "application/vnd.github.v3+json"
|
||||
}
|
||||
|
||||
# Determine if we're fetching for a user or an organization
|
||||
entity_type = "organization" if GITHUB_ORG else "user"
|
||||
entity_name = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
|
||||
|
||||
if not entity_name:
|
||||
print("Error: Neither GITHUB_USERNAME nor GITHUB_ORG is set")
|
||||
return []
|
||||
|
||||
# Different GraphQL query structure based on entity type
|
||||
if entity_type == "user":
|
||||
query = """
|
||||
query {
|
||||
user(login: "%s") {
|
||||
sponsorshipsAsMaintainer(first: 100) {
|
||||
nodes {
|
||||
sponsorEntity {
|
||||
... on User {
|
||||
login
|
||||
name
|
||||
avatarUrl
|
||||
url
|
||||
}
|
||||
... on Organization {
|
||||
login
|
||||
name
|
||||
avatarUrl
|
||||
url
|
||||
}
|
||||
}
|
||||
tier {
|
||||
monthlyPriceInDollars
|
||||
}
|
||||
isOneTimePayment
|
||||
isActive
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""" % entity_name
|
||||
else: # organization
|
||||
query = """
|
||||
query {
|
||||
organization(login: "%s") {
|
||||
sponsorshipsAsMaintainer(first: 100) {
|
||||
nodes {
|
||||
sponsorEntity {
|
||||
... on User {
|
||||
login
|
||||
name
|
||||
avatarUrl
|
||||
url
|
||||
}
|
||||
... on Organization {
|
||||
login
|
||||
name
|
||||
avatarUrl
|
||||
url
|
||||
}
|
||||
}
|
||||
tier {
|
||||
monthlyPriceInDollars
|
||||
}
|
||||
isOneTimePayment
|
||||
isActive
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""" % entity_name
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
"https://api.github.com/graphql",
|
||||
headers=headers,
|
||||
json={"query": query}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Process the response - the path to the data differs based on entity type
|
||||
if entity_type == "user":
|
||||
sponsors_data = data.get('data', {}).get('user', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
|
||||
else:
|
||||
sponsors_data = data.get('data', {}).get('organization', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
|
||||
|
||||
sponsors = []
|
||||
for sponsor in sponsors_data:
|
||||
# Only include active sponsors
|
||||
if sponsor.get('isActive'):
|
||||
entity = sponsor.get('sponsorEntity', {})
|
||||
name = entity.get('name') or entity.get('login', 'Anonymous Sponsor')
|
||||
profile_image = entity.get('avatarUrl')
|
||||
profile_url = entity.get('url')
|
||||
amount = sponsor.get('tier', {}).get('monthlyPriceInDollars', 0)
|
||||
|
||||
sponsors.append({
|
||||
'name': name,
|
||||
'profile_image': profile_image,
|
||||
'profile_url': profile_url,
|
||||
'platform': 'GitHub Sponsors',
|
||||
'amount': amount
|
||||
})
|
||||
|
||||
print(f"Found {len(sponsors)} active GitHub sponsors for {entity_type} '{entity_name}'")
|
||||
return sponsors
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error fetching GitHub sponsors data: {e}")
|
||||
return []
|
||||
|
||||
def generate_readme(supporters):
|
||||
"""Generate a README.md file with supporter information"""
|
||||
print(f"Generating {README_PATH}...")
|
||||
|
||||
# Sort supporters by amount (descending) and then by name
|
||||
supporters.sort(key=lambda x: (-x['amount'], x['name'].lower()))
|
||||
|
||||
# Determine the proper footer links based on what's configured
|
||||
github_entity = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
|
||||
github_entity_type = "orgs" if GITHUB_ORG else "sponsors"
|
||||
github_sponsor_url = f"https://github.com/{github_entity_type}/{github_entity}"
|
||||
|
||||
with open(README_PATH, "w", encoding="utf-8") as f:
|
||||
f.write("## Support My Work\n\n")
|
||||
f.write("If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps! \n\n")
|
||||
# Create appropriate call-to-action based on what's configured
|
||||
cta_parts = []
|
||||
if github_entity:
|
||||
cta_parts.append(f"[Become a sponsor on GitHub]({github_sponsor_url})")
|
||||
if PATREON_TOKEN:
|
||||
cta_parts.append("[support me on Patreon](https://www.patreon.com/ostris)")
|
||||
|
||||
if cta_parts:
|
||||
if GITHUB_ORG:
|
||||
f.write(f"{' or '.join(cta_parts)}.\n\n")
|
||||
f.write("Thank you to all my current supporters!\n\n")
|
||||
|
||||
f.write(f"_Last updated: {datetime.now().strftime('%Y-%m-%d')}_\n\n")
|
||||
|
||||
# Write GitHub Sponsors section
|
||||
github_sponsors = [s for s in supporters if s['platform'] == 'GitHub Sponsors']
|
||||
if github_sponsors:
|
||||
f.write("### GitHub Sponsors\n\n")
|
||||
for sponsor in github_sponsors:
|
||||
if sponsor['profile_image']:
|
||||
f.write(f"<a href=\"{sponsor['profile_url']}\" title=\"{sponsor['name']}\"><img src=\"{sponsor['profile_image']}\" width=\"50\" height=\"50\" alt=\"{sponsor['name']}\" style=\"border-radius:50%\"></a> ")
|
||||
else:
|
||||
f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ")
|
||||
f.write("\n\n")
|
||||
|
||||
# Write Patreon section
|
||||
patreon_supporters = [s for s in supporters if s['platform'] == 'Patreon']
|
||||
if patreon_supporters:
|
||||
f.write("### Patreon Supporters\n\n")
|
||||
for supporter in patreon_supporters:
|
||||
if supporter['profile_image']:
|
||||
f.write(f"<a href=\"{supporter['profile_url']}\" title=\"{supporter['name']}\"><img src=\"{supporter['profile_image']}\" width=\"50\" height=\"50\" alt=\"{supporter['name']}\" style=\"border-radius:50%\"></a> ")
|
||||
else:
|
||||
f.write(f"[{supporter['name']}]({supporter['profile_url']}) ")
|
||||
f.write("\n\n")
|
||||
|
||||
f.write("\n---\n\n")
|
||||
|
||||
|
||||
print(f"Successfully generated {README_PATH} with {len(supporters)} supporters!")
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("Starting supporter data collection...")
|
||||
|
||||
# Check if required environment variables are set
|
||||
missing_vars = []
|
||||
if not GITHUB_TOKEN:
|
||||
missing_vars.append("GITHUB_TOKEN")
|
||||
|
||||
# Either username or org is required for GitHub
|
||||
if not GITHUB_USERNAME and not GITHUB_ORG:
|
||||
missing_vars.append("GITHUB_USERNAME or GITHUB_ORG")
|
||||
|
||||
# Patreon token is optional but warn if missing
|
||||
patreon_enabled = bool(PATREON_TOKEN)
|
||||
|
||||
if missing_vars:
|
||||
print(f"Error: Missing required environment variables: {', '.join(missing_vars)}")
|
||||
print("Please add them to your .env file")
|
||||
return
|
||||
|
||||
if not patreon_enabled:
|
||||
print("Warning: PATREON_ACCESS_TOKEN not set. Will only fetch GitHub sponsors.")
|
||||
|
||||
# Fetch data from both platforms
|
||||
patreon_supporters = fetch_patreon_supporters() if PATREON_TOKEN else []
|
||||
github_sponsors = fetch_github_sponsors()
|
||||
|
||||
# Combine supporters from both platforms
|
||||
all_supporters = patreon_supporters + github_sponsors
|
||||
|
||||
if not all_supporters:
|
||||
print("No supporters found on either platform")
|
||||
return
|
||||
|
||||
# Generate README
|
||||
generate_readme(all_supporters)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,3 +0,0 @@
|
||||
- only do ema on main device? shouldne be needed other than saving and sampling
|
||||
- check when to unwrap model and what it does
|
||||
- disable timer for non main local
|
||||
@@ -135,6 +135,15 @@ class NetworkConfig:
|
||||
self.conv = 4
|
||||
|
||||
self.transformer_only = kwargs.get('transformer_only', True)
|
||||
|
||||
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
|
||||
if self.lokr_full_rank and self.type.lower() == 'lokr':
|
||||
self.linear = 9999999999
|
||||
self.linear_alpha = 9999999999
|
||||
self.conv = 9999999999
|
||||
self.conv_alpha = 9999999999
|
||||
# -1 automatically finds the largest factor
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
||||
|
||||
@@ -231,12 +231,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
elif self.network_type.lower() == "lokr":
|
||||
self.module_class = LokrModule
|
||||
module_class = LokrModule
|
||||
self.network_config: NetworkConfig = kwargs.get("network_config", None)
|
||||
|
||||
self.peft_format = peft_format
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux or self.is_v3 or self.is_lumina2:
|
||||
self.peft_format = True
|
||||
# don't do peft format for lokr
|
||||
if self.network_type.lower() != "lokr":
|
||||
self.peft_format = True
|
||||
|
||||
if self.peft_format:
|
||||
# no alpha for peft
|
||||
@@ -338,8 +344,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
|
||||
continue
|
||||
if self.only_if_contains is not None:
|
||||
if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]):
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
@@ -373,6 +380,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
module_kwargs = {}
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
module_kwargs["factor"] = self.network_config.lokr_factor
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
@@ -386,10 +398,16 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**module_kwargs
|
||||
)
|
||||
loras.append(lora)
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
|
||||
]
|
||||
if self.network_type.lower() == "lokr":
|
||||
try:
|
||||
lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)]
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
|
||||
return loras, skipped
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
@@ -10,24 +10,23 @@ from toolkit.network_mixins import ToolkitModuleMixin
|
||||
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
from optimum.quanto import QBytesTensor, QTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# 4, build custom backward function
|
||||
# -
|
||||
|
||||
|
||||
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||
'''
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
|
||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||
secon value is a value for weight.
|
||||
|
||||
|
||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
@@ -38,7 +37,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
||||
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
||||
'''
|
||||
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
@@ -47,12 +46,12 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m<n:
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension%new_m != 0:
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m>factor:
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
@@ -62,7 +61,8 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
|
||||
|
||||
def make_weight_cp(t, wa, wb):
|
||||
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
|
||||
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l',
|
||||
t, wa, wb) # [c, d, k1, k2]
|
||||
return rebuild2
|
||||
|
||||
|
||||
@@ -71,31 +71,25 @@ def make_kron(w1, w2, scale):
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
rebuild = torch.kron(w1, w2)
|
||||
|
||||
|
||||
return rebuild*scale
|
||||
|
||||
|
||||
class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
"""
|
||||
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
||||
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
|
||||
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=0.,
|
||||
rank_dropout=0.,
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=0.,
|
||||
rank_dropout=0.,
|
||||
module_dropout=0.,
|
||||
use_cp=False,
|
||||
decompose_both = False,
|
||||
decompose_both=False,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
factor:int=-1, # factorization factor
|
||||
factor: int = -1, # factorization factor
|
||||
**kwargs,
|
||||
):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
@@ -107,38 +101,49 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
self.cp = False
|
||||
self.use_w1 = False
|
||||
self.use_w2 = False
|
||||
self.can_merge_in = True
|
||||
|
||||
self.shape = org_module.weight.shape
|
||||
if org_module.__class__.__name__ == 'Conv2d':
|
||||
in_dim = org_module.in_channels
|
||||
k_size = org_module.kernel_size
|
||||
out_dim = org_module.out_channels
|
||||
|
||||
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
|
||||
|
||||
self.cp = use_cp and k_size!=(1, 1)
|
||||
# ((a, b), (c, d), *k_size)
|
||||
shape = ((out_l, out_k), (in_m, in_n), *k_size)
|
||||
|
||||
self.cp = use_cp and k_size != (1, 1)
|
||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
||||
self.lokr_w1_a = nn.Parameter(
|
||||
torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][0]))
|
||||
else:
|
||||
self.use_w1 = True
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(
|
||||
shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(
|
||||
shape[0][1], shape[1][1], *k_size))
|
||||
elif self.cp:
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
||||
else: # Conv2d not cp
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(
|
||||
lora_dim, lora_dim, shape[2], shape[3]))
|
||||
self.lokr_w2_a = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
||||
self.lokr_w2_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
||||
else: # Conv2d not cp
|
||||
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
|
||||
self.lokr_w2_a = nn.Parameter(
|
||||
torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(
|
||||
lora_dim, shape[1][1]*shape[2]*shape[3]))
|
||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
||||
|
||||
|
||||
self.op = F.conv2d
|
||||
self.extra_args = {
|
||||
"stride": org_module.stride,
|
||||
@@ -147,48 +152,55 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
"groups": org_module.groups
|
||||
}
|
||||
|
||||
else: # Linear
|
||||
else: # Linear
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
||||
|
||||
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
||||
shape = ((out_l, out_k), (in_m, in_n))
|
||||
|
||||
# smaller part. weight scale
|
||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
||||
self.lokr_w1_a = nn.Parameter(
|
||||
torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][0]))
|
||||
else:
|
||||
self.use_w1 = True
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(
|
||||
shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
||||
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
|
||||
self.lokr_w2_a = nn.Parameter(
|
||||
torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(
|
||||
torch.empty(lora_dim, shape[1][1]))
|
||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
|
||||
self.lokr_w2 = nn.Parameter(
|
||||
torch.empty(shape[0][1], shape[1][1]))
|
||||
|
||||
self.op = F.linear
|
||||
self.extra_args = {}
|
||||
|
||||
|
||||
self.dropout = dropout
|
||||
if dropout:
|
||||
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
|
||||
print("[WARN]LoKr haven't implemented normal dropout yet.")
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
if self.use_w2 and self.use_w1:
|
||||
#use scale = 1
|
||||
# use scale = 1
|
||||
alpha = lora_dim
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant
|
||||
|
||||
if self.use_w2:
|
||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||
@@ -197,7 +209,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||
|
||||
|
||||
if self.use_w1:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
||||
else:
|
||||
@@ -208,8 +220,8 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
self.org_module = [org_module]
|
||||
weight = make_kron(
|
||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
else self.lokr_w2_a@self.lokr_w2_b),
|
||||
torch.tensor(self.multiplier * self.scale)
|
||||
)
|
||||
@@ -219,12 +231,12 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def get_weight(self, orig_weight = None):
|
||||
|
||||
def get_weight(self, orig_weight=None):
|
||||
weight = make_kron(
|
||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
else self.lokr_w2_a@self.lokr_w2_b),
|
||||
torch.tensor(self.scale)
|
||||
)
|
||||
@@ -232,51 +244,88 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
weight = weight.reshape(orig_weight.shape)
|
||||
if self.training and self.rank_dropout:
|
||||
drop = torch.rand(weight.size(0)) < self.rank_dropout
|
||||
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
|
||||
weight *= drop.view(-1, [1] *
|
||||
len(weight.shape[1:])).to(weight.device)
|
||||
return weight
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_max_norm(self, max_norm, device=None):
|
||||
orig_norm = self.get_weight().norm()
|
||||
norm = torch.clamp(orig_norm, max_norm/2)
|
||||
desired = torch.clamp(norm, max=max_norm)
|
||||
ratio = desired.cpu()/norm.cpu()
|
||||
|
||||
scaled = ratio.item() != 1.0
|
||||
if scaled:
|
||||
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
|
||||
if self.use_w1:
|
||||
self.lokr_w1 *= ratio**(1/modules)
|
||||
else:
|
||||
self.lokr_w1_a *= ratio**(1/modules)
|
||||
self.lokr_w1_b *= ratio**(1/modules)
|
||||
|
||||
if self.use_w2:
|
||||
self.lokr_w2 *= ratio**(1/modules)
|
||||
else:
|
||||
if self.cp:
|
||||
self.lokr_t2 *= ratio**(1/modules)
|
||||
self.lokr_w2_a *= ratio**(1/modules)
|
||||
self.lokr_w2_b *= ratio**(1/modules)
|
||||
|
||||
return scaled, orig_norm*ratio
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
if not self.can_merge_in:
|
||||
return
|
||||
|
||||
def forward(self, x):
|
||||
if self.module_dropout and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return self.op(
|
||||
x,
|
||||
self.org_module[0].weight.data,
|
||||
None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
||||
)
|
||||
weight = (
|
||||
self.org_module[0].weight.data
|
||||
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
# todo find a way to merge in weights when doing quantized model
|
||||
if 'weight._data' in org_sd:
|
||||
# quantized weight
|
||||
return
|
||||
|
||||
weight_key = "weight"
|
||||
if 'weight._data' in org_sd:
|
||||
# quantized weight
|
||||
weight_key = "weight._data"
|
||||
|
||||
orig_dtype = org_sd[weight_key].dtype
|
||||
weight = org_sd[weight_key].float()
|
||||
|
||||
scale = self.scale
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
lokr_weight = self.get_weight(weight)
|
||||
|
||||
merged_weight = (
|
||||
weight
|
||||
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype)
|
||||
)
|
||||
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
||||
return self.op(
|
||||
x,
|
||||
|
||||
# set weight to org_module
|
||||
org_sd[weight_key] = merged_weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def get_orig_weight(self):
|
||||
weight = self.org_module[0].weight
|
||||
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
else:
|
||||
return weight.data.detach()
|
||||
|
||||
def get_orig_bias(self):
|
||||
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
|
||||
return self.org_module[0].bias.dequantize().data.detach()
|
||||
else:
|
||||
return self.org_module[0].bias.data.detach()
|
||||
return None
|
||||
|
||||
def _call_forward(self, x):
|
||||
if isinstance(x, QTensor) or isinstance(x, QBytesTensor):
|
||||
x = x.dequantize()
|
||||
|
||||
orig_dtype = x.dtype
|
||||
|
||||
orig_weight = self.get_orig_weight()
|
||||
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
if x.dtype != orig_weight.dtype:
|
||||
x = x.to(dtype=orig_weight.dtype)
|
||||
|
||||
# we do not currently support split batch multipliers for lokr. Just do a mean
|
||||
multiplier = torch.mean(multiplier)
|
||||
|
||||
weight = (
|
||||
orig_weight
|
||||
+ lokr_weight * multiplier
|
||||
)
|
||||
bias = self.get_orig_bias()
|
||||
if bias is not None:
|
||||
bias = bias.to(weight.device, dtype=weight.dtype)
|
||||
output = self.op(
|
||||
x,
|
||||
weight.view(self.shape),
|
||||
bias,
|
||||
**self.extra_args
|
||||
)
|
||||
)
|
||||
return output.to(orig_dtype)
|
||||
|
||||
@@ -272,6 +272,9 @@ class ToolkitModuleMixin:
|
||||
# if self.__class__.__name__ == "DoRAModule":
|
||||
# # return dora forward
|
||||
# return self.dora_forward(x, *args, **kwargs)
|
||||
|
||||
if self.__class__.__name__ == "LokrModule":
|
||||
return self._call_forward(x)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
|
||||
@@ -540,6 +543,17 @@ class ToolkitNetworkMixin:
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
new_key = key
|
||||
new_key = new_key.replace('lora_transformer_', 'lycoris_')
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
@@ -585,6 +599,10 @@ class ToolkitNetworkMixin:
|
||||
load_key = load_key.replace('.', '$$')
|
||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
load_key = load_key.replace('lycoris_', 'lora_transformer_')
|
||||
|
||||
load_sd[load_key] = value
|
||||
|
||||
@@ -616,9 +634,22 @@ class ToolkitNetworkMixin:
|
||||
# without having to set it in every single module every time it changes
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
try:
|
||||
first_module = self.get_all_modules()[0]
|
||||
except IndexError:
|
||||
raise ValueError("There are not any lora modules in this network. Check your config and try again")
|
||||
|
||||
if hasattr(first_module, 'lora_down'):
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
elif hasattr(first_module, 'lokr_w1'):
|
||||
device = first_module.lokr_w1.device
|
||||
dtype = first_module.lokr_w1.dtype
|
||||
elif hasattr(first_module, 'lokr_w1_a'):
|
||||
device = first_module.lokr_w1_a.device
|
||||
dtype = first_module.lokr_w1_a.dtype
|
||||
else:
|
||||
raise ValueError("Unknown module type")
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
|
||||
@@ -1385,7 +1385,8 @@ class StableDiffusion:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
# handle condition the prompts
|
||||
gen_config.prompt = self.adapter.condition_prompt(
|
||||
gen_config.prompt,
|
||||
@@ -1439,7 +1440,7 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
prompt_embeds=conditional_embeds,
|
||||
|
||||
@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
type: 'lora',
|
||||
linear: 16,
|
||||
linear_alpha: 16,
|
||||
lokr_full_rank: true,
|
||||
lokr_factor: -1
|
||||
},
|
||||
save: {
|
||||
dtype: 'bf16',
|
||||
|
||||
@@ -227,8 +227,31 @@ export default function TrainingForm() {
|
||||
</div>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
{jobConfig.config.process[0].network?.type && (
|
||||
<Card title="LoRA Configuration">
|
||||
<Card title="Target Configuration">
|
||||
<SelectInput
|
||||
label="Target Type"
|
||||
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
|
||||
options={[
|
||||
{ value: 'lora', label: 'LoRA' },
|
||||
{ value: 'lokr', label: 'LoKr' },
|
||||
]}
|
||||
/>
|
||||
{jobConfig.config.process[0].network?.type == 'lokr' && (
|
||||
<SelectInput
|
||||
label="LoKr Factor"
|
||||
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
|
||||
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
|
||||
options={[
|
||||
{ value: '-1', label: 'Auto' },
|
||||
{ value: '4', label: '4' },
|
||||
{ value: '8', label: '8' },
|
||||
{ value: '16', label: '16' },
|
||||
{ value: '32', label: '32' },
|
||||
]}
|
||||
/>
|
||||
)}
|
||||
{jobConfig.config.process[0].network?.type == 'lora' && (
|
||||
<NumberInput
|
||||
label="Linear Rank"
|
||||
value={jobConfig.config.process[0].network.linear}
|
||||
@@ -242,8 +265,8 @@ export default function TrainingForm() {
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
)}
|
||||
</Card>
|
||||
<Card title="Save Configuration">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
@@ -397,7 +420,9 @@ export default function TrainingForm() {
|
||||
label="DFE Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
|
||||
@@ -50,9 +50,11 @@ export interface GPUApiResponse {
|
||||
*/
|
||||
|
||||
export interface NetworkConfig {
|
||||
type: 'lora';
|
||||
type: string;
|
||||
linear: number;
|
||||
linear_alpha: number;
|
||||
lokr_full_rank: boolean;
|
||||
lokr_factor: number;
|
||||
}
|
||||
|
||||
export interface SaveConfig {
|
||||
|
||||
Reference in New Issue
Block a user