Merge branch 'main' into wan21

This commit is contained in:
Jaret Burkett
2025-03-04 00:31:57 -07:00
11 changed files with 603 additions and 132 deletions

File diff suppressed because one or more lines are too long

309
scripts/update_sponsors.py Normal file
View File

@@ -0,0 +1,309 @@
import os
import requests
import json
from datetime import datetime
from dotenv import load_dotenv
# Load environment variables from .env file
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env")
load_dotenv(dotenv_path=env_path)
# API credentials
PATREON_TOKEN = os.getenv("PATREON_ACCESS_TOKEN")
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")
GITHUB_USERNAME = os.getenv("GITHUB_USERNAME")
GITHUB_ORG = os.getenv("GITHUB_ORG") # Organization name (optional)
# Output file
README_PATH = "SUPPORTERS.md"
def fetch_patreon_supporters():
"""Fetch current Patreon supporters"""
print("Fetching Patreon supporters...")
headers = {
"Authorization": f"Bearer {PATREON_TOKEN}",
"Content-Type": "application/json"
}
url = "https://www.patreon.com/api/oauth2/v2/campaigns"
try:
# First get the campaign ID
campaign_response = requests.get(url, headers=headers)
campaign_response.raise_for_status()
campaign_data = campaign_response.json()
if not campaign_data.get('data'):
print("No campaigns found for this Patreon account")
return []
campaign_id = campaign_data['data'][0]['id']
# Now get the supporters for this campaign
members_url = f"https://www.patreon.com/api/oauth2/v2/campaigns/{campaign_id}/members"
params = {
"include": "user",
"fields[member]": "full_name,is_follower,patron_status", # Removed profile_url
"fields[user]": "image_url"
}
supporters = []
while members_url:
members_response = requests.get(members_url, headers=headers, params=params)
members_response.raise_for_status()
members_data = members_response.json()
# Process the response to extract active patrons
for member in members_data.get('data', []):
attributes = member.get('attributes', {})
# Only include active patrons
if attributes.get('patron_status') == 'active_patron':
name = attributes.get('full_name', 'Anonymous Supporter')
# Get user data which contains the profile image
user_id = member.get('relationships', {}).get('user', {}).get('data', {}).get('id')
profile_image = None
profile_url = None # Removed profile_url since it's not supported
if user_id:
for included in members_data.get('included', []):
if included.get('id') == user_id and included.get('type') == 'user':
profile_image = included.get('attributes', {}).get('image_url')
break
supporters.append({
'name': name,
'profile_image': profile_image,
'profile_url': profile_url, # This will be None
'platform': 'Patreon',
'amount': 0 # Placeholder, as Patreon API doesn't provide this in the current response
})
# Handle pagination
members_url = members_data.get('links', {}).get('next')
print(f"Found {len(supporters)} active Patreon supporters")
return supporters
except requests.exceptions.RequestException as e:
print(f"Error fetching Patreon data: {e}")
print(f"Response content: {e.response.content if hasattr(e, 'response') else 'No response content'}")
return []
def fetch_github_sponsors():
"""Fetch current GitHub sponsors for a user or organization"""
print("Fetching GitHub sponsors...")
headers = {
"Authorization": f"Bearer {GITHUB_TOKEN}",
"Accept": "application/vnd.github.v3+json"
}
# Determine if we're fetching for a user or an organization
entity_type = "organization" if GITHUB_ORG else "user"
entity_name = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
if not entity_name:
print("Error: Neither GITHUB_USERNAME nor GITHUB_ORG is set")
return []
# Different GraphQL query structure based on entity type
if entity_type == "user":
query = """
query {
user(login: "%s") {
sponsorshipsAsMaintainer(first: 100) {
nodes {
sponsorEntity {
... on User {
login
name
avatarUrl
url
}
... on Organization {
login
name
avatarUrl
url
}
}
tier {
monthlyPriceInDollars
}
isOneTimePayment
isActive
}
}
}
}
""" % entity_name
else: # organization
query = """
query {
organization(login: "%s") {
sponsorshipsAsMaintainer(first: 100) {
nodes {
sponsorEntity {
... on User {
login
name
avatarUrl
url
}
... on Organization {
login
name
avatarUrl
url
}
}
tier {
monthlyPriceInDollars
}
isOneTimePayment
isActive
}
}
}
}
""" % entity_name
try:
response = requests.post(
"https://api.github.com/graphql",
headers=headers,
json={"query": query}
)
response.raise_for_status()
data = response.json()
# Process the response - the path to the data differs based on entity type
if entity_type == "user":
sponsors_data = data.get('data', {}).get('user', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
else:
sponsors_data = data.get('data', {}).get('organization', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', [])
sponsors = []
for sponsor in sponsors_data:
# Only include active sponsors
if sponsor.get('isActive'):
entity = sponsor.get('sponsorEntity', {})
name = entity.get('name') or entity.get('login', 'Anonymous Sponsor')
profile_image = entity.get('avatarUrl')
profile_url = entity.get('url')
amount = sponsor.get('tier', {}).get('monthlyPriceInDollars', 0)
sponsors.append({
'name': name,
'profile_image': profile_image,
'profile_url': profile_url,
'platform': 'GitHub Sponsors',
'amount': amount
})
print(f"Found {len(sponsors)} active GitHub sponsors for {entity_type} '{entity_name}'")
return sponsors
except requests.exceptions.RequestException as e:
print(f"Error fetching GitHub sponsors data: {e}")
return []
def generate_readme(supporters):
"""Generate a README.md file with supporter information"""
print(f"Generating {README_PATH}...")
# Sort supporters by amount (descending) and then by name
supporters.sort(key=lambda x: (-x['amount'], x['name'].lower()))
# Determine the proper footer links based on what's configured
github_entity = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME
github_entity_type = "orgs" if GITHUB_ORG else "sponsors"
github_sponsor_url = f"https://github.com/{github_entity_type}/{github_entity}"
with open(README_PATH, "w", encoding="utf-8") as f:
f.write("## Support My Work\n\n")
f.write("If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps! \n\n")
# Create appropriate call-to-action based on what's configured
cta_parts = []
if github_entity:
cta_parts.append(f"[Become a sponsor on GitHub]({github_sponsor_url})")
if PATREON_TOKEN:
cta_parts.append("[support me on Patreon](https://www.patreon.com/ostris)")
if cta_parts:
if GITHUB_ORG:
f.write(f"{' or '.join(cta_parts)}.\n\n")
f.write("Thank you to all my current supporters!\n\n")
f.write(f"_Last updated: {datetime.now().strftime('%Y-%m-%d')}_\n\n")
# Write GitHub Sponsors section
github_sponsors = [s for s in supporters if s['platform'] == 'GitHub Sponsors']
if github_sponsors:
f.write("### GitHub Sponsors\n\n")
for sponsor in github_sponsors:
if sponsor['profile_image']:
f.write(f"<a href=\"{sponsor['profile_url']}\" title=\"{sponsor['name']}\"><img src=\"{sponsor['profile_image']}\" width=\"50\" height=\"50\" alt=\"{sponsor['name']}\" style=\"border-radius:50%\"></a> ")
else:
f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ")
f.write("\n\n")
# Write Patreon section
patreon_supporters = [s for s in supporters if s['platform'] == 'Patreon']
if patreon_supporters:
f.write("### Patreon Supporters\n\n")
for supporter in patreon_supporters:
if supporter['profile_image']:
f.write(f"<a href=\"{supporter['profile_url']}\" title=\"{supporter['name']}\"><img src=\"{supporter['profile_image']}\" width=\"50\" height=\"50\" alt=\"{supporter['name']}\" style=\"border-radius:50%\"></a> ")
else:
f.write(f"[{supporter['name']}]({supporter['profile_url']}) ")
f.write("\n\n")
f.write("\n---\n\n")
print(f"Successfully generated {README_PATH} with {len(supporters)} supporters!")
def main():
"""Main function"""
print("Starting supporter data collection...")
# Check if required environment variables are set
missing_vars = []
if not GITHUB_TOKEN:
missing_vars.append("GITHUB_TOKEN")
# Either username or org is required for GitHub
if not GITHUB_USERNAME and not GITHUB_ORG:
missing_vars.append("GITHUB_USERNAME or GITHUB_ORG")
# Patreon token is optional but warn if missing
patreon_enabled = bool(PATREON_TOKEN)
if missing_vars:
print(f"Error: Missing required environment variables: {', '.join(missing_vars)}")
print("Please add them to your .env file")
return
if not patreon_enabled:
print("Warning: PATREON_ACCESS_TOKEN not set. Will only fetch GitHub sponsors.")
# Fetch data from both platforms
patreon_supporters = fetch_patreon_supporters() if PATREON_TOKEN else []
github_sponsors = fetch_github_sponsors()
# Combine supporters from both platforms
all_supporters = patreon_supporters + github_sponsors
if not all_supporters:
print("No supporters found on either platform")
return
# Generate README
generate_readme(all_supporters)
if __name__ == "__main__":
main()

View File

@@ -1,3 +0,0 @@
- only do ema on main device? shouldne be needed other than saving and sampling
- check when to unwrap model and what it does
- disable timer for non main local

View File

@@ -136,6 +136,15 @@ class NetworkConfig:
self.transformer_only = kwargs.get('transformer_only', True)
self.lokr_full_rank = kwargs.get('lokr_full_rank', False)
if self.lokr_full_rank and self.type.lower() == 'lokr':
self.linear = 9999999999
self.linear_alpha = 9999999999
self.conv = 9999999999
self.conv_alpha = 9999999999
# -1 automatically finds the largest factor
self.lokr_factor = kwargs.get('lokr_factor', -1)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']

View File

@@ -231,11 +231,17 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
module_class = DoRAModule
elif self.network_type.lower() == "lokr":
self.module_class = LokrModule
module_class = LokrModule
self.network_config: NetworkConfig = kwargs.get("network_config", None)
self.peft_format = peft_format
# always do peft for flux only for now
if self.is_flux or self.is_v3 or self.is_lumina2:
# don't do peft format for lokr
if self.network_type.lower() != "lokr":
self.peft_format = True
if self.peft_format:
@@ -338,7 +344,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if (is_linear or is_conv2d) and not skip:
if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
if self.only_if_contains is not None:
if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]):
continue
dim = None
@@ -374,6 +381,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
skipped.append(lora_name)
continue
module_kwargs = {}
if self.network_type.lower() == "lokr":
module_kwargs["factor"] = self.network_config.lokr_factor
lora = module_class(
lora_name,
child_module,
@@ -386,10 +398,16 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
network=self,
parent=module,
use_bias=use_bias,
**module_kwargs
)
loras.append(lora)
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
]
if self.network_type.lower() == "lokr":
try:
lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)]
except:
pass
else:
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
return loras, skipped
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]

