diff --git a/.vscode/launch.json b/.vscode/launch.json index 02d5cacf..abbc2f3d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -40,5 +40,17 @@ "console": "integratedTerminal", "justMyCode": false }, + { + "name": "Python: Debug Current File (cuda:1)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "env": { + "CUDA_LAUNCH_BLOCKING": "1", + "CUDA_VISIBLE_DEVICES": "1" + }, + "justMyCode": false + }, ] } \ No newline at end of file diff --git a/README.md b/README.md index 4509fbad..5e68605c 100644 --- a/README.md +++ b/README.md @@ -414,3 +414,12 @@ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https:/ Everything else should work the same including layer targeting. + +## Updates + +### June 10, 2024 +- Decided to keep track up updates in the readme +- Added support for SDXL in the UI +- Added support for SD 1.5 in the UI +- Fixed UI Wan 2.1 14b name bug +- Added support for for conv training in the UI for models that support it \ No newline at end of file diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 7468e0fc..5d0105f2 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -438,7 +438,7 @@ class SDTrainer(BaseSDTrainProcess): dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 - elif self.dfe.version == 3: + elif self.dfe.version == 3 or self.dfe.version == 4: dfe_loss = self.dfe( noise=noise, noise_pred=noise_pred, @@ -501,15 +501,27 @@ class SDTrainer(BaseSDTrainProcess): loss = wavelet_loss(pred, batch.latents, noise) else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + do_weighted_timesteps = False + if self.sd.is_flow_matching: + if self.train_config.linear_timesteps or self.train_config.linear_timesteps2: + do_weighted_timesteps = True + if self.train_config.timestep_type == "weighted": + # use the noise scheduler to get the weights for the timesteps + do_weighted_timesteps = True # handle linear timesteps and only adjust the weight of the timesteps - if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2): + if do_weighted_timesteps: # calculate the weights for the timesteps timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( timesteps, - v2=self.train_config.linear_timesteps2 + v2=self.train_config.linear_timesteps2, + timestep_type=self.train_config.timestep_type ).to(loss.device, dtype=loss.dtype) - timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + if len(loss.shape) == 4: + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + elif len(loss.shape) == 5: + timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() loss = loss * timestep_weight if self.train_config.do_prior_divergence and prior_pred is not None: @@ -764,6 +776,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds: Union[PromptEmbeds, None] = None, unconditional_embeds: Union[PromptEmbeds, None] = None, batch: Optional['DataLoaderBatchDTO'] = None, + is_primary_pred: bool = False, **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) @@ -1553,6 +1566,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, batch=batch, + is_primary_pred=True, **pred_kwargs ) self.after_unet_predict() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b4f768d9..064e89fc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1116,6 +1116,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.train_config.linear_timesteps, self.train_config.linear_timesteps2, self.train_config.timestep_type == 'linear', + self.train_config.timestep_type == 'one_step', ]) timestep_type = 'linear' if linear_timesteps else None @@ -1159,6 +1160,8 @@ class BaseSDTrainProcess(BaseTrainProcess): device=self.device_torch ) timestep_indices = timestep_indices.long() + elif self.train_config.timestep_type == 'one_step': + timestep_indices = torch.zeros((batch_size,), device=self.device_torch, dtype=torch.long) elif content_or_style in ['style', 'content']: # this is from diffusers training code # Cubic sampling for favoring later or earlier timesteps diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 0b6d5fab..4530b7cd 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess from toolkit.image_utils import show_tensors from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset -from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses @@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess): else: return torch.tensor(0.0, device=self.device) - def get_ltv_loss(self, latent): + def get_ltv_loss(self, latent, images): # loss to reduce the latent space variance if self.ltv_weight > 0: - return total_variation(latent).mean() + with torch.no_grad(): + images = images.to(latent.device, dtype=latent.dtype) + # resize down to latent size + images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False) + + # mean the color channel and then expand to latent size + images = images.mean(dim=1, keepdim=True) + images = images.repeat(1, latent.shape[1], 1, 1) + + # normalize to a mean of 0 and std of 1 + images_mean = images.mean(dim=(2, 3), keepdim=True) + images_std = images.std(dim=(2, 3), keepdim=True) + images = (images - images_mean) / (images_std + 1e-6) + + # now we target the same std of the image for the latent space as to not reduce to 0 + + latent_tv = torch.abs(total_variation_deltas(latent)) + images_tv = torch.abs(total_variation_deltas(images)) + loss = torch.abs(latent_tv - images_tv) # keep it spatially aware + loss = loss.mean(dim=2, keepdim=True) + loss = loss.mean(dim=3, keepdim=True) # mean over height and width + loss = loss.mean(dim=1, keepdim=True) # mean over channels + loss = loss.mean() + return loss else: return torch.tensor(0.0, device=self.device) @@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess): mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) if self.ltv_weight > 0: - ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight + ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight else: ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) diff --git a/requirements.txt b/requirements.txt index 264836a5..96ddce8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b -transformers==4.49.0 +transformers==4.52.4 lycoris-lora==1.8.3 flatten_json pyyaml diff --git a/scripts/calculate_timestep_weighing_flex.py b/scripts/calculate_timestep_weighing_flex.py new file mode 100644 index 00000000..05a21766 --- /dev/null +++ b/scripts/calculate_timestep_weighing_flex.py @@ -0,0 +1,228 @@ +import gc +import os, sys +from tqdm import tqdm +import numpy as np +import json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# set visible devices to 0 +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# protect from formatting +if True: + import torch + from optimum.quanto import freeze, qfloat8, QTensor, qint4 + from diffusers import FluxTransformer2DModel, FluxPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler + from toolkit.util.quantize import quantize, get_qtype + from transformers import T5EncoderModel, T5TokenizerFast, CLIPTextModel, CLIPTokenizer + from torchvision import transforms + +qtype = "qfloat8" +dtype = torch.bfloat16 +# base_model_path = "black-forest-labs/FLUX.1-dev" +base_model_path = "ostris/Flex.1-alpha" +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("Loading Transformer...") +prompt = "Photo of a man and a woman in a park, sunny day" + +output_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output") +output_path = os.path.join(output_root, "flex_timestep_weights.json") +img_output_path = os.path.join(output_root, "flex_timestep_weights.png") + +quantization_type = get_qtype(qtype) + +def flush(): + torch.cuda.empty_cache() + gc.collect() + +pil_to_tensor = transforms.ToTensor() + +with torch.no_grad(): + transformer = FluxTransformer2DModel.from_pretrained( + base_model_path, + subfolder='transformer', + torch_dtype=dtype + ) + + transformer.to(device, dtype=dtype) + + print("Quantizing Transformer...") + quantize(transformer, weights=quantization_type) + freeze(transformer) + flush() + + print("Loading Scheduler...") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + + print("Loading Autoencoder...") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + + vae.to(device, dtype=dtype) + + flush() + print("Loading Text Encoder...") + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) + text_encoder_2.to(device, dtype=dtype) + + print("Quantizing Text Encoder...") + quantize(text_encoder_2, weights=get_qtype(qtype)) + freeze(text_encoder_2) + flush() + + print("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(device, dtype=dtype) + + print("Making pipe") + + pipe: FluxPipeline = FluxPipeline( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + pipe.to(device, dtype=dtype) + + print("Encoding prompt...") + + prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( + prompt, + prompt_2=prompt, + device=device + ) + + + generator = torch.manual_seed(42) + + height = 1024 + width = 1024 + + print("Generating image...") + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + latents = callback_kwargs["latents"] + if latents.dtype != dtype: + latents = latents.to(dtype) + return {"latents": latents} + img = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + height=height, + width=height, + num_inference_steps=30, + guidance_scale=3.5, + generator=generator, + callback_on_step_end=callback_on_step_end, + ).images[0] + + img.save(img_output_path) + print(f"Image saved to {img_output_path}") + + print("Encoding image...") + # img is a PIL image. convert it to a -1 to 1 tensor + img = pil_to_tensor(img) + img = img.unsqueeze(0) # add batch dimension + img = img * 2 - 1 # convert to -1 to 1 range + img = img.to(device, dtype=dtype) + latents = vae.encode(img).latent_dist.sample() + + shift = vae.config['shift_factor'] if vae.config['shift_factor'] is not None else 0 + latents = vae.config['scaling_factor'] * (latents - shift) + + num_channels_latents = pipe.transformer.config.in_channels // 4 + + l_height = 2 * (int(height) // (pipe.vae_scale_factor * 2)) + l_width = 2 * (int(width) // (pipe.vae_scale_factor * 2)) + packed_latents = pipe._pack_latents(latents, 1, num_channels_latents, l_height, l_width) + + packed_latents, latent_image_ids = pipe.prepare_latents( + 1, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + packed_latents, + ) + + print("Calculating timestep weights...") + + torch.manual_seed(8675309) + noise = torch.randn_like(packed_latents, device=device, dtype=dtype) + + # Create linear timesteps from 1000 to 0 + num_train_timesteps = 1000 + timesteps_torch = torch.linspace(1000, 1, num_train_timesteps, device='cpu') + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + timestep_weights = torch.zeros(num_train_timesteps, dtype=torch.float32, device=device) + + guidance = torch.full([1], 1.0, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + pbar = tqdm(range(num_train_timesteps), desc="loss: 0.000000 scaler: 0.0000") + for i in pbar: + timestep = timesteps[i:i+1].to(device) + t_01 = (timestep / 1000).to(device) + t_01 = t_01.reshape(-1, 1, 1) + noisy_latents = (1.0 - t_01) * packed_latents + t_01 * noise + + noise_pred = pipe.transformer( + hidden_states=noisy_latents, # torch.Size([1, 4096, 64]) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + target = noise - packed_latents + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float()) + loss = loss + + # determine scaler to multiply loss by to make it 1 + scaler = 1.0 / (loss + 1e-6) + + timestep_weights[i] = scaler + pbar.set_description(f"loss: {loss.item():.6f} scaler: {scaler.item():.4f}") + + print("normalizing timestep weights...") + # normalize the timestep weights so they are a mean of 1.0 + timestep_weights = timestep_weights / timestep_weights.mean() + timestep_weights = timestep_weights.cpu().numpy().tolist() + + print("Saving timestep weights...") + + with open(output_path, 'w') as f: + json.dump(timestep_weights, f) + + +print(f"Timestep weights saved to {output_path}") +print("Done!") +flush() + + + + + + + + + + + + \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index aa3d84a6..0e8b8af0 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -437,7 +437,7 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) - self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample, weighted, one_step self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8) self.linear_timesteps = kwargs.get('linear_timesteps', False) self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 7985bfe7..5f419ff0 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1773,6 +1773,97 @@ class LatentCachingMixin: self.sd.restore_device_state() + +class TextEmbeddingCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings + + def cache_text_embeddings(self: 'AiToolkitDataset'): + + with accelerator.main_process_first(): + print_acc(f"Caching text_embeddings for {self.dataset_path}") + # cache all latents to disk + to_disk = self.is_caching_latents_to_disk + to_memory = self.is_caching_latents_to_memory + print_acc(" - Saving text embeddings to disk") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_latents') + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): + # set latent space version + if self.sd.model_config.latent_space_version is not None: + file_item.latent_space_version = self.sd.model_config.latent_space_version + elif self.sd.is_xl: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_v3: + file_item.latent_space_version = 'sd3' + elif self.sd.is_auraflow: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_flux: + file_item.latent_space_version = 'flux1' + elif self.sd.model_config.is_pixart_sigma: + file_item.latent_space_version = 'sdxl' + else: + file_item.latent_space_version = self.sd.model_config.arch + file_item.is_caching_to_disk = to_disk + file_item.is_caching_to_memory = to_memory + file_item.latent_load_device = self.sd.device + + latent_path = file_item.get_latent_path(recalculate=True) + # check if it is saved to disk already + if os.path.exists(latent_path): + if to_memory: + # load it into memory + state_dict = load_file(latent_path, device='cpu') + file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + else: + # not saved to disk, calculate + # load the image first + file_item.load_and_process_image(self.transform, only_load_latents=True) + dtype = self.sd.torch_dtype + device = self.sd.device_torch + # add batch dimension + try: + imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + latent = self.sd.encode_images(imgs).squeeze(0) + except Exception as e: + print_acc(f"Error processing image: {file_item.path}") + print_acc(f"Error: {str(e)}") + raise e + # save_latent + if to_disk: + state_dict = OrderedDict([ + ('latent', latent.clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) + os.makedirs(os.path.dirname(latent_path), exist_ok=True) + save_file(state_dict, latent_path, metadata=meta) + + if to_memory: + # keep it in memory + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) + + del imgs + del latent + del file_item.tensor + + # flush(garbage_collect=False) + file_item.is_latent_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() + + # restore device state + self.sd.restore_device_state() + + class CLIPCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): # if we have super, call it diff --git a/toolkit/ema.py b/toolkit/ema.py index e3b3a7ea..b34554bb 100644 --- a/toolkit/ema.py +++ b/toolkit/ema.py @@ -137,7 +137,8 @@ class ExponentialMovingAverage: update_param = False if self.use_feedback: - param_float.add_(tmp) + # make feedback 10x decay + param_float.add_(tmp * 10) update_param = True if self.param_multiplier != 1.0: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index b0c2d7f4..cd443735 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -7,7 +7,7 @@ import re import sys from typing import List, Optional, Dict, Type, Union import torch -from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel +from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel from transformers import CLIPTextModel from toolkit.models.lokr import LokrModule @@ -522,6 +522,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): transformer.pos_embed = self.transformer_pos_embed transformer.proj_out = self.transformer_proj_out + + elif base_model is not None and base_model.arch == "wan21": + transformer: WanTransformer3DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.patch_embedding = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out else: unet: UNet2DConditionModel = unet @@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) if self.full_train_in_out: - if self.is_pixart or self.is_auraflow or self.is_flux: + base_model = self.base_model_ref() if self.base_model_ref is not None else None + if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"): all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) else: diff --git a/toolkit/losses.py b/toolkit/losses.py index eeea3571..fef9310d 100644 --- a/toolkit/losses.py +++ b/toolkit/losses.py @@ -13,6 +13,22 @@ def total_variation(image): n_elements = image.shape[1] * image.shape[2] * image.shape[3] return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) + +def total_variation_deltas(image): + """ + Compute per-pixel total variation deltas. + Input: + - image: Tensor of shape (N, C, H, W) + Returns: + - Tensor with shape (N, C, H, W), padded to match input shape + """ + dh = torch.zeros_like(image) + dv = torch.zeros_like(image) + + dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1]) + dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :]) + + return dh + dv class ComparativeTotalVariation(torch.nn.Module): diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 17b259e6..01d7f278 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -1,3 +1,4 @@ +import math import torch import os from torch import nn @@ -351,12 +352,252 @@ class DiffusionFeatureExtractor3(nn.Module): return total_loss +class DiffusionFeatureExtractor4(nn.Module): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__() + self.version = 4 + if vae is None: + raise ValueError("vae must be provided for DFE4") + self.vae = vae + # image_encoder_path = "google/siglip-so400m-patch14-384" + image_encoder_path = "google/siglip2-so400m-patch16-naflex" + from transformers import Siglip2ImageProcessor, Siglip2VisionModel + try: + self.image_processor = Siglip2ImageProcessor.from_pretrained( + image_encoder_path) + except EnvironmentError: + self.image_processor = Siglip2ImageProcessor() + + self.image_processor.max_num_patches = 1024 + + self.vision_encoder = Siglip2VisionModel.from_pretrained( + image_encoder_path, + ignore_mismatched_sizes=True + ).to(device, dtype=dtype) + + self.losses = {} + self.log_every = 100 + self.step = 0 + + def _target_hw(self, h, w, patch, max_patches, eps: float = 1e-5): + def _snap(x, s): + x = math.ceil((x * s) / patch) * patch + return max(patch, int(x)) + + lo, hi = eps / 10, 1.0 + while hi - lo >= eps: + mid = (lo + hi) / 2 + th, tw = _snap(h, mid), _snap(w, mid) + if (th // patch) * (tw // patch) <= max_patches: + lo = mid + else: + hi = mid + return _snap(h, lo), _snap(w, lo) + + + def tensors_to_siglip_like_features(self, batch: torch.Tensor): + """ + Args: + batch: (bs, 3, H, W) tensor already in the desired value range + (e.g. [-1, 1] or [0, 1]); no extra rescale / normalize here. + + Returns: + dict( + pixel_values – (bs, L, P) where L = n_h*n_w, P = 3*patch*patch + pixel_attention_mask– (L,) all-ones + spatial_shapes – (n_h, n_w) + ) + """ + if batch.ndim != 4: + raise ValueError("Expected (bs, 3, H, W) tensor") + + bs, c, H, W = batch.shape + proc = self.image_processor + patch = proc.patch_size + max_patches = proc.max_num_patches + + # One shared resize for the whole batch + tgt_h, tgt_w = self._target_hw(H, W, patch, max_patches) + batch = torch.nn.functional.interpolate( + batch, size=(tgt_h, tgt_w), mode="bilinear", align_corners=False + ) + + n_h, n_w = tgt_h // patch, tgt_w // patch + # flat_dim = c * patch * patch + num_p = n_h * n_w + + # unfold → (bs, flat_dim, num_p) → (bs, num_p, flat_dim) + patches = ( + torch.nn.functional.unfold(batch, kernel_size=patch, stride=patch) + .transpose(1, 2) + ) + + attn_mask = torch.ones(num_p, dtype=torch.long, device=batch.device) + spatial = torch.tensor((n_h, n_w), device=batch.device, dtype=torch.int32) + + # repeat attn_mask for each batch element + attn_mask = attn_mask.unsqueeze(0).repeat(bs, 1) + spatial = spatial.unsqueeze(0).repeat(bs, 1) + + return { + "pixel_values": patches, # (bs, num_patches, patch_dim) + "pixel_attention_mask": attn_mask, # (num_patches,) + "spatial_shapes": spatial + } + + def get_siglip_features(self, tensors_0_1): + dtype = torch.bfloat16 + device = self.vae.device + + tensors_0_1 = torch.clamp(tensors_0_1, 0.0, 1.0) + + mean = torch.tensor(self.image_processor.image_mean).to( + device, dtype=dtype + ).detach() + std = torch.tensor(self.image_processor.image_std).to( + device, dtype=dtype + ).detach() + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + encoder_kwargs = self.tensors_to_siglip_like_features(clip_image) + id_embeds = self.vision_encoder( + pixel_values=encoder_kwargs['pixel_values'], + pixel_attention_mask=encoder_kwargs['pixel_attention_mask'], + spatial_shapes=encoder_kwargs['spatial_shapes'], + output_hidden_states=True, + ) + + # embeds = id_embeds['hidden_states'][-2] # penultimate layer + image_embeds = id_embeds['pooler_output'] + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + return image_embeds + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + clip_weight=1.0, + mse_weight=0.0, + model=None + ): + dtype = torch.bfloat16 + device = self.vae.device + tensors = batch.tensor.to(device, dtype=dtype) + is_video = False + # stack time for video models on the batch dimension + if len(noise_pred.shape) == 5: + # B, C, T, H, W = images.shape + # only take first time + noise = noise[:, :, 0, :, :] + noise_pred = noise_pred[:, :, 0, :, :] + noisy_latents = noisy_latents[:, :, 0, :, :] + is_video = True + + if len(tensors.shape) == 5: + # batch is different + # (B, T, C, H, W) + # only take first time + tensors = tensors[:, 0, :, :, :] + + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) + + scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0 + shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0 + latents = (latents / scaling_factor) + shift_factor + if is_video: + # if video, we need to unsqueeze the latents to match the vae input shape + latents = latents.unsqueeze(2) + tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1 + + if is_video: + # if video, we need to squeeze the tensors to match the output shape + tensors_n1p1 = tensors_n1p1.squeeze(2) + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + total_loss = 0 + + with torch.no_grad(): + target_img = tensors.to(device, dtype=dtype) + # go from -1 to 1 to 0 to 1 + target_img = (target_img + 1) / 2 + if clip_weight > 0: + target_clip_output = self.get_siglip_features(target_img).detach() + if clip_weight > 0: + pred_clip_output = self.get_siglip_features(pred_images) + clip_loss = torch.nn.functional.mse_loss( + pred_clip_output.float(), target_clip_output.float() + ) * clip_weight + + if 'clip_loss' not in self.losses: + self.losses['clip_loss'] = clip_loss.item() + else: + self.losses['clip_loss'] += clip_loss.item() + + total_loss += clip_loss + if mse_weight > 0: + mse_loss = torch.nn.functional.mse_loss( + pred_images.float(), target_img.float() + ) * mse_weight + + if 'mse_loss' not in self.losses: + self.losses['mse_loss'] = mse_loss.item() + else: + self.losses['mse_loss'] += mse_loss.item() + + total_loss += mse_loss + + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return total_loss def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: if model_path == "v3": dfe = DiffusionFeatureExtractor3(vae=vae) dfe.eval() return dfe + if model_path == "v4": + dfe = DiffusionFeatureExtractor4(vae=vae) + dfe.eval() + return dfe if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # if it ende with safetensors diff --git a/toolkit/models/wan21/autoencoder_kl_wan.py b/toolkit/models/wan21/autoencoder_kl_wan.py new file mode 100644 index 00000000..4f5b6ebd --- /dev/null +++ b/toolkit/models/wan21/autoencoder_kl_wan.py @@ -0,0 +1,865 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.activations import get_activation +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +import copy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class WanCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class WanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class WanUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class WanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class WanResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class WanAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class WanMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim)) + resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class WanEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class WanUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class WanDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx) + + else: + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLWan(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Wan 2.1]. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = WanEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, WanCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] + enc = torch.cat([mu, logvar], dim=1) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + self.clear_cache() + + iter_ = z.shape[2] + x = self.post_quant_conv(z) + for i in range(iter_): + + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + decoded = self._decode(z).sample + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 57b556ed..9f029c7b 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -9,7 +9,8 @@ from toolkit.dequantize import patch_dequantization_on_save from toolkit.models.base_model import BaseModel from toolkit.prompt_utils import PromptEmbeds from transformers import AutoTokenizer, UMT5EncoderModel -from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel +from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKL +from .autoencoder_kl_wan import AutoencoderKLWan import os import sys diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 1e0ae2ab..bac7f3fd 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -4,6 +4,7 @@ from torch.distributions import LogNormal from diffusers import FlowMatchEulerDiscreteScheduler import torch import numpy as np +from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme def calculate_shift( @@ -47,20 +48,26 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() - # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + # Create linear timesteps from 1000 to 1 + timesteps = torch.linspace(1000, 1, num_timesteps, device='cpu') self.linear_timesteps = timesteps self.linear_timesteps_weights = bsmntw_weighing self.linear_timesteps_weights2 = hbsmntw_weighing pass - def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: + def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor: # Get the indices of the timesteps step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] # Get the weights for the timesteps + if timestep_type == "weighted": + weights = torch.tensor( + [default_weighing_scheme[i] for i in step_indices], + device=timesteps.device, + dtype=timesteps.dtype + ) if v2: weights = self.linear_timesteps_weights2[step_indices].flatten() else: @@ -106,8 +113,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): patch_size=1 ): self.timestep_type = timestep_type - if timestep_type == 'linear': - timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + if timestep_type == 'linear' or timestep_type == 'weighted': + timesteps = torch.linspace(1000, 1, num_timesteps, device=device) self.timesteps = timesteps return timesteps elif timestep_type == 'sigmoid': @@ -198,7 +205,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): t1 = ((1 - t1/t1.max()) * 1000) # add half of linear - t2 = torch.linspace(1000, 0, int( + t2 = torch.linspace(1000, 1, int( num_timesteps * (1 - alpha)), device=device) timesteps = torch.cat((t1, t2)) diff --git a/toolkit/timestep_weighing/__init__.py b/toolkit/timestep_weighing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/toolkit/timestep_weighing/default_weighing_scheme.py b/toolkit/timestep_weighing/default_weighing_scheme.py new file mode 100644 index 00000000..72699744 --- /dev/null +++ b/toolkit/timestep_weighing/default_weighing_scheme.py @@ -0,0 +1,1004 @@ +# these weights were calculated using flex.1-alpha. A similar weighing scheme has been seen with other flowmatch models as well. + +default_weighing_scheme = [ + 0.905706524848938, + 0.9097874164581299, + 0.9251009821891785, + 0.9399133920669556, + 0.9497355818748474, + 0.962195873260498, + 0.9691638946533203, + 0.9927358627319336, + 1.01659095287323, + 1.0392684936523438, + 1.0429067611694336, + 1.0677919387817383, + 1.092896580696106, + 1.1149251461029053, + 1.1015851497650146, + 1.1209120750427246, + 1.1399472951889038, + 1.1559219360351562, + 1.143755555152893, + 1.160578727722168, + 1.1761468648910522, + 1.1895899772644043, + 1.1867371797561646, + 1.1993824243545532, + 1.2113513946533203, + 1.2201216220855713, + 1.221794605255127, + 1.2329251766204834, + 1.2426395416259766, + 1.251118540763855, + 1.2562459707260132, + 1.2668883800506592, + 1.2760146856307983, + 1.2822540998458862, + 1.2892439365386963, + 1.2972776889801025, + 1.3042182922363281, + 1.3104143142700195, + 1.3190876245498657, + 1.3253265619277954, + 1.3295484781265259, + 1.3324649333953857, + 1.3436106443405151, + 1.3482736349105835, + 1.351183533668518, + 1.3556060791015625, + 1.359933614730835, + 1.3625402450561523, + 1.3641390800476074, + 1.3723753690719604, + 1.3748365640640259, + 1.3776681423187256, + 1.3802807331085205, + 1.3848408460617065, + 1.3877429962158203, + 1.390969157218933, + 1.3927756547927856, + 1.4031920433044434, + 1.4064390659332275, + 1.4097232818603516, + 1.4132285118103027, + 1.4208053350448608, + 1.4254504442214966, + 1.4294012784957886, + 1.4323071241378784, + 1.44380521774292, + 1.4465110301971436, + 1.4490883350372314, + 1.451446533203125, + 1.458960771560669, + 1.460111379623413, + 1.4603426456451416, + 1.4595434665679932, + 1.4694839715957642, + 1.470526099205017, + 1.4705318212509155, + 1.4697961807250977, + 1.477265477180481, + 1.4771337509155273, + 1.4768965244293213, + 1.4760032892227173, + 1.4871419668197632, + 1.4877341985702515, + 1.488408088684082, + 1.489357829093933, + 1.4884581565856934, + 1.4882304668426514, + 1.4878745079040527, + 1.4949846267700195, + 1.4957915544509888, + 1.4951081275939941, + 1.4950779676437378, + 1.4995663166046143, + 1.4994632005691528, + 1.4986882209777832, + 1.4976568222045898, + 1.5038518905639648, + 1.503743052482605, + 1.502982258796692, + 1.502566933631897, + 1.5068002939224243, + 1.506575584411621, + 1.5067178010940552, + 1.5050104856491089, + 1.508437991142273, + 1.5073161125183105, + 1.506223440170288, + 1.504751205444336, + 1.512485384941101, + 1.5121595859527588, + 1.5116060972213745, + 1.5102864503860474, + 1.5156408548355103, + 1.5157924890518188, + 1.5142825841903687, + 1.5141233205795288, + 1.5191627740859985, + 1.518998622894287, + 1.5177574157714844, + 1.516713261604309, + 1.5186537504196167, + 1.5179635286331177, + 1.5159885883331299, + 1.5150446891784668, + 1.5226575136184692, + 1.5215232372283936, + 1.5201103687286377, + 1.5225480794906616, + 1.5215368270874023, + 1.520209789276123, + 1.5184354782104492, + 1.5220975875854492, + 1.5209708213806152, + 1.519667625427246, + 1.5177332162857056, + 1.5203157663345337, + 1.519271731376648, + 1.5175830125808716, + 1.5161852836608887, + 1.517741322517395, + 1.5163099765777588, + 1.5150357484817505, + 1.513055682182312, + 1.5154107809066772, + 1.513993501663208, + 1.5126358270645142, + 1.510875940322876, + 1.5125701427459717, + 1.510590672492981, + 1.5086392164230347, + 1.5068501234054565, + 1.5094420909881592, + 1.5080277919769287, + 1.5060293674468994, + 1.5040230751037598, + 1.504892110824585, + 1.5031654834747314, + 1.5013214349746704, + 1.499170184135437, + 1.500420331954956, + 1.4979915618896484, + 1.4959659576416016, + 1.494127631187439, + 1.4961415529251099, + 1.494597315788269, + 1.4926577806472778, + 1.4904958009719849, + 1.4913445711135864, + 1.489889144897461, + 1.4880430698394775, + 1.4873780012130737, + 1.485144019126892, + 1.4830609560012817, + 1.48123037815094, + 1.483134150505066, + 1.4808504581451416, + 1.4793643951416016, + 1.4773707389831543, + 1.4774049520492554, + 1.4752962589263916, + 1.4733281135559082, + 1.4712235927581787, + 1.470260500907898, + 1.4684780836105347, + 1.4657323360443115, + 1.4639259576797485, + 1.4644291400909424, + 1.4621423482894897, + 1.4604870080947876, + 1.4585931301116943, + 1.458650827407837, + 1.456871747970581, + 1.4548134803771973, + 1.4528228044509888, + 1.4519706964492798, + 1.4501826763153076, + 1.448083519935608, + 1.446316123008728, + 1.4454604387283325, + 1.4432463645935059, + 1.4412567615509033, + 1.4390138387680054, + 1.4388699531555176, + 1.437005877494812, + 1.4348008632659912, + 1.4327361583709717, + 1.4320861101150513, + 1.4304683208465576, + 1.4287059307098389, + 1.4264065027236938, + 1.424929141998291, + 1.4226492643356323, + 1.4205398559570312, + 1.4207192659378052, + 1.4187930822372437, + 1.4170515537261963, + 1.415097713470459, + 1.4142191410064697, + 1.412431001663208, + 1.4104442596435547, + 1.4082318544387817, + 1.4076896905899048, + 1.4057838916778564, + 1.4034852981567383, + 1.4016139507293701, + 1.4003900289535522, + 1.3983466625213623, + 1.3962912559509277, + 1.3940908908843994, + 1.3933204412460327, + 1.3910188674926758, + 1.3888089656829834, + 1.3865182399749756, + 1.3853325843811035, + 1.3832392692565918, + 1.3812776803970337, + 1.3789170980453491, + 1.3797601461410522, + 1.3777130842208862, + 1.3756908178329468, + 1.3735847473144531, + 1.3711339235305786, + 1.3690725564956665, + 1.366849660873413, + 1.364676833152771, + 1.364690899848938, + 1.362541675567627, + 1.3608990907669067, + 1.358725666999817, + 1.3566731214523315, + 1.3549315929412842, + 1.3522361516952515, + 1.3527694940567017, + 1.3506978750228882, + 1.3486748933792114, + 1.3464853763580322, + 1.344880223274231, + 1.3429863452911377, + 1.3410741090774536, + 1.3392194509506226, + 1.3374866247177124, + 1.3356633186340332, + 1.333699345588684, + 1.3316830396652222, + 1.3293700218200684, + 1.327602744102478, + 1.325606107711792, + 1.323636770248413, + 1.322487473487854, + 1.3201935291290283, + 1.3185865879058838, + 1.3163570165634155, + 1.315348505973816, + 1.3135069608688354, + 1.3115581274032593, + 1.309351921081543, + 1.3080940246582031, + 1.306084394454956, + 1.3041918277740479, + 1.3022172451019287, + 1.30052649974823, + 1.2984906435012817, + 1.296433925628662, + 1.294618844985962, + 1.2924813032150269, + 1.290629267692566, + 1.288609504699707, + 1.286437749862671, + 1.2852808237075806, + 1.2831010818481445, + 1.281022071838379, + 1.2789161205291748, + 1.2785669565200806, + 1.2766060829162598, + 1.274585485458374, + 1.2728400230407715, + 1.2709832191467285, + 1.2691309452056885, + 1.2671318054199219, + 1.265442132949829, + 1.2635501623153687, + 1.2614946365356445, + 1.2593908309936523, + 1.257880449295044, + 1.2560313940048218, + 1.254082441329956, + 1.2522804737091064, + 1.2505096197128296, + 1.2482692003250122, + 1.2462091445922852, + 1.2445822954177856, + 1.2432236671447754, + 1.2414650917053223, + 1.2396503686904907, + 1.2376699447631836, + 1.2357380390167236, + 1.2339240312576294, + 1.2320566177368164, + 1.2299892902374268, + 1.2286840677261353, + 1.226925015449524, + 1.2250070571899414, + 1.223126769065857, + 1.2215166091918945, + 1.2196996212005615, + 1.2178195714950562, + 1.2158279418945312, + 1.2140803337097168, + 1.2121261358261108, + 1.210100769996643, + 1.2083507776260376, + 1.2063525915145874, + 1.2046012878417969, + 1.2027149200439453, + 1.201154112815857, + 1.1992254257202148, + 1.1971834897994995, + 1.1951549053192139, + 1.1935709714889526, + 1.191764235496521, + 1.1898930072784424, + 1.187896728515625, + 1.186535120010376, + 1.184600591659546, + 1.1826894283294678, + 1.1809728145599365, + 1.1789331436157227, + 1.1774684190750122, + 1.1756458282470703, + 1.1736308336257935, + 1.1719911098480225, + 1.1701127290725708, + 1.1681468486785889, + 1.165921926498413, + 1.16463041305542, + 1.1627451181411743, + 1.1608567237854004, + 1.1590938568115234, + 1.1575335264205933, + 1.1555901765823364, + 1.1538552045822144, + 1.1518657207489014, + 1.14994215965271, + 1.1481153964996338, + 1.14644455909729, + 1.1444120407104492, + 1.1428122520446777, + 1.1410313844680786, + 1.1391836404800415, + 1.137330412864685, + 1.135434627532959, + 1.1336791515350342, + 1.131978154182434, + 1.1300874948501587, + 1.128359317779541, + 1.1264809370040894, + 1.1248611211776733, + 1.122762680053711, + 1.1209162473678589, + 1.1190710067749023, + 1.1172044277191162, + 1.1158984899520874, + 1.1140459775924683, + 1.1124012470245361, + 1.110682487487793, + 1.1087219715118408, + 1.106826901435852, + 1.1050584316253662, + 1.1034021377563477, + 1.1011031866073608, + 1.0996853113174438, + 1.0978131294250488, + 1.0963127613067627, + 1.0944904088974, + 1.0927494764328003, + 1.0910944938659668, + 1.0892736911773682, + 1.0878331661224365, + 1.0860958099365234, + 1.0842169523239136, + 1.0826303958892822, + 1.0806686878204346, + 1.078961730003357, + 1.0773676633834839, + 1.0755786895751953, + 1.073934555053711, + 1.0721861124038696, + 1.0704376697540283, + 1.0689181089401245, + 1.067183256149292, + 1.0654473304748535, + 1.0637754201889038, + 1.0620981454849243, + 1.0604465007781982, + 1.0587077140808105, + 1.0570865869522095, + 1.0553107261657715, + 1.053688883781433, + 1.0520380735397339, + 1.0502020120620728, + 1.048741340637207, + 1.046962022781372, + 1.0453627109527588, + 1.0439050197601318, + 1.041886806488037, + 1.0405514240264893, + 1.0387938022613525, + 1.0370451211929321, + 1.035706877708435, + 1.03403902053833, + 1.0325669050216675, + 1.0308712720870972, + 1.0291008949279785, + 1.0275760889053345, + 1.0258709192276, + 1.024153470993042, + 1.022807240486145, + 1.0211118459701538, + 1.019489049911499, + 1.0178107023239136, + 1.0159832239151, + 1.0143824815750122, + 1.0128840208053589, + 1.0111985206604004, + 1.0098742246627808, + 1.0081874132156372, + 1.0064918994903564, + 1.0050266981124878, + 1.0036821365356445, + 1.0018991231918335, + 1.0004172325134277, + 0.9988566637039185, + 0.9969817399978638, + 0.9954714179039001, + 0.9939242005348206, + 0.9923979640007019, + 0.9910774230957031, + 0.9894015789031982, + 0.9880895614624023, + 0.9861252903938293, + 0.9846389889717102, + 0.9831112027168274, + 0.9815076589584351, + 0.9799305200576782, + 0.9784950017929077, + 0.976923942565918, + 0.975475549697876, + 0.9737277626991272, + 0.9722781181335449, + 0.9707712531089783, + 0.9693742394447327, + 0.9677569270133972, + 0.9663806557655334, + 0.9648120999336243, + 0.963326096534729, + 0.9619874358177185, + 0.9605197906494141, + 0.9590029120445251, + 0.9575618505477905, + 0.9558634757995605, + 0.9542866945266724, + 0.9530059099197388, + 0.9513764977455139, + 0.9499674439430237, + 0.948621392250061, + 0.947046160697937, + 0.945502519607544, + 0.9441988468170166, + 0.9427464604377747, + 0.9413387179374695, + 0.9397821426391602, + 0.9385508894920349, + 0.9372508525848389, + 0.9356773495674133, + 0.9340954422950745, + 0.9325379133224487, + 0.9311357140541077, + 0.9296550154685974, + 0.9283716082572937, + 0.9268398880958557, + 0.9254037141799927, + 0.9239259362220764, + 0.9225856065750122, + 0.921108603477478, + 0.9197893142700195, + 0.9185012578964233, + 0.9169778823852539, + 0.9154301881790161, + 0.9140625, + 0.9127756357192993, + 0.9113842844963074, + 0.9101965427398682, + 0.9088224172592163, + 0.9074375629425049, + 0.9061430096626282, + 0.9046499133110046, + 0.9033547043800354, + 0.9018712639808655, + 0.9006990790367126, + 0.8993589878082275, + 0.8980291485786438, + 0.8965833187103271, + 0.8953617811203003, + 0.8940249681472778, + 0.8928234577178955, + 0.8914735913276672, + 0.8900470733642578, + 0.8885773420333862, + 0.887448251247406, + 0.8860753178596497, + 0.8848751783370972, + 0.8835704326629639, + 0.8822427988052368, + 0.8808343410491943, + 0.8794860243797302, + 0.8782272338867188, + 0.876940131187439, + 0.8755697011947632, + 0.8743593096733093, + 0.8731096982955933, + 0.8717764019966125, + 0.870373547077179, + 0.869137704372406, + 0.8679963946342468, + 0.8665465116500854, + 0.8653771281242371, + 0.8643192052841187, + 0.8630129098892212, + 0.8618021011352539, + 0.8606610894203186, + 0.8596193194389343, + 0.8584977984428406, + 0.8571111559867859, + 0.8558118343353271, + 0.854767382144928, + 0.8535858392715454, + 0.8525562882423401, + 0.851208508014679, + 0.8500548601150513, + 0.8489854335784912, + 0.8476380109786987, + 0.8465084433555603, + 0.8454263806343079, + 0.8440982699394226, + 0.8429536819458008, + 0.8419493436813354, + 0.8406177759170532, + 0.8395005464553833, + 0.83843994140625, + 0.8372390866279602, + 0.836262583732605, + 0.8351759910583496, + 0.8340833187103271, + 0.8330100178718567, + 0.8318305611610413, + 0.8307360410690308, + 0.8296796083450317, + 0.8287205696105957, + 0.8275678753852844, + 0.8264811038970947, + 0.8253570795059204, + 0.8243551254272461, + 0.8232539296150208, + 0.822137176990509, + 0.8212800025939941, + 0.8199703097343445, + 0.8190608024597168, + 0.8179953098297119, + 0.8167867064476013, + 0.8158150315284729, + 0.8149182200431824, + 0.8140754699707031, + 0.8131433129310608, + 0.8118599057197571, + 0.8109708428382874, + 0.8099024891853333, + 0.8090004324913025, + 0.8079776763916016, + 0.807029664516449, + 0.8058684468269348, + 0.8049055337905884, + 0.8039948344230652, + 0.803061306476593, + 0.8021382689476013, + 0.8012913465499878, + 0.8002091646194458, + 0.7992268204689026, + 0.7981467247009277, + 0.7973214983940125, + 0.7964017987251282, + 0.7954541444778442, + 0.7945792078971863, + 0.7938122153282166, + 0.7926003932952881, + 0.7917800545692444, + 0.7908596396446228, + 0.7899304628372192, + 0.7890149354934692, + 0.7882192730903625, + 0.7870058417320251, + 0.7863731980323792, + 0.7852027416229248, + 0.7844488024711609, + 0.783501386642456, + 0.7827003598213196, + 0.7819803357124329, + 0.7808201909065247, + 0.7800688147544861, + 0.7791293263435364, + 0.7784658670425415, + 0.7775732278823853, + 0.7768633961677551, + 0.7760342359542847, + 0.775243878364563, + 0.7743030786514282, + 0.7735926508903503, + 0.7724748849868774, + 0.7718163728713989, + 0.77097487449646, + 0.7702510356903076, + 0.7693900465965271, + 0.7687169313430786, + 0.7678922414779663, + 0.7672128081321716, + 0.7663589715957642, + 0.7657037377357483, + 0.7647771239280701, + 0.7640203237533569, + 0.7633466720581055, + 0.7625623941421509, + 0.7617509961128235, + 0.7610896229743958, + 0.760379433631897, + 0.7596492767333984, + 0.7588953971862793, + 0.7581916451454163, + 0.7573999166488647, + 0.7568274736404419, + 0.756077229976654, + 0.7554765939712524, + 0.7546539306640625, + 0.7539674043655396, + 0.753139853477478, + 0.7525543570518494, + 0.7519160509109497, + 0.7513154149055481, + 0.7505142688751221, + 0.7497125864028931, + 0.74923175573349, + 0.7484207153320312, + 0.7479155659675598, + 0.7473617792129517, + 0.7468436360359192, + 0.7462318539619446, + 0.7456430792808533, + 0.7447810769081116, + 0.7442206144332886, + 0.7435954809188843, + 0.7431489825248718, + 0.7422271370887756, + 0.7418114542961121, + 0.7412892580032349, + 0.740713357925415, + 0.7401546239852905, + 0.7396021485328674, + 0.7390599846839905, + 0.73844313621521, + 0.7377904653549194, + 0.737305223941803, + 0.7368288636207581, + 0.7363747358322144, + 0.7358483076095581, + 0.7354381680488586, + 0.7348212003707886, + 0.7343763709068298, + 0.7336553335189819, + 0.7332231402397156, + 0.73262619972229, + 0.7321929931640625, + 0.7315752506256104, + 0.7312256693840027, + 0.7306149005889893, + 0.7302426695823669, + 0.7299467325210571, + 0.7294563055038452, + 0.728706419467926, + 0.7283353209495544, + 0.7279900312423706, + 0.7276231646537781, + 0.7273217439651489, + 0.7269001007080078, + 0.7265130877494812, + 0.7261000871658325, + 0.7257733345031738, + 0.725188672542572, + 0.724976658821106, + 0.7242119908332825, + 0.7238465547561646, + 0.7236427664756775, + 0.7232236266136169, + 0.7227877974510193, + 0.7226144075393677, + 0.7221262454986572, + 0.7218216061592102, + 0.7215451002120972, + 0.7213869094848633, + 0.7209206819534302, + 0.7207257747650146, + 0.7203994989395142, + 0.7200448513031006, + 0.7197679281234741, + 0.7195186018943787, + 0.7190226912498474, + 0.7188836932182312, + 0.7186117768287659, + 0.7185105681419373, + 0.718199610710144, + 0.7180152535438538, + 0.7175536155700684, + 0.7173341512680054, + 0.7171205878257751, + 0.7168837189674377, + 0.7163654565811157, + 0.7162774801254272, + 0.7161651253700256, + 0.7160663604736328, + 0.7159175872802734, + 0.7157440185546875, + 0.7154026031494141, + 0.7153436541557312, + 0.715220034122467, + 0.7150475978851318, + 0.7150062322616577, + 0.7149052619934082, + 0.7147804498672485, + 0.7147180438041687, + 0.7146442532539368, + 0.7143230438232422, + 0.7142894268035889, + 0.7143252491950989, + 0.7141361236572266, + 0.7140751481056213, + 0.7138771414756775, + 0.7138750553131104, + 0.7138450145721436, + 0.7138748168945312, + 0.7137607336044312, + 0.7137340903282166, + 0.7137055993080139, + 0.7136792540550232, + 0.7135677337646484, + 0.7134214639663696, + 0.7135778069496155, + 0.7136402130126953, + 0.713737964630127, + 0.7136131525039673, + 0.7135958075523376, + 0.713367760181427, + 0.7136869430541992, + 0.7137601971626282, + 0.7137682437896729, + 0.7137079834938049, + 0.7138195633888245, + 0.713512122631073, + 0.7136629819869995, + 0.7137271761894226, + 0.7138593792915344, + 0.714098334312439, + 0.714293360710144, + 0.7142848372459412, + 0.7144456505775452, + 0.714730978012085, + 0.7147268652915955, + 0.7149925231933594, + 0.7151434421539307, + 0.7151704430580139, + 0.7152837514877319, + 0.7152866125106812, + 0.7155521512031555, + 0.7158286571502686, + 0.7161067724227905, + 0.7161760926246643, + 0.716302752494812, + 0.7165276408195496, + 0.716739296913147, + 0.7168811559677124, + 0.7171459794044495, + 0.7174181342124939, + 0.7176154255867004, + 0.7177509069442749, + 0.7180988192558289, + 0.718157172203064, + 0.7184935212135315, + 0.7185874581336975, + 0.7188709378242493, + 0.7191057801246643, + 0.7192631959915161, + 0.7195265293121338, + 0.7197085618972778, + 0.720137357711792, + 0.7203015089035034, + 0.7204228639602661, + 0.720592737197876, + 0.7209603786468506, + 0.7212156057357788, + 0.7214911580085754, + 0.72171950340271, + 0.7221818566322327, + 0.7225285768508911, + 0.7227242588996887, + 0.7229769825935364, + 0.7232232689857483, + 0.7235181927680969, + 0.7238690853118896, + 0.7240993976593018, + 0.7244613170623779, + 0.7245984673500061, + 0.7250016331672668, + 0.7251067161560059, + 0.7255734205245972, + 0.7258168458938599, + 0.7260231375694275, + 0.7262008786201477, + 0.7266826629638672, + 0.7266356945037842, + 0.7272213697433472, + 0.7274652719497681, + 0.7279736399650574, + 0.7281084656715393, + 0.7283049821853638, + 0.7285397052764893, + 0.7289696931838989, + 0.7293713688850403, + 0.72969651222229, + 0.7298030853271484, + 0.7300551533699036, + 0.7303762435913086, + 0.7306288480758667, + 0.7307636141777039, + 0.7312272787094116, + 0.7314295172691345, + 0.7316324710845947, + 0.7314282655715942, + 0.7316331267356873, + 0.7319273352622986, + 0.732114315032959, + 0.732216477394104, + 0.7322894930839539, + 0.7325413227081299, + 0.7325369715690613, + 0.7325401902198792, + 0.7323561906814575, + 0.7322503924369812, + 0.7322250008583069, + 0.7320146560668945, + 0.7321474552154541, + 0.732187032699585, + 0.7320796251296997, + 0.7315171360969543, + 0.7313243746757507, + 0.7310792803764343, + 0.7308080196380615, + 0.7306304574012756, + 0.7296295166015625, + 0.7289745807647705, + 0.7288649082183838, + 0.7281786799430847, + 0.7277448773384094, + 0.7272322177886963, + 0.7265949845314026, + 0.725848376750946, + 0.7252488136291504, + 0.7245422601699829, + 0.723800003528595, + 0.7228619456291199, + 0.7218728065490723, + 0.720892071723938, + 0.7198144793510437, + 0.7186352610588074, + 0.717454731464386, + 0.7162171602249146, + 0.7149246335029602, + 0.7136655449867249, + 0.7121627926826477, + 0.7108365297317505, + 0.7092090249061584, + 0.7076245546340942, + 0.7060236930847168, + 0.704273521900177, + 0.7023670673370361, + 0.7007937431335449, + 0.698858916759491, + 0.696875810623169, + 0.69483882188797, + 0.6926799416542053, + 0.6907360553741455, + 0.6884570121765137, + 0.68642657995224, + 0.6837813258171082, + 0.6816286444664001, + 0.6790634989738464, + 0.6767467260360718, + 0.6742716431617737, + 0.6715688705444336, + 0.6689924001693726, + 0.6662940382957458, + 0.6634176969528198, + 0.6603904962539673, + 0.6574862599372864, + 0.6544369459152222, + 0.651530385017395, + 0.6485568284988403, + 0.6453983187675476, + 0.6423068046569824, + 0.6392328143119812, + 0.6360040307044983, + 0.6325976252555847, + 0.6291093826293945, + 0.6256147623062134, + 0.6223041415214539, + 0.6188730001449585, + 0.615329921245575, + 0.6118036508560181, + 0.6081951260566711, + 0.604539155960083, + 0.6009404063224792, + 0.5972440242767334, + 0.5937116146087646, + 0.5897051692008972, + 0.5859677195549011, + 0.5821571946144104, + 0.578381359577179, + 0.5747998952865601, + 0.5709896683692932, + 0.5671953558921814, + 0.5633583068847656, + 0.5594204068183899, + 0.5555382370948792, + 0.5519280433654785, + 0.5482934713363647, + 0.544551432132721, + 0.5410515666007996, + 0.5374910831451416, + 0.5340041518211365, + 0.5304144024848938, + 0.5269584655761719, + 0.5235306620597839, + 0.520039975643158, + 0.516674280166626, + 0.513296902179718, + 0.5098193883895874, + 0.5064578652381897, + 0.5030517578125, + 0.4997297525405884, + 0.4967145025730133, + 0.49335765838623047, + 0.4902186989784241, + 0.486634224653244, + 0.48311659693717957, + 0.4792158007621765, + 0.4755136966705322, + 0.4720709025859833, + 0.4689248502254486, + 0.4660993814468384, + 0.46355342864990234, + 0.46058982610702515, + 0.45763304829597473, + 0.45535609126091003, + 0.45405313372612, + 0.45241352915763855, + 0.45207348465919495, + 0.45095735788345337, + 0.45052871108055115, + 0.449806272983551, + 0.4484655559062958, + 0.44648951292037964, + 0.44580715894699097, + 0.4447800815105438, + 0.4453802704811096, + 0.4472601115703583 +] \ No newline at end of file diff --git a/toolkit/timestep_weighing/flex_timestep_weights_plot.png b/toolkit/timestep_weighing/flex_timestep_weights_plot.png new file mode 100644 index 00000000..b77632e2 Binary files /dev/null and b/toolkit/timestep_weighing/flex_timestep_weights_plot.png differ diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index b8ab6b93..ca587564 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -33,7 +33,6 @@ export default function SimpleJob({ gpuList, datasetOptions, }: Props) { - const modelArch = useMemo(() => { return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; }, [jobConfig.config.process[0].model.arch]); @@ -104,8 +103,7 @@ export default function SimpleJob({ const newDataset = objectCopy(dataset); newDataset.controls = controls; return newDataset; - } - ); + }); setJobConfig(datasets, 'config.process[0].datasets'); }} options={ @@ -131,20 +129,22 @@ export default function SimpleJob({ placeholder="" required /> - -
- setJobConfig(value, 'config.process[0].model.quantize')} - /> - setJobConfig(value, 'config.process[0].model.quantize_te')} - /> -
-
+ {modelArch?.disableSections?.includes('model.quantize') ? null : ( + +
+ setJobConfig(value, 'config.process[0].model.quantize')} + /> + setJobConfig(value, 'config.process[0].model.quantize_te')} + /> +
+
+ )} )} {jobConfig.config.process[0].network?.type == 'lora' && ( - { - console.log('onChange', value); - setJobConfig(value, 'config.process[0].network.linear'); - setJobConfig(value, 'config.process[0].network.linear_alpha'); - }} - placeholder="eg. 16" - min={0} - max={1024} - required - /> + <> + { + console.log('onChange', value); + setJobConfig(value, 'config.process[0].network.linear'); + setJobConfig(value, 'config.process[0].network.linear_alpha'); + }} + placeholder="eg. 16" + min={0} + max={1024} + required + /> + { + modelArch?.disableSections?.includes('network.conv') ? null : ( + { + console.log('onChange', value); + setJobConfig(value, 'config.process[0].network.conv'); + setJobConfig(value, 'config.process[0].network.conv_alpha'); + }} + placeholder="eg. 16" + min={0} + max={1024} + /> + ) + } + )} @@ -276,16 +294,19 @@ export default function SimpleJob({ />
- setJobConfig(value, 'config.process[0].train.timestep_type')} - options={[ - { value: 'sigmoid', label: 'Sigmoid' }, - { value: 'linear', label: 'Linear' }, - { value: 'shift', label: 'Shift' }, - ]} - /> + {modelArch?.disableSections?.includes('train.timestep_type') ? null : ( + setJobConfig(value, 'config.process[0].train.timestep_type')} + options={[ + { value: 'sigmoid', label: 'Sigmoid' }, + { value: 'linear', label: 'Linear' }, + { value: 'shift', label: 'Shift' }, + ]} + /> + )} setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')} @@ -353,7 +374,7 @@ export default function SimpleJob({ min={0} /> setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} @@ -466,10 +487,7 @@ export default function SimpleJob({ // automaticallt add the controls for a new dataset const controls = modelArch?.controls ?? []; newDataset.controls = controls; - setJobConfig( - [...jobConfig.config.process[0].datasets, newDataset], - 'config.process[0].datasets', - ) + setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets'); }} className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors" > diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 59268012..dd28461b 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -30,6 +30,8 @@ export const defaultJobConfig: JobConfig = { type: 'lora', linear: 32, linear_alpha: 32, + conv: 16, + conv_alpha: 16, lokr_full_rank: true, lokr_factor: -1, network_kwargs: { diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index b35f32a4..e84bc4a0 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -1,4 +1,3 @@ - type Control = 'depth' | 'line' | 'pose' | 'inpaint'; export interface ModelArch { @@ -6,11 +5,14 @@ export interface ModelArch { label: string; controls?: Control[]; isVideoModel?: boolean; - defaults?: { [key: string]: [any, any] }; + defaults?: { [key: string]: any }; + disableSections?: DisableableSections[]; } const defaultNameOrPath = ''; +type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; + export const modelArchs: ModelArch[] = [ { name: 'flux', @@ -23,6 +25,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, + disableSections: ['network.conv'], }, { name: 'flex1', @@ -36,6 +39,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, + disableSections: ['network.conv'], }, { name: 'flex2', @@ -62,6 +66,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, + disableSections: ['network.conv'], }, { name: 'chroma', @@ -74,6 +79,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, + disableSections: ['network.conv'], }, { name: 'wan21:1b', @@ -89,6 +95,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.num_frames': [40, 1], 'config.process[0].sample.fps': [15, 1], }, + disableSections: ['network.conv'], }, { name: 'wan21:14b', @@ -96,7 +103,7 @@ export const modelArchs: ModelArch[] = [ isVideoModel: true, defaults: { // default updates when [selected, unselected] in the UI - 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffuserss', defaultNameOrPath], + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffusers', defaultNameOrPath], 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], @@ -104,6 +111,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.num_frames': [40, 1], 'config.process[0].sample.fps': [15, 1], }, + disableSections: ['network.conv'], }, { name: 'lumina2', @@ -116,6 +124,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, + disableSections: ['network.conv'], }, { name: 'hidream', @@ -131,5 +140,37 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].train.timestep_type': ['shift', 'sigmoid'], 'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []], }, + disableSections: ['network.conv'], }, -]; + { + name: 'sdxl', + label: 'SDXL', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath], + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [false, false], + 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], + 'config.process[0].sample.guidance_scale': [6, 4], + }, + disableSections: ['model.quantize', 'train.timestep_type'], + }, + { + name: 'sd15', + label: 'SD 1.5', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath], + 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], + 'config.process[0].sample.width': [512, 1024], + 'config.process[0].sample.height': [512, 1024], + 'config.process[0].sample.guidance_scale': [6, 4], + }, + disableSections: ['model.quantize', 'train.timestep_type'], + }, +].sort((a, b) => { + // Sort by label, case-insensitive + return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }) +}) as any; diff --git a/ui/src/components/Sidebar.tsx b/ui/src/components/Sidebar.tsx index 14ec4542..324b6097 100644 --- a/ui/src/components/Sidebar.tsx +++ b/ui/src/components/Sidebar.tsx @@ -11,11 +11,12 @@ const Sidebar = () => { ]; return ( -
-
-

- Ostris AI Toolkit - Ostris - AI Toolkit +
+
+

+ Ostris AI Toolkit + Ostris + AI-Toolkit

- +
- - - - - - - - + + +
-
Support me on Patreon
+
Support AI-Toolkit
); diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index 7f53e681..a9908bd0 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -2,8 +2,8 @@ import React, { forwardRef } from 'react'; import classNames from 'classnames'; -import dynamic from "next/dynamic"; -const Select = dynamic(() => import("react-select"), { ssr: false }); +import dynamic from 'next/dynamic'; +const Select = dynamic(() => import('react-select'), { ssr: false }); const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300'; const inputClasses = @@ -42,7 +42,7 @@ export const TextInput = forwardRef( />

); - } + }, ); // 👇 Helpful for debugging @@ -114,6 +114,7 @@ export const NumberInput = (props: NumberInputProps) => { export interface SelectInputProps extends InputProps { value: string; + disabled?: boolean; onChange: (value: string) => void; options: { value: string; label: string }[]; } @@ -122,11 +123,16 @@ export const SelectInput = (props: SelectInputProps) => { const { label, value, onChange, options } = props; const selectedOption = options.find(option => option.value === value); return ( -
+
{label && } -