mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Merge branch 'main' into wan21
This commit is contained in:
309
scripts/update_sponsors.py
Normal file
309
scripts/update_sponsors.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env")
|
||||||
|
load_dotenv(dotenv_path=env_path)
|
||||||
|
|
||||||
|
# API credentials
|
||||||
|
PATREON_TOKEN = os.getenv("PATREON_ACCESS_TOKEN")
|
||||||
|
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")
|
||||||
|
GITHUB_USERNAME = os.getenv("GITHUB_USERNAME")
|
||||||
|
GITHUB_ORG = os.getenv("GITHUB_ORG") # Organization name (optional)
|
||||||
|
|
||||||
|
# Output file
|
||||||
|
README_PATH = "SUPPORTERS.md"
|
||||||
|
|
||||||
|
def fetch_patreon_supporters():
|
||||||
|
"""Fetch current Patreon supporters"""
|
||||||
|
print("Fetching Patreon supporters...")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {PATREON_TOKEN}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
url = "https://www.patreon.com/api/oauth2/v2/campaigns"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First get the campaign ID
|
||||||
|
campaign_response = requests.get(url, headers=headers)
|
||||||
|
campaign_response.raise_for_status()
|
||||||
|
campaign_data = campaign_response.json()
|
||||||
|
|
||||||
|
if not campaign_data.get('data'):
|
||||||
|
print("No campaigns found for this Patreon account")
|
||||||
|
return []
|
||||||
|
|
||||||
|
campaign_id = campaign_data['data'][0]['id']
|
||||||
|
|
||||||
|
# Now get the supporters for this campaign
|
||||||
|
members_url = f"https://www.patreon.com/api/oauth2/v2/campaigns/{campaign_id}/members"
|
||||||
|
params = {
|
||||||
|
"include": "user",
|
||||||
|
"fields[member]": "full_name,is_follower,patron_status", # Removed profile_url
|
||||||
|
"fields[user]": "image_url"
|
||||||
|
}
|
||||||
|
|
||||||
|
supporters = []
|
||||||
|
while members_url:
|
||||||
|
members_response = requests.get(members_url, headers=headers, params=params)
|
||||||
|
members_response.raise_for_status()
|
||||||
|
members_data = members_response.json()
|
||||||
|
|
||||||
|
# Process the response to extract active patrons
|
||||||
|
for member in members_data.get('data', []):
|
||||||
|
attributes = member.get('attributes', {})
|
||||||
|
|
||||||
|
# Only include active patrons
|
||||||
|
if attributes.get('patron_status') == 'active_patron':
|
||||||
|
name = attributes.get('full_name', 'Anonymous Supporter')
|
||||||
|
|
||||||
|
# Get user data which contains the profile image
|
||||||
|
user_id = member.get('relationships', {}).get('user', {}).get('data', {}).get('id')
|
||||||
|
profile_image = None
|
||||||
|
profile_url = None # Removed profile_url since it's not supported
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
for included in members_data.get('included', []):
|
||||||
|
if included.get('id') == user_id and included.get('type') == 'user':
|
||||||
|
profile_image = included.get('attributes', {}).get('image_url')
|
||||||
|
break
|
||||||
|
|
||||||
|
supporters.append({
|
||||||
|
'name': name,
|
||||||
|
'profile_image': profile_image,
|
||||||
|
'profile_url': profile_url, # This will be None
|
||||||
|
'platform': 'Patreon',
|
||||||
|
'amount': 0 # Placeholder, as Patreon API doesn't provide this in the current response
|
||||||
|
})
|
||||||
|
|
||||||
|
# Handle pagination
|
||||||
|
members_url = members_data.get('links', {}).get('next')
|
||||||
|
|
||||||
|
print(f"Found {len(supporters)} active Patreon supporters")
|
||||||
|
return supporters
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"Error fetching Patreon data: {e}")
|
||||||
|
print(f"Response content: {e.response.content if hasattr(e, 'response') else 'No response content'}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def fetch_github_sponsors():
|
||||||
|
"""Fetch current GitHub sponsors for a user or organization"""
|
||||||
|
print("Fetching GitHub sponsors...")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {GITHUB_TOKEN}",
|
||||||
|
"Accept": "application/vnd.github.v3+json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine if we're fetching for a user or an organization
|
||||||
|
entity_type = "organization" if GITHUB_ORG else "user"
|
||||||
|
entity_name = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
|
||||||
|
|
||||||
|
if not entity_name:
|
||||||
|
print("Error: Neither GITHUB_USERNAME nor GITHUB_ORG is set")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Different GraphQL query structure based on entity type
|
||||||
|
if entity_type == "user":
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
user(login: "%s") {
|
||||||
|
sponsorshipsAsMaintainer(first: 100) {
|
||||||
|
nodes {
|
||||||
|
sponsorEntity {
|
||||||
|
... on User {
|
||||||
|
login
|
||||||
|
name
|
||||||
|
avatarUrl
|
||||||
|
url
|
||||||
|
}
|
||||||
|
... on Organization {
|
||||||
|
login
|
||||||
|
name
|
||||||
|
avatarUrl
|
||||||
|
url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tier {
|
||||||
|
monthlyPriceInDollars
|
||||||
|
}
|
||||||
|
isOneTimePayment
|
||||||
|
isActive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""" % entity_name
|
||||||
|
else: # organization
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
organization(login: "%s") {
|
||||||
|
sponsorshipsAsMaintainer(first: 100) {
|
||||||
|
nodes {
|
||||||
|
sponsorEntity {
|
||||||
|
... on User {
|
||||||
|
login
|
||||||
|
name
|
||||||
|
avatarUrl
|
||||||
|
url
|
||||||
|
}
|
||||||
|
... on Organization {
|
||||||
|
login
|
||||||
|
name
|
||||||
|
avatarUrl
|
||||||
|
url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tier {
|
||||||
|
monthlyPriceInDollars
|
||||||
|
}
|
||||||
|
isOneTimePayment
|
||||||
|
isActive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""" % entity_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
"https://api.github.com/graphql",
|
||||||
|
headers=headers,
|
||||||
|
json={"query": query}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Process the response - the path to the data differs based on entity type
|
||||||
|
if entity_type == "user":
|
||||||
|
sponsors_data = data.get('data', {}).get('user', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
|
||||||
|
else:
|
||||||
|
sponsors_data = data.get('data', {}).get('organization', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
|
||||||
|
|
||||||
|
sponsors = []
|
||||||
|
for sponsor in sponsors_data:
|
||||||
|
# Only include active sponsors
|
||||||
|
if sponsor.get('isActive'):
|
||||||
|
entity = sponsor.get('sponsorEntity', {})
|
||||||
|
name = entity.get('name') or entity.get('login', 'Anonymous Sponsor')
|
||||||
|
profile_image = entity.get('avatarUrl')
|
||||||
|
profile_url = entity.get('url')
|
||||||
|
amount = sponsor.get('tier', {}).get('monthlyPriceInDollars', 0)
|
||||||
|
|
||||||
|
sponsors.append({
|
||||||
|
'name': name,
|
||||||
|
'profile_image': profile_image,
|
||||||
|
'profile_url': profile_url,
|
||||||
|
'platform': 'GitHub Sponsors',
|
||||||
|
'amount': amount
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"Found {len(sponsors)} active GitHub sponsors for {entity_type} '{entity_name}'")
|
||||||
|
return sponsors
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"Error fetching GitHub sponsors data: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def generate_readme(supporters):
|
||||||
|
"""Generate a README.md file with supporter information"""
|
||||||
|
print(f"Generating {README_PATH}...")
|
||||||
|
|
||||||
|
# Sort supporters by amount (descending) and then by name
|
||||||
|
supporters.sort(key=lambda x: (-x['amount'], x['name'].lower()))
|
||||||
|
|
||||||
|
# Determine the proper footer links based on what's configured
|
||||||
|
github_entity = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
|
||||||
|
github_entity_type = "orgs" if GITHUB_ORG else "sponsors"
|
||||||
|
github_sponsor_url = f"https://github.com/{github_entity_type}/{github_entity}"
|
||||||
|
|
||||||
|
with open(README_PATH, "w", encoding="utf-8") as f:
|
||||||
|
f.write("## Support My Work\n\n")
|
||||||
|
f.write("If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps! \n\n")
|
||||||
|
# Create appropriate call-to-action based on what's configured
|
||||||
|
cta_parts = []
|
||||||
|
if github_entity:
|
||||||
|
cta_parts.append(f"[Become a sponsor on GitHub]({github_sponsor_url})")
|
||||||
|
if PATREON_TOKEN:
|
||||||
|
cta_parts.append("[support me on Patreon](https://www.patreon.com/ostris)")
|
||||||
|
|
||||||
|
if cta_parts:
|
||||||
|
if GITHUB_ORG:
|
||||||
|
f.write(f"{' or '.join(cta_parts)}.\n\n")
|
||||||
|
f.write("Thank you to all my current supporters!\n\n")
|
||||||
|
|
||||||
|
f.write(f"_Last updated: {datetime.now().strftime('%Y-%m-%d')}_\n\n")
|
||||||
|
|
||||||
|
# Write GitHub Sponsors section
|
||||||
|
github_sponsors = [s for s in supporters if s['platform'] == 'GitHub Sponsors']
|
||||||
|
if github_sponsors:
|
||||||
|
f.write("### GitHub Sponsors\n\n")
|
||||||
|
for sponsor in github_sponsors:
|
||||||
|
if sponsor['profile_image']:
|
||||||
|
f.write(f"<a href=\"{sponsor['profile_url']}\" title=\"{sponsor['name']}\"><img src=\"{sponsor['profile_image']}\" width=\"50\" height=\"50\" alt=\"{sponsor['name']}\" style=\"border-radius:50%\"></a> ")
|
||||||
|
else:
|
||||||
|
f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ")
|
||||||
|
f.write("\n\n")
|
||||||
|
|
||||||
|
# Write Patreon section
|
||||||
|
patreon_supporters = [s for s in supporters if s['platform'] == 'Patreon']
|
||||||
|
if patreon_supporters:
|
||||||
|
f.write("### Patreon Supporters\n\n")
|
||||||
|
for supporter in patreon_supporters:
|
||||||
|
if supporter['profile_image']:
|
||||||
|
f.write(f"<a href=\"{supporter['profile_url']}\" title=\"{supporter['name']}\"><img src=\"{supporter['profile_image']}\" width=\"50\" height=\"50\" alt=\"{supporter['name']}\" style=\"border-radius:50%\"></a> ")
|
||||||
|
else:
|
||||||
|
f.write(f"[{supporter['name']}]({supporter['profile_url']}) ")
|
||||||
|
f.write("\n\n")
|
||||||
|
|
||||||
|
f.write("\n---\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Successfully generated {README_PATH} with {len(supporters)} supporters!")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function"""
|
||||||
|
print("Starting supporter data collection...")
|
||||||
|
|
||||||
|
# Check if required environment variables are set
|
||||||
|
missing_vars = []
|
||||||
|
if not GITHUB_TOKEN:
|
||||||
|
missing_vars.append("GITHUB_TOKEN")
|
||||||
|
|
||||||
|
# Either username or org is required for GitHub
|
||||||
|
if not GITHUB_USERNAME and not GITHUB_ORG:
|
||||||
|
missing_vars.append("GITHUB_USERNAME or GITHUB_ORG")
|
||||||
|
|
||||||
|
# Patreon token is optional but warn if missing
|
||||||
|
patreon_enabled = bool(PATREON_TOKEN)
|
||||||
|
|
||||||
|
if missing_vars:
|
||||||
|
print(f"Error: Missing required environment variables: {', '.join(missing_vars)}")
|
||||||
|
print("Please add them to your .env file")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not patreon_enabled:
|
||||||
|
print("Warning: PATREON_ACCESS_TOKEN not set. Will only fetch GitHub sponsors.")
|
||||||
|
|
||||||
|
# Fetch data from both platforms
|
||||||
|
patreon_supporters = fetch_patreon_supporters() if PATREON_TOKEN else []
|
||||||
|
github_sponsors = fetch_github_sponsors()
|
||||||
|
|
||||||
|
# Combine supporters from both platforms
|
||||||
|
all_supporters = patreon_supporters + github_sponsors
|
||||||
|
|
||||||
|
if not all_supporters:
|
||||||
|
print("No supporters found on either platform")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate README
|
||||||
|
generate_readme(all_supporters)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
- only do ema on main device? shouldne be needed other than saving and sampling
|
|
||||||
- check when to unwrap model and what it does
|
|
||||||
- disable timer for non main local
|
|
||||||
@@ -135,6 +135,15 @@ class NetworkConfig:
|
|||||||
self.conv = 4
|
self.conv = 4
|
||||||
|
|
||||||
self.transformer_only = kwargs.get('transformer_only', True)
|
self.transformer_only = kwargs.get('transformer_only', True)
|
||||||
|
|
||||||
|
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
|
||||||
|
if self.lokr_full_rank and self.type.lower() == 'lokr':
|
||||||
|
self.linear = 9999999999
|
||||||
|
self.linear_alpha = 9999999999
|
||||||
|
self.conv = 9999999999
|
||||||
|
self.conv_alpha = 9999999999
|
||||||
|
# -1 automatically finds the largest factor
|
||||||
|
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||||
|
|
||||||
|
|
||||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
||||||
|
|||||||
@@ -231,12 +231,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if self.network_type.lower() == "dora":
|
if self.network_type.lower() == "dora":
|
||||||
self.module_class = DoRAModule
|
self.module_class = DoRAModule
|
||||||
module_class = DoRAModule
|
module_class = DoRAModule
|
||||||
|
elif self.network_type.lower() == "lokr":
|
||||||
|
self.module_class = LokrModule
|
||||||
|
module_class = LokrModule
|
||||||
|
self.network_config: NetworkConfig = kwargs.get("network_config", None)
|
||||||
|
|
||||||
self.peft_format = peft_format
|
self.peft_format = peft_format
|
||||||
|
|
||||||
# always do peft for flux only for now
|
# always do peft for flux only for now
|
||||||
if self.is_flux or self.is_v3 or self.is_lumina2:
|
if self.is_flux or self.is_v3 or self.is_lumina2:
|
||||||
self.peft_format = True
|
# don't do peft format for lokr
|
||||||
|
if self.network_type.lower() != "lokr":
|
||||||
|
self.peft_format = True
|
||||||
|
|
||||||
if self.peft_format:
|
if self.peft_format:
|
||||||
# no alpha for peft
|
# no alpha for peft
|
||||||
@@ -338,8 +344,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
|
|
||||||
if (is_linear or is_conv2d) and not skip:
|
if (is_linear or is_conv2d) and not skip:
|
||||||
|
|
||||||
if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
|
if self.only_if_contains is not None:
|
||||||
continue
|
if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]):
|
||||||
|
continue
|
||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
@@ -373,6 +380,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
self.conv_lora_dim is not None or conv_block_dims is not None):
|
self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
module_kwargs = {}
|
||||||
|
|
||||||
|
if self.network_type.lower() == "lokr":
|
||||||
|
module_kwargs["factor"] = self.network_config.lokr_factor
|
||||||
|
|
||||||
lora = module_class(
|
lora = module_class(
|
||||||
lora_name,
|
lora_name,
|
||||||
@@ -386,10 +398,16 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
network=self,
|
network=self,
|
||||||
parent=module,
|
parent=module,
|
||||||
use_bias=use_bias,
|
use_bias=use_bias,
|
||||||
|
**module_kwargs
|
||||||
)
|
)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
|
if self.network_type.lower() == "lokr":
|
||||||
]
|
try:
|
||||||
|
lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||||
|
|||||||
@@ -10,24 +10,23 @@ from toolkit.network_mixins import ToolkitModuleMixin
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Union, List
|
from typing import TYPE_CHECKING, Union, List
|
||||||
|
|
||||||
|
from optimum.quanto import QBytesTensor, QTensor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
|
|
||||||
# 4, build custom backward function
|
|
||||||
# -
|
|
||||||
|
|
||||||
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||||
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|
||||||
'''
|
'''
|
||||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
second value is higher or equal than first value.
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||||
secon value is a value for weight.
|
secon value is a value for weight.
|
||||||
|
|
||||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||||
|
|
||||||
examples)
|
examples)
|
||||||
factor
|
factor
|
||||||
-1 2 4 8 16 ...
|
-1 2 4 8 16 ...
|
||||||
@@ -38,7 +37,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|||||||
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
||||||
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if factor > 0 and (dimension % factor) == 0:
|
if factor > 0 and (dimension % factor) == 0:
|
||||||
m = factor
|
m = factor
|
||||||
n = dimension // factor
|
n = dimension // factor
|
||||||
@@ -47,12 +46,12 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|||||||
factor = dimension
|
factor = dimension
|
||||||
m, n = 1, dimension
|
m, n = 1, dimension
|
||||||
length = m + n
|
length = m + n
|
||||||
while m<n:
|
while m < n:
|
||||||
new_m = m + 1
|
new_m = m + 1
|
||||||
while dimension%new_m != 0:
|
while dimension % new_m != 0:
|
||||||
new_m += 1
|
new_m += 1
|
||||||
new_n = dimension // new_m
|
new_n = dimension // new_m
|
||||||
if new_m + new_n > length or new_m>factor:
|
if new_m + new_n > length or new_m > factor:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
m, n = new_m, new_n
|
m, n = new_m, new_n
|
||||||
@@ -62,7 +61,8 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|||||||
|
|
||||||
|
|
||||||
def make_weight_cp(t, wa, wb):
|
def make_weight_cp(t, wa, wb):
|
||||||
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
|
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l',
|
||||||
|
t, wa, wb) # [c, d, k1, k2]
|
||||||
return rebuild2
|
return rebuild2
|
||||||
|
|
||||||
|
|
||||||
@@ -71,31 +71,25 @@ def make_kron(w1, w2, scale):
|
|||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
w2 = w2.contiguous()
|
w2 = w2.contiguous()
|
||||||
rebuild = torch.kron(w1, w2)
|
rebuild = torch.kron(w1, w2)
|
||||||
|
|
||||||
return rebuild*scale
|
return rebuild*scale
|
||||||
|
|
||||||
|
|
||||||
class LokrModule(ToolkitModuleMixin, nn.Module):
|
class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||||
"""
|
|
||||||
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
|
||||||
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
|
|
||||||
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
lora_name,
|
lora_name,
|
||||||
org_module: nn.Module,
|
org_module: nn.Module,
|
||||||
multiplier=1.0,
|
multiplier=1.0,
|
||||||
lora_dim=4,
|
lora_dim=4,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
rank_dropout=0.,
|
rank_dropout=0.,
|
||||||
module_dropout=0.,
|
module_dropout=0.,
|
||||||
use_cp=False,
|
use_cp=False,
|
||||||
decompose_both = False,
|
decompose_both=False,
|
||||||
network: 'LoRASpecialNetwork' = None,
|
network: 'LoRASpecialNetwork' = None,
|
||||||
factor:int=-1, # factorization factor
|
factor: int = -1, # factorization factor
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||||
@@ -107,38 +101,49 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
|||||||
self.cp = False
|
self.cp = False
|
||||||
self.use_w1 = False
|
self.use_w1 = False
|
||||||
self.use_w2 = False
|
self.use_w2 = False
|
||||||
|
self.can_merge_in = True
|
||||||
|
|
||||||
self.shape = org_module.weight.shape
|
self.shape = org_module.weight.shape
|
||||||
if org_module.__class__.__name__ == 'Conv2d':
|
if org_module.__class__.__name__ == 'Conv2d':
|
||||||
in_dim = org_module.in_channels
|
in_dim = org_module.in_channels
|
||||||
k_size = org_module.kernel_size
|
k_size = org_module.kernel_size
|
||||||
out_dim = org_module.out_channels
|
out_dim = org_module.out_channels
|
||||||
|
|
||||||
in_m, in_n = factorization(in_dim, factor)
|
in_m, in_n = factorization(in_dim, factor)
|
||||||
out_l, out_k = factorization(out_dim, factor)
|
out_l, out_k = factorization(out_dim, factor)
|
||||||
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
|
# ((a, b), (c, d), *k_size)
|
||||||
|
shape = ((out_l, out_k), (in_m, in_n), *k_size)
|
||||||
self.cp = use_cp and k_size!=(1, 1)
|
|
||||||
|
self.cp = use_cp and k_size != (1, 1)
|
||||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
self.lokr_w1_a = nn.Parameter(
|
||||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
torch.empty(shape[0][0], lora_dim))
|
||||||
|
self.lokr_w1_b = nn.Parameter(
|
||||||
|
torch.empty(lora_dim, shape[1][0]))
|
||||||
else:
|
else:
|
||||||
self.use_w1 = True
|
self.use_w1 = True
|
||||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
self.lokr_w1 = nn.Parameter(torch.empty(
|
||||||
|
shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||||
|
|
||||||
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
||||||
self.use_w2 = True
|
self.use_w2 = True
|
||||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
|
self.lokr_w2 = nn.Parameter(torch.empty(
|
||||||
|
shape[0][1], shape[1][1], *k_size))
|
||||||
elif self.cp:
|
elif self.cp:
|
||||||
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
self.lokr_t2 = nn.Parameter(torch.empty(
|
||||||
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
lora_dim, lora_dim, shape[2], shape[3]))
|
||||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
self.lokr_w2_a = nn.Parameter(
|
||||||
else: # Conv2d not cp
|
torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
||||||
|
self.lokr_w2_b = nn.Parameter(
|
||||||
|
torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
||||||
|
else: # Conv2d not cp
|
||||||
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
||||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
self.lokr_w2_a = nn.Parameter(
|
||||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
|
torch.empty(shape[0][1], lora_dim))
|
||||||
|
self.lokr_w2_b = nn.Parameter(torch.empty(
|
||||||
|
lora_dim, shape[1][1]*shape[2]*shape[3]))
|
||||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
||||||
|
|
||||||
self.op = F.conv2d
|
self.op = F.conv2d
|
||||||
self.extra_args = {
|
self.extra_args = {
|
||||||
"stride": org_module.stride,
|
"stride": org_module.stride,
|
||||||
@@ -147,48 +152,55 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
|||||||
"groups": org_module.groups
|
"groups": org_module.groups
|
||||||
}
|
}
|
||||||
|
|
||||||
else: # Linear
|
else: # Linear
|
||||||
in_dim = org_module.in_features
|
in_dim = org_module.in_features
|
||||||
out_dim = org_module.out_features
|
out_dim = org_module.out_features
|
||||||
|
|
||||||
in_m, in_n = factorization(in_dim, factor)
|
in_m, in_n = factorization(in_dim, factor)
|
||||||
out_l, out_k = factorization(out_dim, factor)
|
out_l, out_k = factorization(out_dim, factor)
|
||||||
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
||||||
|
shape = ((out_l, out_k), (in_m, in_n))
|
||||||
|
|
||||||
# smaller part. weight scale
|
# smaller part. weight scale
|
||||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
self.lokr_w1_a = nn.Parameter(
|
||||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
torch.empty(shape[0][0], lora_dim))
|
||||||
|
self.lokr_w1_b = nn.Parameter(
|
||||||
|
torch.empty(lora_dim, shape[1][0]))
|
||||||
else:
|
else:
|
||||||
self.use_w1 = True
|
self.use_w1 = True
|
||||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
self.lokr_w1 = nn.Parameter(torch.empty(
|
||||||
|
shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||||
|
|
||||||
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
||||||
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
||||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
self.lokr_w2_a = nn.Parameter(
|
||||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
|
torch.empty(shape[0][1], lora_dim))
|
||||||
|
self.lokr_w2_b = nn.Parameter(
|
||||||
|
torch.empty(lora_dim, shape[1][1]))
|
||||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
||||||
else:
|
else:
|
||||||
self.use_w2 = True
|
self.use_w2 = True
|
||||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
|
self.lokr_w2 = nn.Parameter(
|
||||||
|
torch.empty(shape[0][1], shape[1][1]))
|
||||||
|
|
||||||
self.op = F.linear
|
self.op = F.linear
|
||||||
self.extra_args = {}
|
self.extra_args = {}
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
if dropout:
|
if dropout:
|
||||||
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
|
print("[WARN]LoKr haven't implemented normal dropout yet.")
|
||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
if isinstance(alpha, torch.Tensor):
|
if isinstance(alpha, torch.Tensor):
|
||||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
if self.use_w2 and self.use_w1:
|
if self.use_w2 and self.use_w1:
|
||||||
#use scale = 1
|
# use scale = 1
|
||||||
alpha = lora_dim
|
alpha = lora_dim
|
||||||
self.scale = alpha / self.lora_dim
|
self.scale = alpha / self.lora_dim
|
||||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant
|
||||||
|
|
||||||
if self.use_w2:
|
if self.use_w2:
|
||||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||||
@@ -197,7 +209,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
|||||||
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
||||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||||
|
|
||||||
if self.use_w1:
|
if self.use_w1:
|
||||||
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
||||||
else:
|
else:
|
||||||
@@ -208,8 +220,8 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
|||||||
self.org_module = [org_module]
|
self.org_module = [org_module]
|
||||||
weight = make_kron(
|
weight = make_kron(
|
||||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||||
(self.lokr_w2 if self.use_w2
|
(self.lokr_w2 if self.use_w2
|
||||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||||
else self.lokr_w2_a@self.lokr_w2_b),
|
else self.lokr_w2_a@self.lokr_w2_b),
|
||||||
torch.tensor(self.multiplier * self.scale)
|
torch.tensor(self.multiplier * self.scale)
|
||||||
)
|
)
|
||||||
@@ -219,12 +231,12 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
|||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module[0].forward
|
self.org_forward = self.org_module[0].forward
|
||||||
self.org_module[0].forward = self.forward
|
self.org_module[0].forward = self.forward
|
||||||
|
|
||||||
def get_weight(self, orig_weight = None):
|
def get_weight(self, orig_weight=None):
|
||||||
weight = make_kron(
|
weight = make_kron(
|
||||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||||
(self.lokr_w2 if self.use_w2
|
(self.lokr_w2 if self.use_w2
|
||||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||||
else self.lokr_w2_a@self.lokr_w2_b),
|
else self.lokr_w2_a@self.lokr_w2_b),
|
||||||
torch.tensor(self.scale)
|
torch.tensor(self.scale)
|
||||||
)
|
)
|
||||||
@@ -232,51 +244,88 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
|||||||
weight = weight.reshape(orig_weight.shape)
|
weight = weight.reshape(orig_weight.shape)
|
||||||
if self.training and self.rank_dropout:
|
if self.training and self.rank_dropout:
|
||||||
drop = torch.rand(weight.size(0)) < self.rank_dropout
|
drop = torch.rand(weight.size(0)) < self.rank_dropout
|
||||||
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
|
weight *= drop.view(-1, [1] *
|
||||||
|
len(weight.shape[1:])).to(weight.device)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def apply_max_norm(self, max_norm, device=None):
|
def merge_in(self, merge_weight=1.0):
|
||||||
orig_norm = self.get_weight().norm()
|
if not self.can_merge_in:
|
||||||
norm = torch.clamp(orig_norm, max_norm/2)
|
return
|
||||||
desired = torch.clamp(norm, max=max_norm)
|
|
||||||
ratio = desired.cpu()/norm.cpu()
|
|
||||||
|
|
||||||
scaled = ratio.item() != 1.0
|
|
||||||
if scaled:
|
|
||||||
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
|
|
||||||
if self.use_w1:
|
|
||||||
self.lokr_w1 *= ratio**(1/modules)
|
|
||||||
else:
|
|
||||||
self.lokr_w1_a *= ratio**(1/modules)
|
|
||||||
self.lokr_w1_b *= ratio**(1/modules)
|
|
||||||
|
|
||||||
if self.use_w2:
|
|
||||||
self.lokr_w2 *= ratio**(1/modules)
|
|
||||||
else:
|
|
||||||
if self.cp:
|
|
||||||
self.lokr_t2 *= ratio**(1/modules)
|
|
||||||
self.lokr_w2_a *= ratio**(1/modules)
|
|
||||||
self.lokr_w2_b *= ratio**(1/modules)
|
|
||||||
|
|
||||||
return scaled, orig_norm*ratio
|
|
||||||
|
|
||||||
def forward(self, x):
|
# extract weight from org_module
|
||||||
if self.module_dropout and self.training:
|
org_sd = self.org_module[0].state_dict()
|
||||||
if torch.rand(1) < self.module_dropout:
|
# todo find a way to merge in weights when doing quantized model
|
||||||
return self.op(
|
if 'weight._data' in org_sd:
|
||||||
x,
|
# quantized weight
|
||||||
self.org_module[0].weight.data,
|
return
|
||||||
None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
|
||||||
)
|
weight_key = "weight"
|
||||||
weight = (
|
if 'weight._data' in org_sd:
|
||||||
self.org_module[0].weight.data
|
# quantized weight
|
||||||
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
|
weight_key = "weight._data"
|
||||||
|
|
||||||
|
orig_dtype = org_sd[weight_key].dtype
|
||||||
|
weight = org_sd[weight_key].float()
|
||||||
|
|
||||||
|
scale = self.scale
|
||||||
|
# handle trainable scaler method locon does
|
||||||
|
if hasattr(self, 'scalar'):
|
||||||
|
scale = scale * self.scalar
|
||||||
|
|
||||||
|
lokr_weight = self.get_weight(weight)
|
||||||
|
|
||||||
|
merged_weight = (
|
||||||
|
weight
|
||||||
|
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype)
|
||||||
)
|
)
|
||||||
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
|
||||||
return self.op(
|
# set weight to org_module
|
||||||
x,
|
org_sd[weight_key] = merged_weight.to(orig_dtype)
|
||||||
|
self.org_module[0].load_state_dict(org_sd)
|
||||||
|
|
||||||
|
def get_orig_weight(self):
|
||||||
|
weight = self.org_module[0].weight
|
||||||
|
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
||||||
|
return weight.dequantize().data.detach()
|
||||||
|
else:
|
||||||
|
return weight.data.detach()
|
||||||
|
|
||||||
|
def get_orig_bias(self):
|
||||||
|
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||||
|
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
|
||||||
|
return self.org_module[0].bias.dequantize().data.detach()
|
||||||
|
else:
|
||||||
|
return self.org_module[0].bias.data.detach()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _call_forward(self, x):
|
||||||
|
if isinstance(x, QTensor) or isinstance(x, QBytesTensor):
|
||||||
|
x = x.dequantize()
|
||||||
|
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
|
||||||
|
orig_weight = self.get_orig_weight()
|
||||||
|
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
|
||||||
|
multiplier = self.network_ref().torch_multiplier
|
||||||
|
|
||||||
|
if x.dtype != orig_weight.dtype:
|
||||||
|
x = x.to(dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
# we do not currently support split batch multipliers for lokr. Just do a mean
|
||||||
|
multiplier = torch.mean(multiplier)
|
||||||
|
|
||||||
|
weight = (
|
||||||
|
orig_weight
|
||||||
|
+ lokr_weight * multiplier
|
||||||
|
)
|
||||||
|
bias = self.get_orig_bias()
|
||||||
|
if bias is not None:
|
||||||
|
bias = bias.to(weight.device, dtype=weight.dtype)
|
||||||
|
output = self.op(
|
||||||
|
x,
|
||||||
weight.view(self.shape),
|
weight.view(self.shape),
|
||||||
bias,
|
bias,
|
||||||
**self.extra_args
|
**self.extra_args
|
||||||
)
|
)
|
||||||
|
return output.to(orig_dtype)
|
||||||
|
|||||||
@@ -272,6 +272,9 @@ class ToolkitModuleMixin:
|
|||||||
# if self.__class__.__name__ == "DoRAModule":
|
# if self.__class__.__name__ == "DoRAModule":
|
||||||
# # return dora forward
|
# # return dora forward
|
||||||
# return self.dora_forward(x, *args, **kwargs)
|
# return self.dora_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.__class__.__name__ == "LokrModule":
|
||||||
|
return self._call_forward(x)
|
||||||
|
|
||||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
@@ -540,6 +543,17 @@ class ToolkitNetworkMixin:
|
|||||||
new_save_dict[new_key] = value
|
new_save_dict[new_key] = value
|
||||||
|
|
||||||
save_dict = new_save_dict
|
save_dict = new_save_dict
|
||||||
|
|
||||||
|
|
||||||
|
if self.network_type.lower() == "lokr":
|
||||||
|
new_save_dict = {}
|
||||||
|
for key, value in save_dict.items():
|
||||||
|
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||||
|
new_key = key
|
||||||
|
new_key = new_key.replace('lora_transformer_', 'lycoris_')
|
||||||
|
new_save_dict[new_key] = value
|
||||||
|
|
||||||
|
save_dict = new_save_dict
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = OrderedDict()
|
metadata = OrderedDict()
|
||||||
@@ -585,6 +599,10 @@ class ToolkitNetworkMixin:
|
|||||||
load_key = load_key.replace('.', '$$')
|
load_key = load_key.replace('.', '$$')
|
||||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||||
|
|
||||||
|
if self.network_type.lower() == "lokr":
|
||||||
|
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||||
|
load_key = load_key.replace('lycoris_', 'lora_transformer_')
|
||||||
|
|
||||||
load_sd[load_key] = value
|
load_sd[load_key] = value
|
||||||
|
|
||||||
@@ -616,9 +634,22 @@ class ToolkitNetworkMixin:
|
|||||||
# without having to set it in every single module every time it changes
|
# without having to set it in every single module every time it changes
|
||||||
multiplier = self._multiplier
|
multiplier = self._multiplier
|
||||||
# get first module
|
# get first module
|
||||||
first_module = self.get_all_modules()[0]
|
try:
|
||||||
device = first_module.lora_down.weight.device
|
first_module = self.get_all_modules()[0]
|
||||||
dtype = first_module.lora_down.weight.dtype
|
except IndexError:
|
||||||
|
raise ValueError("There are not any lora modules in this network. Check your config and try again")
|
||||||
|
|
||||||
|
if hasattr(first_module, 'lora_down'):
|
||||||
|
device = first_module.lora_down.weight.device
|
||||||
|
dtype = first_module.lora_down.weight.dtype
|
||||||
|
elif hasattr(first_module, 'lokr_w1'):
|
||||||
|
device = first_module.lokr_w1.device
|
||||||
|
dtype = first_module.lokr_w1.dtype
|
||||||
|
elif hasattr(first_module, 'lokr_w1_a'):
|
||||||
|
device = first_module.lokr_w1_a.device
|
||||||
|
dtype = first_module.lokr_w1_a.dtype
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown module type")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
tensor_multiplier = None
|
tensor_multiplier = None
|
||||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||||
|
|||||||
@@ -1385,7 +1385,8 @@ class StableDiffusion:
|
|||||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||||
self.adapter(conditional_clip_embeds)
|
self.adapter(conditional_clip_embeds)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \
|
||||||
|
and gen_config.adapter_image_path is not None:
|
||||||
# handle condition the prompts
|
# handle condition the prompts
|
||||||
gen_config.prompt = self.adapter.condition_prompt(
|
gen_config.prompt = self.adapter.condition_prompt(
|
||||||
gen_config.prompt,
|
gen_config.prompt,
|
||||||
@@ -1439,7 +1440,7 @@ class StableDiffusion:
|
|||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
|
||||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||||
tensors_0_1=validation_image,
|
tensors_0_1=validation_image,
|
||||||
prompt_embeds=conditional_embeds,
|
prompt_embeds=conditional_embeds,
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
type: 'lora',
|
type: 'lora',
|
||||||
linear: 16,
|
linear: 16,
|
||||||
linear_alpha: 16,
|
linear_alpha: 16,
|
||||||
|
lokr_full_rank: true,
|
||||||
|
lokr_factor: -1
|
||||||
},
|
},
|
||||||
save: {
|
save: {
|
||||||
dtype: 'bf16',
|
dtype: 'bf16',
|
||||||
|
|||||||
@@ -227,8 +227,31 @@ export default function TrainingForm() {
|
|||||||
</div>
|
</div>
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
</Card>
|
</Card>
|
||||||
{jobConfig.config.process[0].network?.type && (
|
<Card title="Target Configuration">
|
||||||
<Card title="LoRA Configuration">
|
<SelectInput
|
||||||
|
label="Target Type"
|
||||||
|
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
||||||
|
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
|
||||||
|
options={[
|
||||||
|
{ value: 'lora', label: 'LoRA' },
|
||||||
|
{ value: 'lokr', label: 'LoKr' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
{jobConfig.config.process[0].network?.type == 'lokr' && (
|
||||||
|
<SelectInput
|
||||||
|
label="LoKr Factor"
|
||||||
|
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
|
||||||
|
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
|
||||||
|
options={[
|
||||||
|
{ value: '-1', label: 'Auto' },
|
||||||
|
{ value: '4', label: '4' },
|
||||||
|
{ value: '8', label: '8' },
|
||||||
|
{ value: '16', label: '16' },
|
||||||
|
{ value: '32', label: '32' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{jobConfig.config.process[0].network?.type == 'lora' && (
|
||||||
<NumberInput
|
<NumberInput
|
||||||
label="Linear Rank"
|
label="Linear Rank"
|
||||||
value={jobConfig.config.process[0].network.linear}
|
value={jobConfig.config.process[0].network.linear}
|
||||||
@@ -242,8 +265,8 @@ export default function TrainingForm() {
|
|||||||
max={1024}
|
max={1024}
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
</Card>
|
)}
|
||||||
)}
|
</Card>
|
||||||
<Card title="Save Configuration">
|
<Card title="Save Configuration">
|
||||||
<SelectInput
|
<SelectInput
|
||||||
label="Data Type"
|
label="Data Type"
|
||||||
@@ -397,7 +420,9 @@ export default function TrainingForm() {
|
|||||||
label="DFE Loss Multiplier"
|
label="DFE Loss Multiplier"
|
||||||
className="pt-2"
|
className="pt-2"
|
||||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
onChange={value =>
|
||||||
|
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||||
|
}
|
||||||
placeholder="eg. 1.0"
|
placeholder="eg. 1.0"
|
||||||
min={0}
|
min={0}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -50,9 +50,11 @@ export interface GPUApiResponse {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
export interface NetworkConfig {
|
export interface NetworkConfig {
|
||||||
type: 'lora';
|
type: string;
|
||||||
linear: number;
|
linear: number;
|
||||||
linear_alpha: number;
|
linear_alpha: number;
|
||||||
|
lokr_full_rank: boolean;
|
||||||
|
lokr_factor: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SaveConfig {
|
export interface SaveConfig {
|
||||||
|
|||||||
Reference in New Issue
Block a user