Merge branch 'main' into dev

This commit is contained in:
Jaret Burkett
2025-06-10 10:26:11 -06:00
26 changed files with 2684 additions and 103 deletions

12
.vscode/launch.json vendored
View File

@@ -40,5 +40,17 @@
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Python: Debug Current File (cuda:1)",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"env": {
"CUDA_LAUNCH_BLOCKING": "1",
"CUDA_VISIBLE_DEVICES": "1"
},
"justMyCode": false
},
]
}

View File

@@ -414,3 +414,12 @@ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https:/
Everything else should work the same including layer targeting.
## Updates
### June 10, 2024
- Decided to keep track up updates in the readme
- Added support for SDXL in the UI
- Added support for SD 1.5 in the UI
- Fixed UI Wan 2.1 14b name bug
- Added support for for conv training in the UI for models that support it

View File

@@ -438,7 +438,7 @@ class SDTrainer(BaseSDTrainProcess):
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
elif self.dfe.version == 3:
elif self.dfe.version == 3 or self.dfe.version == 4:
dfe_loss = self.dfe(
noise=noise,
noise_pred=noise_pred,
@@ -501,15 +501,27 @@ class SDTrainer(BaseSDTrainProcess):
loss = wavelet_loss(pred, batch.latents, noise)
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
do_weighted_timesteps = False
if self.sd.is_flow_matching:
if self.train_config.linear_timesteps or self.train_config.linear_timesteps2:
do_weighted_timesteps = True
if self.train_config.timestep_type == "weighted":
# use the noise scheduler to get the weights for the timesteps
do_weighted_timesteps = True
# handle linear timesteps and only adjust the weight of the timesteps
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
if do_weighted_timesteps:
# calculate the weights for the timesteps
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
timesteps,
v2=self.train_config.linear_timesteps2
v2=self.train_config.linear_timesteps2,
timestep_type=self.train_config.timestep_type
).to(loss.device, dtype=loss.dtype)
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
if len(loss.shape) == 4:
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
elif len(loss.shape) == 5:
timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach()
loss = loss * timestep_weight
if self.train_config.do_prior_divergence and prior_pred is not None:
@@ -764,6 +776,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
batch: Optional['DataLoaderBatchDTO'] = None,
is_primary_pred: bool = False,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
@@ -1553,6 +1566,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
is_primary_pred=True,
**pred_kwargs
)
self.after_unet_predict()

View File

@@ -1116,6 +1116,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.linear_timesteps,
self.train_config.linear_timesteps2,
self.train_config.timestep_type == 'linear',
self.train_config.timestep_type == 'one_step',
])
timestep_type = 'linear' if linear_timesteps else None
@@ -1159,6 +1160,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
device=self.device_torch
)
timestep_indices = timestep_indices.long()
elif self.train_config.timestep_type == 'one_step':
timestep_indices = torch.zeros((batch_size,), device=self.device_torch, dtype=torch.long)
elif content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps

View File

@@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
@@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess):
else:
return torch.tensor(0.0, device=self.device)
def get_ltv_loss(self, latent):
def get_ltv_loss(self, latent, images):
# loss to reduce the latent space variance
if self.ltv_weight > 0:
return total_variation(latent).mean()
with torch.no_grad():
images = images.to(latent.device, dtype=latent.dtype)
# resize down to latent size
images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
# mean the color channel and then expand to latent size
images = images.mean(dim=1, keepdim=True)
images = images.repeat(1, latent.shape[1], 1, 1)
# normalize to a mean of 0 and std of 1
images_mean = images.mean(dim=(2, 3), keepdim=True)
images_std = images.std(dim=(2, 3), keepdim=True)
images = (images - images_mean) / (images_std + 1e-6)
# now we target the same std of the image for the latent space as to not reduce to 0
latent_tv = torch.abs(total_variation_deltas(latent))
images_tv = torch.abs(total_variation_deltas(images))
loss = torch.abs(latent_tv - images_tv) # keep it spatially aware
loss = loss.mean(dim=2, keepdim=True)
loss = loss.mean(dim=3, keepdim=True) # mean over height and width
loss = loss.mean(dim=1, keepdim=True) # mean over channels
loss = loss.mean()
return loss
else:
return torch.tensor(0.0, device=self.device)
@@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess):
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
if self.ltv_weight > 0:
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight
else:
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)

View File