View File

@@ -10,13 +10,12 @@ from toolkit.network_mixins import ToolkitModuleMixin
from typing import TYPE_CHECKING, Union, List
from optimum.quanto import QBytesTensor, QTensor
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# 4, build custom backward function
# -
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
'''
@@ -62,7 +61,8 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
def make_weight_cp(t, wa, wb):
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l',
t, wa, wb) # [c, d, k1, k2]
return rebuild2
@@ -76,12 +76,6 @@ def make_kron(w1, w2, scale):
class LokrModule(ToolkitModuleMixin, nn.Module):
"""
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
"""
def __init__(
self,
lora_name,
@@ -107,6 +101,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
self.cp = False
self.use_w1 = False
self.use_w2 = False
self.can_merge_in = True
self.shape = org_module.weight.shape
if org_module.__class__.__name__ == 'Conv2d':
@@ -116,27 +111,37 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
# ((a, b), (c, d), *k_size)
shape = ((out_l, out_k), (in_m, in_n), *k_size)
self.cp = use_cp and k_size != (1, 1)
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
self.lokr_w1_a = nn.Parameter(
torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(
torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
self.lokr_w1 = nn.Parameter(torch.empty(
shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim >= max(shape[0][1], shape[1][1])/2:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
self.lokr_w2 = nn.Parameter(torch.empty(
shape[0][1], shape[1][1], *k_size))
elif self.cp:
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
self.lokr_t2 = nn.Parameter(torch.empty(
lora_dim, lora_dim, shape[2], shape[3]))
self.lokr_w2_a = nn.Parameter(
torch.empty(lora_dim, shape[0][1])) # b, 1-mode
self.lokr_w2_b = nn.Parameter(
torch.empty(lora_dim, shape[1][1])) # d, 2-mode
else: # Conv2d not cp
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
self.lokr_w2_a = nn.Parameter(
torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(
lora_dim, shape[1][1]*shape[2]*shape[3]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
self.op = F.conv2d
@@ -153,31 +158,38 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
shape = ((out_l, out_k), (in_m, in_n))
# smaller part. weight scale
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
self.lokr_w1_a = nn.Parameter(
torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(
torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
self.lokr_w1 = nn.Parameter(torch.empty(
shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim < max(shape[0][1], shape[1][1])/2:
# bigger part. weight and LoRA. [b, dim] x [dim, d]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
self.lokr_w2_a = nn.Parameter(
torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(
torch.empty(lora_dim, shape[1][1]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
self.lokr_w2 = nn.Parameter(
torch.empty(shape[0][1], shape[1][1]))
self.op = F.linear
self.extra_args = {}
self.dropout = dropout
if dropout:
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
print("[WARN]LoKr haven't implemented normal dropout yet.")
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
@@ -188,7 +200,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
# use scale = 1
alpha = lora_dim
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant
if self.use_w2:
torch.nn.init.constant_(self.lokr_w2, 0)
@@ -232,51 +244,88 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
weight = weight.reshape(orig_weight.shape)
if self.training and self.rank_dropout:
drop = torch.rand(weight.size(0)) < self.rank_dropout
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
weight *= drop.view(-1, [1] *
len(weight.shape[1:])).to(weight.device)
return weight
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.get_weight().norm()
norm = torch.clamp(orig_norm, max_norm/2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu()/norm.cpu()
def merge_in(self, merge_weight=1.0):
if not self.can_merge_in:
return
scaled = ratio.item() != 1.0
if scaled:
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
if self.use_w1:
self.lokr_w1 *= ratio**(1/modules)
else:
self.lokr_w1_a *= ratio**(1/modules)
self.lokr_w1_b *= ratio**(1/modules)
# extract weight from org_module
org_sd = self.org_module[0].state_dict()
# todo find a way to merge in weights when doing quantized model
if 'weight._data' in org_sd:
# quantized weight
return
if self.use_w2:
self.lokr_w2 *= ratio**(1/modules)
else:
if self.cp:
self.lokr_t2 *= ratio**(1/modules)
self.lokr_w2_a *= ratio**(1/modules)
self.lokr_w2_b *= ratio**(1/modules)
weight_key = "weight"
if 'weight._data' in org_sd:
# quantized weight
weight_key = "weight._data"
return scaled, orig_norm*ratio
orig_dtype = org_sd[weight_key].dtype
weight = org_sd[weight_key].float()
def forward(self, x):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.op(
x,
self.org_module[0].weight.data,
None if self.org_module[0].bias is None else self.org_module[0].bias.data
scale = self.scale
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale = scale * self.scalar
lokr_weight = self.get_weight(weight)
merged_weight = (
weight
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype)
)
# set weight to org_module
org_sd[weight_key] = merged_weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def get_orig_weight(self):
weight = self.org_module[0].weight
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
return weight.dequantize().data.detach()
else:
return weight.data.detach()
def get_orig_bias(self):
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor):
return self.org_module[0].bias.dequantize().data.detach()
else:
return self.org_module[0].bias.data.detach()
return None
def _call_forward(self, x):
if isinstance(x, QTensor) or isinstance(x, QBytesTensor):
x = x.dequantize()
orig_dtype = x.dtype
orig_weight = self.get_orig_weight()
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype)
multiplier = self.network_ref().torch_multiplier
if x.dtype != orig_weight.dtype:
x = x.to(dtype=orig_weight.dtype)
# we do not currently support split batch multipliers for lokr. Just do a mean
multiplier = torch.mean(multiplier)
weight = (
self.org_module[0].weight.data
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
orig_weight
+ lokr_weight * multiplier
)
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
return self.op(
bias = self.get_orig_bias()
if bias is not None:
bias = bias.to(weight.device, dtype=weight.dtype)
output = self.op(
x,
weight.view(self.shape),
bias,
**self.extra_args
)
return output.to(orig_dtype)

View File

@@ -273,6 +273,9 @@ class ToolkitModuleMixin:
# # return dora forward
# return self.dora_forward(x, *args, **kwargs)
if self.__class__.__name__ == "LokrModule":
return self._call_forward(x)
org_forwarded = self.org_forward(x, *args, **kwargs)
if isinstance(x, QTensor):
@@ -541,6 +544,17 @@ class ToolkitNetworkMixin:
save_dict = new_save_dict
if self.network_type.lower() == "lokr":
new_save_dict = {}
for key, value in save_dict.items():
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
new_key = key
new_key = new_key.replace('lora_transformer_', 'lycoris_')
new_save_dict[new_key] = value
save_dict = new_save_dict
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(state_dict, metadata)
@@ -586,6 +600,10 @@ class ToolkitNetworkMixin:
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
if self.network_type.lower() == "lokr":
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
load_key = load_key.replace('lycoris_', 'lora_transformer_')
load_sd[load_key] = value
# extract extra items from state dict
@@ -616,9 +634,22 @@ class ToolkitNetworkMixin:
# without having to set it in every single module every time it changes
multiplier = self._multiplier
# get first module
try:
first_module = self.get_all_modules()[0]
except IndexError:
raise ValueError("There are not any lora modules in this network. Check your config and try again")
if hasattr(first_module, 'lora_down'):
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
elif hasattr(first_module, 'lokr_w1'):
device = first_module.lokr_w1.device
dtype = first_module.lokr_w1.dtype
elif hasattr(first_module, 'lokr_w1_a'):
device = first_module.lokr_w1_a.device
dtype = first_module.lokr_w1_a.dtype
else:
raise ValueError("Unknown module type")
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):

View File

@@ -1385,7 +1385,8 @@ class StableDiffusion:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \
and gen_config.adapter_image_path is not None:
# handle condition the prompts
gen_config.prompt = self.adapter.condition_prompt(
gen_config.prompt,
@@ -1439,7 +1440,7 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image,
prompt_embeds=conditional_embeds,

View File

@@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = {
type: 'lora',
linear: 16,
linear_alpha: 16,
lokr_full_rank: true,
lokr_factor: -1
},
save: {
dtype: 'bf16',

View File

@@ -227,8 +227,31 @@ export default function TrainingForm() {
</div>
</FormGroup>
</Card>
{jobConfig.config.process[0].network?.type && (
<Card title="LoRA Configuration">
<Card title="Target Configuration">
<SelectInput
label="Target Type"
value={jobConfig.config.process[0].network?.type ?? 'lora'}
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
options={[
{ value: 'lora', label: 'LoRA' },
{ value: 'lokr', label: 'LoKr' },
]}
/>
{jobConfig.config.process[0].network?.type == 'lokr' && (
<SelectInput
label="LoKr Factor"
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
options={[
{ value: '-1', label: 'Auto' },
{ value: '4', label: '4' },
{ value: '8', label: '8' },
{ value: '16', label: '16' },
{ value: '32', label: '32' },
]}
/>
)}
{jobConfig.config.process[0].network?.type == 'lora' && (
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
@@ -242,8 +265,8 @@ export default function TrainingForm() {
max={1024}
required
/>
</Card>
)}
</Card>
<Card title="Save Configuration">
<SelectInput
label="Data Type"
@@ -397,7 +420,9 @@ export default function TrainingForm() {
label="DFE Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>

View File

@@ -50,9 +50,11 @@ export interface GPUApiResponse {
*/
export interface NetworkConfig {
type: 'lora';
type: string;
linear: number;
linear_alpha: number;
lokr_full_rank: boolean;
lokr_factor: number;
}
export interface SaveConfig {