@@ -2,7 +2,7 @@ torchao==0.10.0
safetensors
git+https://github.com/jaretburkett/easy_dwpose.git
git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b
transformers==4.49.0
transformers==4.52.4
lycoris-lora==1.8.3
flatten_json
pyyaml

View File

@@ -0,0 +1,228 @@
import gc
import os, sys
from tqdm import tqdm
import numpy as np
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# set visible devices to 0
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# protect from formatting
if True:
import torch
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from diffusers import FluxTransformer2DModel, FluxPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTextModel, CLIPTokenizer
from torchvision import transforms
qtype = "qfloat8"
dtype = torch.bfloat16
# base_model_path = "black-forest-labs/FLUX.1-dev"
base_model_path = "ostris/Flex.1-alpha"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Loading Transformer...")
prompt = "Photo of a man and a woman in a park, sunny day"
output_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output")
output_path = os.path.join(output_root, "flex_timestep_weights.json")
img_output_path = os.path.join(output_root, "flex_timestep_weights.png")
quantization_type = get_qtype(qtype)
def flush():
torch.cuda.empty_cache()
gc.collect()
pil_to_tensor = transforms.ToTensor()
with torch.no_grad():
transformer = FluxTransformer2DModel.from_pretrained(
base_model_path,
subfolder='transformer',
torch_dtype=dtype
)
transformer.to(device, dtype=dtype)
print("Quantizing Transformer...")
quantize(transformer, weights=quantization_type)
freeze(transformer)
flush()
print("Loading Scheduler...")
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print("Loading Autoencoder...")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
vae.to(device, dtype=dtype)
flush()
print("Loading Text Encoder...")
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
text_encoder_2.to(device, dtype=dtype)
print("Quantizing Text Encoder...")
quantize(text_encoder_2, weights=get_qtype(qtype))
freeze(text_encoder_2)
flush()
print("Loading CLIP")
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(device, dtype=dtype)
print("Making pipe")
pipe: FluxPipeline = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.to(device, dtype=dtype)
print("Encoding prompt...")
prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
prompt,
prompt_2=prompt,
device=device
)
generator = torch.manual_seed(42)
height = 1024
width = 1024
print("Generating image...")
# Fix a bug in diffusers/torch
def callback_on_step_end(pipe, i, t, callback_kwargs):
latents = callback_kwargs["latents"]
if latents.dtype != dtype:
latents = latents.to(dtype)
return {"latents": latents}
img = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
height=height,
width=height,
num_inference_steps=30,
guidance_scale=3.5,
generator=generator,
callback_on_step_end=callback_on_step_end,
).images[0]
img.save(img_output_path)
print(f"Image saved to {img_output_path}")
print("Encoding image...")
# img is a PIL image. convert it to a -1 to 1 tensor
img = pil_to_tensor(img)
img = img.unsqueeze(0) # add batch dimension
img = img * 2 - 1 # convert to -1 to 1 range
img = img.to(device, dtype=dtype)
latents = vae.encode(img).latent_dist.sample()
shift = vae.config['shift_factor'] if vae.config['shift_factor'] is not None else 0
latents = vae.config['scaling_factor'] * (latents - shift)
num_channels_latents = pipe.transformer.config.in_channels // 4
l_height = 2 * (int(height) // (pipe.vae_scale_factor * 2))
l_width = 2 * (int(width) // (pipe.vae_scale_factor * 2))
packed_latents = pipe._pack_latents(latents, 1, num_channels_latents, l_height, l_width)
packed_latents, latent_image_ids = pipe.prepare_latents(
1,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
packed_latents,
)
print("Calculating timestep weights...")
torch.manual_seed(8675309)
noise = torch.randn_like(packed_latents, device=device, dtype=dtype)
# Create linear timesteps from 1000 to 0
num_train_timesteps = 1000
timesteps_torch = torch.linspace(1000, 1, num_train_timesteps, device='cpu')
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
timestep_weights = torch.zeros(num_train_timesteps, dtype=torch.float32, device=device)
guidance = torch.full([1], 1.0, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
pbar = tqdm(range(num_train_timesteps), desc="loss: 0.000000 scaler: 0.0000")
for i in pbar:
timestep = timesteps[i:i+1].to(device)
t_01 = (timestep / 1000).to(device)
t_01 = t_01.reshape(-1, 1, 1)
noisy_latents = (1.0 - t_01) * packed_latents + t_01 * noise
noise_pred = pipe.transformer(
hidden_states=noisy_latents, # torch.Size([1, 4096, 64])
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
target = noise - packed_latents
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float())
loss = loss
# determine scaler to multiply loss by to make it 1
scaler = 1.0 / (loss + 1e-6)
timestep_weights[i] = scaler
pbar.set_description(f"loss: {loss.item():.6f} scaler: {scaler.item():.4f}")
print("normalizing timestep weights...")
# normalize the timestep weights so they are a mean of 1.0
timestep_weights = timestep_weights / timestep_weights.mean()
timestep_weights = timestep_weights.cpu().numpy().tolist()
print("Saving timestep weights...")
with open(output_path, 'w') as f:
json.dump(timestep_weights, f)
print(f"Timestep weights saved to {output_path}")
print("Done!")
flush()

View File

@@ -437,7 +437,7 @@ class TrainConfig:
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample, weighted, one_step
self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8)
self.linear_timesteps = kwargs.get('linear_timesteps', False)
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)

View File

@@ -1773,6 +1773,97 @@ class LatentCachingMixin:
self.sd.restore_device_state()
class TextEmbeddingCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
def cache_text_embeddings(self: 'AiToolkitDataset'):
with accelerator.main_process_first():
print_acc(f"Caching text_embeddings for {self.dataset_path}")
# cache all latents to disk
to_disk = self.is_caching_latents_to_disk
to_memory = self.is_caching_latents_to_memory
print_acc(" - Saving text embeddings to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_latents')
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
# set latent space version
if self.sd.model_config.latent_space_version is not None:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux1'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = self.sd.model_config.arch
file_item.is_caching_to_disk = to_disk
file_item.is_caching_to_memory = to_memory
file_item.latent_load_device = self.sd.device
latent_path = file_item.get_latent_path(recalculate=True)
# check if it is saved to disk already
if os.path.exists(latent_path):
if to_memory:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
save_file(state_dict, latent_path, metadata=meta)
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()
class CLIPCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it

View File

@@ -137,7 +137,8 @@ class ExponentialMovingAverage:
update_param = False
if self.use_feedback:
param_float.add_(tmp)
# make feedback 10x decay
param_float.add_(tmp * 10)
update_param = True
if self.param_multiplier != 1.0:

View File

@@ -7,7 +7,7 @@ import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel
from transformers import CLIPTextModel
from toolkit.models.lokr import LokrModule
@@ -522,6 +522,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
elif base_model is not None and base_model.arch == "wan21":
transformer: WanTransformer3DModel = unet
self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.patch_embedding = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
else:
unet: UNet2DConditionModel = unet
@@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out:
if self.is_pixart or self.is_auraflow or self.is_flux:
base_model = self.base_model_ref() if self.base_model_ref is not None else None
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else:

View File

@@ -13,6 +13,22 @@ def total_variation(image):
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
def total_variation_deltas(image):
"""
Compute per-pixel total variation deltas.
Input:
- image: Tensor of shape (N, C, H, W)
Returns:
- Tensor with shape (N, C, H, W), padded to match input shape
"""
dh = torch.zeros_like(image)
dv = torch.zeros_like(image)
dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
return dh + dv
class ComparativeTotalVariation(torch.nn.Module):

View File

@@ -1,3 +1,4 @@
import math
import torch
import os
from torch import nn
@@ -351,12 +352,252 @@ class DiffusionFeatureExtractor3(nn.Module):
return total_loss
class DiffusionFeatureExtractor4(nn.Module):
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
super().__init__()
self.version = 4
if vae is None:
raise ValueError("vae must be provided for DFE4")
self.vae = vae
# image_encoder_path = "google/siglip-so400m-patch14-384"
image_encoder_path = "google/siglip2-so400m-patch16-naflex"
from transformers import Siglip2ImageProcessor, Siglip2VisionModel
try:
self.image_processor = Siglip2ImageProcessor.from_pretrained(
image_encoder_path)
except EnvironmentError:
self.image_processor = Siglip2ImageProcessor()
self.image_processor.max_num_patches = 1024
self.vision_encoder = Siglip2VisionModel.from_pretrained(
image_encoder_path,
ignore_mismatched_sizes=True
).to(device, dtype=dtype)
self.losses = {}
self.log_every = 100
self.step = 0
def _target_hw(self, h, w, patch, max_patches, eps: float = 1e-5):
def _snap(x, s):
x = math.ceil((x * s) / patch) * patch
return max(patch, int(x))
lo, hi = eps / 10, 1.0
while hi - lo >= eps:
mid = (lo + hi) / 2
th, tw = _snap(h, mid), _snap(w, mid)
if (th // patch) * (tw // patch) <= max_patches:
lo = mid
else:
hi = mid
return _snap(h, lo), _snap(w, lo)
def tensors_to_siglip_like_features(self, batch: torch.Tensor):
"""
Args:
batch: (bs, 3, H, W) tensor already in the desired value range
(e.g. [-1, 1] or [0, 1]); no extra rescale / normalize here.
Returns:
dict(
pixel_values (bs, L, P) where L = n_h*n_w, P = 3*patch*patch
pixel_attention_mask (L,) all-ones
spatial_shapes (n_h, n_w)
)
"""
if batch.ndim != 4:
raise ValueError("Expected (bs, 3, H, W) tensor")
bs, c, H, W = batch.shape
proc = self.image_processor
patch = proc.patch_size
max_patches = proc.max_num_patches
# One shared resize for the whole batch
tgt_h, tgt_w = self._target_hw(H, W, patch, max_patches)
batch = torch.nn.functional.interpolate(
batch, size=(tgt_h, tgt_w), mode="bilinear", align_corners=False
)
n_h, n_w = tgt_h // patch, tgt_w // patch
# flat_dim = c * patch * patch
num_p = n_h * n_w
# unfold → (bs, flat_dim, num_p) → (bs, num_p, flat_dim)
patches = (
torch.nn.functional.unfold(batch, kernel_size=patch, stride=patch)
.transpose(1, 2)
)
attn_mask = torch.ones(num_p, dtype=torch.long, device=batch.device)
spatial = torch.tensor((n_h, n_w), device=batch.device, dtype=torch.int32)
# repeat attn_mask for each batch element
attn_mask = attn_mask.unsqueeze(0).repeat(bs, 1)
spatial = spatial.unsqueeze(0).repeat(bs, 1)
return {
"pixel_values": patches, # (bs, num_patches, patch_dim)
"pixel_attention_mask": attn_mask, # (num_patches,)
"spatial_shapes": spatial
}
def get_siglip_features(self, tensors_0_1):
dtype = torch.bfloat16
device = self.vae.device
tensors_0_1 = torch.clamp(tensors_0_1, 0.0, 1.0)
mean = torch.tensor(self.image_processor.image_mean).to(
device, dtype=dtype
).detach()
std = torch.tensor(self.image_processor.image_std).to(
device, dtype=dtype
).detach()
# tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
encoder_kwargs = self.tensors_to_siglip_like_features(clip_image)
id_embeds = self.vision_encoder(
pixel_values=encoder_kwargs['pixel_values'],
pixel_attention_mask=encoder_kwargs['pixel_attention_mask'],
spatial_shapes=encoder_kwargs['spatial_shapes'],
output_hidden_states=True,
)
# embeds = id_embeds['hidden_states'][-2] # penultimate layer
image_embeds = id_embeds['pooler_output']
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
return image_embeds
def forward(
self,
noise,
noise_pred,
noisy_latents,
timesteps,
batch: DataLoaderBatchDTO,
scheduler: CustomFlowMatchEulerDiscreteScheduler,
clip_weight=1.0,
mse_weight=0.0,
model=None
):
dtype = torch.bfloat16
device = self.vae.device
tensors = batch.tensor.to(device, dtype=dtype)
is_video = False
# stack time for video models on the batch dimension
if len(noise_pred.shape) == 5:
# B, C, T, H, W = images.shape
# only take first time
noise = noise[:, :, 0, :, :]
noise_pred = noise_pred[:, :, 0, :, :]
noisy_latents = noisy_latents[:, :, 0, :, :]
is_video = True
if len(tensors.shape) == 5:
# batch is different
# (B, T, C, H, W)
# only take first time
tensors = tensors[:, 0, :, :, :]
if model is not None and hasattr(model, 'get_stepped_pred'):
stepped_latents = model.get_stepped_pred(noise_pred, noise)
else:
# stepped_latents = noise - noise_pred
# first we step the scheduler from current timestep to the very end for a full denoise
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
scheduler._step_index = None
scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0
shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0
latents = (latents / scaling_factor) + shift_factor
if is_video:
# if video, we need to unsqueeze the latents to match the vae input shape
latents = latents.unsqueeze(2)
tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1
if is_video:
# if video, we need to squeeze the tensors to match the output shape
tensors_n1p1 = tensors_n1p1.squeeze(2)
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
total_loss = 0
with torch.no_grad():
target_img = tensors.to(device, dtype=dtype)
# go from -1 to 1 to 0 to 1
target_img = (target_img + 1) / 2
if clip_weight > 0:
target_clip_output = self.get_siglip_features(target_img).detach()
if clip_weight > 0:
pred_clip_output = self.get_siglip_features(pred_images)
clip_loss = torch.nn.functional.mse_loss(
pred_clip_output.float(), target_clip_output.float()
) * clip_weight
if 'clip_loss' not in self.losses:
self.losses['clip_loss'] = clip_loss.item()
else:
self.losses['clip_loss'] += clip_loss.item()
total_loss += clip_loss
if mse_weight > 0:
mse_loss = torch.nn.functional.mse_loss(
pred_images.float(), target_img.float()
) * mse_weight
if 'mse_loss' not in self.losses:
self.losses['mse_loss'] = mse_loss.item()
else:
self.losses['mse_loss'] += mse_loss.item()
total_loss += mse_loss
if self.step % self.log_every == 0 and self.step > 0:
print(f"DFE losses:")
for key in self.losses:
self.losses[key] /= self.log_every
# print in 2.000e-01 format
print(f" - {key}: {self.losses[key]:.3e}")
self.losses[key] = 0.0
# total_loss += mse_loss
self.step += 1
return total_loss
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
if model_path == "v3":
dfe = DiffusionFeatureExtractor3(vae=vae)
dfe.eval()
return dfe
if model_path == "v4":
dfe = DiffusionFeatureExtractor4(vae=vae)
dfe.eval()
return dfe
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# if it ende with safetensors

View File

@@ -0,0 +1,865 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import logging
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.activations import get_activation
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
import copy
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class WanRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class WanUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class WanResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class WanResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = WanRMS_norm(in_dim, images=False)
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = WanRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h
class WanAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = WanRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(q, k, v)
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class WanMidBlock(nn.Module):
"""
Middle block for WanVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
super().__init__()
self.dim = dim
# Create the components
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(WanAttentionBlock(dim))
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
return x
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class WanUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: Optional[str] = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
class WanDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
in_dim = in_dim // 2
# Determine if we need upsampling
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
# Create and add the upsampling block
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx)
else:
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: List[float] = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
],
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
],
) -> None:
super().__init__()
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
self.encoder = WanEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
)
def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
self._conv_num = _count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor) -> torch.Tensor:
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
enc = torch.cat([mu, logvar], dim=1)
self.clear_cache()
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache()
iter_ = z.shape[2]
x = self.post_quant_conv(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
if not return_dict:
return (out,)
return DecoderOutput(sample=out)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec

View File

@@ -9,7 +9,8 @@ from toolkit.dequantize import patch_dequantization_on_save
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKL
from .autoencoder_kl_wan import AutoencoderKLWan
import os
import sys

View File

@@ -4,6 +4,7 @@ from torch.distributions import LogNormal
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
import numpy as np
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
def calculate_shift(
@@ -47,20 +48,26 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
hbsmntw_weighing[num_timesteps //
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
# Create linear timesteps from 1000 to 1
timesteps = torch.linspace(1000, 1, num_timesteps, device='cpu')
self.linear_timesteps = timesteps
self.linear_timesteps_weights = bsmntw_weighing
self.linear_timesteps_weights2 = hbsmntw_weighing
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item()
for t in timesteps]
# Get the weights for the timesteps
if timestep_type == "weighted":
weights = torch.tensor(
[default_weighing_scheme[i] for i in step_indices],
device=timesteps.device,
dtype=timesteps.dtype
)
if v2:
weights = self.linear_timesteps_weights2[step_indices].flatten()
else:
@@ -106,8 +113,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
patch_size=1
):
self.timestep_type = timestep_type
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
if timestep_type == 'linear' or timestep_type == 'weighted':
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
elif timestep_type == 'sigmoid':
@@ -198,7 +205,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(
t2 = torch.linspace(1000, 1, int(
num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))

View File

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 190 KiB

View File

@@ -33,7 +33,6 @@ export default function SimpleJob({
gpuList,
datasetOptions,
}: Props) {
const modelArch = useMemo(() => {
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]);
@@ -104,8 +103,7 @@ export default function SimpleJob({
const newDataset = objectCopy(dataset);
newDataset.controls = controls;
return newDataset;
}
);
});
setJobConfig(datasets, 'config.process[0].datasets');
}}
options={
@@ -131,20 +129,22 @@ export default function SimpleJob({
placeholder=""
required
/>
<FormGroup label="Quantize">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
{modelArch?.disableSections?.includes('model.quantize') ? null : (
<FormGroup label="Quantize">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
)}
</Card>
<Card title="Target Configuration">
<SelectInput
@@ -171,19 +171,37 @@ export default function SimpleJob({
/>
)}
{jobConfig.config.process[0].network?.type == 'lora' && (
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.linear');
setJobConfig(value, 'config.process[0].network.linear_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
required
/>
<>
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.linear');
setJobConfig(value, 'config.process[0].network.linear_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
required
/>
{
modelArch?.disableSections?.includes('network.conv') ? null : (
<NumberInput
label="Conv Rank"
value={jobConfig.config.process[0].network.conv}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.conv');
setJobConfig(value, 'config.process[0].network.conv_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
/>
)
}
</>
)}
</Card>
<Card title="Save Configuration">
@@ -276,16 +294,19 @@ export default function SimpleJob({
/>
</div>
<div>
<SelectInput
label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'shift', label: 'Shift' },
]}
/>
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
<SelectInput
label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type}
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'shift', label: 'Shift' },
]}
/>
)}
<SelectInput
label="Timestep Bias"
className="pt-2"
@@ -345,7 +366,7 @@ export default function SimpleJob({
/>
</FormGroup>
<NumberInput
label="DFE Loss Multiplier"
label="DOP Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
@@ -353,7 +374,7 @@ export default function SimpleJob({
min={0}
/>
<TextInput
label="DFE Preservation Class"
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
@@ -466,10 +487,7 @@ export default function SimpleJob({
// automaticallt add the controls for a new dataset
const controls = modelArch?.controls ?? [];
newDataset.controls = controls;
setJobConfig(
[...jobConfig.config.process[0].datasets, newDataset],
'config.process[0].datasets',
)
setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets');
}}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>

View File

@@ -30,6 +30,8 @@ export const defaultJobConfig: JobConfig = {
type: 'lora',
linear: 32,
linear_alpha: 32,
conv: 16,
conv_alpha: 16,
lokr_full_rank: true,
lokr_factor: -1,
network_kwargs: {

View File

@@ -1,4 +1,3 @@
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
export interface ModelArch {
@@ -6,11 +5,14 @@ export interface ModelArch {
label: string;
controls?: Control[];
isVideoModel?: boolean;
defaults?: { [key: string]: [any, any] };
defaults?: { [key: string]: any };
disableSections?: DisableableSections[];
}
const defaultNameOrPath = '';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
export const modelArchs: ModelArch[] = [
{
name: 'flux',
@@ -23,6 +25,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'flex1',
@@ -36,6 +39,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'flex2',
@@ -62,6 +66,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'chroma',
@@ -74,6 +79,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'wan21:1b',
@@ -89,6 +95,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
},
{
name: 'wan21:14b',
@@ -96,7 +103,7 @@ export const modelArchs: ModelArch[] = [
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffuserss', defaultNameOrPath],
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffusers', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
@@ -104,6 +111,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
},
{
name: 'lumina2',
@@ -116,6 +124,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
disableSections: ['network.conv'],
},
{
name: 'hidream',
@@ -131,5 +140,37 @@ export const modelArchs: ModelArch[] = [
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
},
disableSections: ['network.conv'],
},
];
{
name: 'sdxl',
label: 'SDXL',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [false, false],
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
'config.process[0].sample.guidance_scale': [6, 4],
},
disableSections: ['model.quantize', 'train.timestep_type'],
},
{
name: 'sd15',
label: 'SD 1.5',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
'config.process[0].sample.width': [512, 1024],
'config.process[0].sample.height': [512, 1024],
'config.process[0].sample.guidance_scale': [6, 4],
},
disableSections: ['model.quantize', 'train.timestep_type'],
},
].sort((a, b) => {
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
}) as any;

View File

@@ -11,11 +11,12 @@ const Sidebar = () => {
];
return (
<div className="flex flex-col w-64 bg-gray-900 text-gray-100">
<div className="p-4">
<h1 className="text-xl">
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-8 mr-3 inline" />
Ostris - AI Toolkit
<div className="flex flex-col w-59 bg-gray-900 text-gray-100">
<div className="px-4 py-3">
<h1 className="text-l">
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-7 mr-3 inline" />
<span className="font-bold uppercase">Ostris</span>
<span className='ml-2 uppercase text-gray-300'>AI-Toolkit</span>
</h1>
</div>
<nav className="flex-1">
@@ -33,31 +34,18 @@ const Sidebar = () => {
))}
</ul>
</nav>
<a href="https://patreon.com/ostris" target="_blank" rel="noreferrer" className="flex items-center space-x-2 p-4">
<a href="https://ostris.com/support" target="_blank" rel="noreferrer" className="flex items-center space-x-2 px-4 py-3">
<div className="min-w-[26px] min-h-[26px]">
<svg
viewBox="0 0 512 512"
xmlns="http://www.w3.org/2000/svg"
fillRule="evenodd"
clipRule="evenodd"
strokeLinejoin="round"
strokeMiterlimit="2"
>
<g transform="matrix(.47407 0 0 .47407 .383 .422)">
<clipPath id="prefix__a">
<path d="M0 0h1080v1080H0z"></path>
</clipPath>
<g clipPath="url(#prefix__a)">
<path
d="M1033.05 324.45c-.19-137.9-107.59-250.92-233.6-291.7-156.48-50.64-362.86-43.3-512.28 27.2-181.1 85.46-237.99 272.66-240.11 459.36-1.74 153.5 13.58 557.79 241.62 560.67 169.44 2.15 194.67-216.18 273.07-321.33 55.78-74.81 127.6-95.94 216.01-117.82 151.95-37.61 255.51-157.53 255.29-316.38z"
fillRule="nonzero"
fill="#ffffff"
></path>
</g>
<svg height="24" version="1.1" width="24" xmlns="http://www.w3.org/2000/svg">
<g transform="translate(0 -1028.4)">
<path
d="m7 1031.4c-1.5355 0-3.0784 0.5-4.25 1.7-2.3431 2.4-2.2788 6.1 0 8.5l9.25 9.8 9.25-9.8c2.279-2.4 2.343-6.1 0-8.5-2.343-2.3-6.157-2.3-8.5 0l-0.75 0.8-0.75-0.8c-1.172-1.2-2.7145-1.7-4.25-1.7z"
fill="#c0392b"
/>
</g>
</svg>
</div>
<div className="text-gray-500 text-md mb-2 flex-1 pt-2 pl-2">Support me on Patreon</div>
<div className="uppercase text-gray-500 text-sm mb-2 flex-1 pt-2 pl-0">Support AI-Toolkit</div>
</a>
</div>
);

View File

@@ -2,8 +2,8 @@
import React, { forwardRef } from 'react';
import classNames from 'classnames';
import dynamic from "next/dynamic";
const Select = dynamic(() => import("react-select"), { ssr: false });
import dynamic from 'next/dynamic';
const Select = dynamic(() => import('react-select'), { ssr: false });
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
const inputClasses =
@@ -42,7 +42,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
/>
</div>
);
}
},
);
// 👇 Helpful for debugging
@@ -114,6 +114,7 @@ export const NumberInput = (props: NumberInputProps) => {
export interface SelectInputProps extends InputProps {
value: string;
disabled?: boolean;
onChange: (value: string) => void;
options: { value: string; label: string }[];
}
@@ -122,11 +123,16 @@ export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options } = props;
const selectedOption = options.find(option => option.value === value);
return (
<div className={classNames(props.className)}>
<div
className={classNames(props.className, {
'opacity-30 cursor-not-allowed': props.disabled,
})}
>
{label && <label className={labelClasses}>{label}</label>}
<Select
value={selectedOption}
<Select
value={selectedOption}
options={options}
isDisabled={props.disabled}
className="aitk-react-select-container"
classNamePrefix="aitk-react-select"
onChange={selected => {

View File

@@ -53,6 +53,8 @@ export interface NetworkConfig {
type: string;
linear: number;
linear_alpha: number;
conv: number;
conv_alpha: number;
lokr_full_rank: boolean;
lokr_factor: number;
network_kwargs: {

View File

@@ -1 +1 @@
VERSION = "0.2.9"
VERSION = "0.2.10"