Merged in from main

This commit is contained in:
Jaret Burkett
2025-06-24 10:56:54 -06:00
32 changed files with 1621 additions and 1060 deletions

View File

@@ -417,6 +417,17 @@ Everything else should work the same including layer targeting.
## Updates
### June 17, 2024
- Performance optimizations for batch preparation
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
### June 16, 2024
- Hide control images in the UI when viewing datasets
- WIP on mean flow loss
### June 12, 2024
- Fixed issue that resulted in blank captions in the dataloader
### June 10, 2024
- Decided to keep track up updates in the readme
- Added support for SDXL in the UI

View File

@@ -36,7 +36,6 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F
from toolkit.models.flux import convert_flux_to_mean_flow
def flush():
@@ -62,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess):
self._clip_image_embeds_unconditional: Union[List[str], None] = None
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.cfm_cache = None
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
@@ -85,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess):
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
self.dfe: Optional[DiffusionFeatureExtractor] = None
self.unconditional_embeds = None
if self.train_config.diff_output_preservation:
if self.trigger_word is None:
@@ -96,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess):
# always do a prior prediction when doing diff output preservation
self.do_prior_prediction = True
# store the loss target for a batch so we can use it in a loss
self._guidance_loss_target_batch: float = 0.0
if isinstance(self.train_config.guidance_loss_target, (int, float)):
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target)
elif isinstance(self.train_config.guidance_loss_target, list):
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0])
else:
raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}")
def before_model_load(self):
@@ -135,9 +143,16 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.train_config.loss_type == "mean_flow":
# todo handle non flux models
convert_flux_to_mean_flow(self.sd.unet)
# cache unconditional embeds (blank prompt)
with torch.no_grad():
self.unconditional_embeds = self.sd.encode_prompt(
[self.train_config.unconditional_prompt],
long_prompts=self.do_long_prompts
).to(
self.device_torch,
dtype=self.sd.torch_dtype
).detach()
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
@@ -418,15 +433,41 @@ class SDTrainer(BaseSDTrainProcess):
if self.dfe is not None:
if self.dfe.version == 1:
# do diffusion feature extraction on target
model = self.sd
if model is not None and hasattr(model, 'get_stepped_pred'):
stepped_latents = model.get_stepped_pred(noise_pred, noise)
else:
# stepped_latents = noise - noise_pred
# first we step the scheduler from current timestep to the very end for a full denoise
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
self.sd.noise_scheduler._step_index = None
self.sd.noise_scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index]
sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
pred_features = self.dfe(stepped_latents.float())
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
# scale dfe so it is weaker at higher noise levels
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
# do diffusion feature extraction on prediction
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
self.train_config.diffusion_feature_extractor_weight
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
additional_loss += dfe_loss.mean()
elif self.dfe.version == 2:
# version 2
# do diffusion feature extraction on target
@@ -454,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess):
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
else:
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
if self.train_config.do_guidance_loss:
with torch.no_grad():
# we make cached blank prompt embeds that match the batch size
unconditional_embeds = concat_prompt_embeds(
[self.unconditional_embeds] * noisy_latents.shape[0],
)
cfm_pred = self.predict_noise(
noisy_latents=noisy_latents,
timesteps=timesteps,
conditional_embeds=unconditional_embeds,
unconditional_embeds=None,
batch=batch,
)
# zero cfg
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
batch_size = target.shape[0]
positive_flat = target.view(batch_size, -1)
negative_flat = cfm_pred.view(batch_size, -1)
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
alpha = st_star
is_video = len(target.shape) == 5
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
guidance_scale = self._guidance_loss_target_batch
if isinstance(guidance_scale, list):
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
unconditional_target = cfm_pred * alpha
target = unconditional_target + guidance_scale * (target - unconditional_target)
if target is None:
@@ -634,102 +716,6 @@ class SDTrainer(BaseSDTrainProcess):
return loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
# This version avoids jvp / double-back-prop issues with Flash-Attention
# adapted from the work of lodestonerock
# ------------------------------------------------------------------
def get_mean_flow_loss_wip(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
time_end = timesteps.float() / 1000
# for timestep_r, we need values from timestep_end to 0.0 randomly
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
# Compute noised data points
# lerp_vector = noisy_latents
# compute instantaneous vector
instantaneous_vector = noise - batch_latents
# finite difference method
epsilon_fd = 1e-3
jitter_std = 1e-4
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
# f(x + epsilon * v) for the primal (we backprop through here)
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
mean_vec_val_pred = self.predict_noise(
noisy_latents=noisy_latents,
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
with torch.no_grad():
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
f_x_plus_eps_v = self.predict_noise(
noisy_latents=perturbed_lerp_vector,
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
mean_vec_grad = mean_vec_grad_fd
# calculate the regression target the mean vector
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
# 5) MSE loss
loss = torch.nn.functional.mse_loss(
mean_vec_val_pred.float(),
mean_vec_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
with torch.no_grad():
loss_mean = loss.mean([1, 2, 3], keepdim=True)
loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
@@ -970,6 +956,10 @@ class SDTrainer(BaseSDTrainProcess):
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
guidance_embedding_scale = self.train_config.cfg_scale
if self.train_config.do_guidance_loss:
guidance_embedding_scale = self._guidance_loss_target_batch
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
@@ -977,6 +967,7 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
guidance_embedding_scale=guidance_embedding_scale,
rescale_cfg=self.train_config.cfg_rescale,
batch=batch,
**pred_kwargs # adapter residuals in here
@@ -1020,13 +1011,16 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
guidance_embedding_scale = self.train_config.cfg_scale
if self.train_config.do_guidance_loss:
guidance_embedding_scale = self._guidance_loss_target_batch
return self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
guidance_embedding_scale=self.train_config.cfg_scale,
guidance_embedding_scale=guidance_embedding_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
@@ -1034,152 +1028,78 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
)
def cfm_augment_tensors(
self,
images: torch.Tensor
) -> torch.Tensor:
if self.cfm_cache is None:
# flip the current one. Only need this for first time
self.cfm_cache = torch.flip(images, [3]).clone()
augmented_tensor_list = []
for i in range(images.shape[0]):
# get a random one
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
augmented = torch.cat(augmented_tensor_list, dim=0)
# resize to match the input
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
self.cfm_cache = images.clone()
return augmented
def get_cfm_loss(
self,
noisy_latents: torch.Tensor,
noise: torch.Tensor,
noise_pred: torch.Tensor,
conditional_embeds: PromptEmbeds,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
alpha: float = 0.1,
):
dtype = get_torch_dtype(self.train_config.dtype)
if hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching:
# forward ODE
target = (noise - batch.latents).detach()
else:
raise ValueError("CFM loss only works with flow matching")
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
with torch.no_grad():
# we need to compute the contrast
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
cfm_noisy_latents = self.sd.add_noise(
original_samples=cfm_latents,
noise=noise,
timesteps=timesteps,
)
cfm_pred = self.predict_noise(
noisy_latents=cfm_noisy_latents,
timesteps=timesteps,
conditional_embeds=conditional_embeds,
unconditional_embeds=None,
batch=batch,
)
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
# # Compute cosine similarity at each pixel
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
# Compute cosine similarity at each pixel
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
# Average over spatial dimensions, then batch
contrastive_loss = -sim.mean()
loss = fm_loss.mean() + alpha * contrastive_loss
return loss
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
self.timer.start('preprocess_batch')
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_raw(batch)
batch = self.preprocess_batch(batch)
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_processed(batch)
dtype = get_torch_dtype(self.train_config.dtype)
# sanity check
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
if encoder.dtype != self.sd.te_torch_dtype:
encoder.to(self.sd.te_torch_dtype)
else:
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
self.sd.text_encoder.to(self.sd.te_torch_dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.train_config.do_cfg or self.train_config.do_random_cfg:
# pick random negative prompts
if self.negative_prompt_pool is not None:
negative_prompts = []
for i in range(noisy_latents.shape[0]):
num_neg = random.randint(1, self.train_config.max_negative_prompts)
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
this_neg_prompt = ', '.join(this_neg_prompts)
negative_prompts.append(this_neg_prompt)
self.batch_negative_prompt = negative_prompts
else:
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt
# todo handle more than one adapter image
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
has_clip_image = batch.clip_image_tensor is not None
has_clip_image_embeds = batch.clip_image_embeds is not None
# force it to be true if doing regs as we handle those differently
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
has_clip_image = True
if self._clip_image_embeds_unconditional is not None:
has_clip_image_embeds = True # we are caching embeds, handle that differently
has_clip_image = False
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError(
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
match_adapter_assist = True
elif self.train_config.match_adapter_chance > 0.0:
match_adapter_assist = torch.rand(
(1,), device=self.device_torch, dtype=dtype
) < self.train_config.match_adapter_chance
self.timer.stop('preprocess_batch')
is_reg = False
with torch.no_grad():
self.timer.start('preprocess_batch')
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_raw(batch)
batch = self.preprocess_batch(batch)
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_processed(batch)
dtype = get_torch_dtype(self.train_config.dtype)
# sanity check
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
if encoder.dtype != self.sd.te_torch_dtype:
encoder.to(self.sd.te_torch_dtype)
else:
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
self.sd.text_encoder.to(self.sd.te_torch_dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.train_config.do_cfg or self.train_config.do_random_cfg:
# pick random negative prompts
if self.negative_prompt_pool is not None:
negative_prompts = []
for i in range(noisy_latents.shape[0]):
num_neg = random.randint(1, self.train_config.max_negative_prompts)
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
this_neg_prompt = ', '.join(this_neg_prompts)
negative_prompts.append(this_neg_prompt)
self.batch_negative_prompt = negative_prompts
else:
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt
# todo handle more than one adapter image
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
has_clip_image = batch.clip_image_tensor is not None
has_clip_image_embeds = batch.clip_image_embeds is not None
# force it to be true if doing regs as we handle those differently
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
has_clip_image = True
if self._clip_image_embeds_unconditional is not None:
has_clip_image_embeds = True # we are caching embeds, handle that differently
has_clip_image = False
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError(
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
match_adapter_assist = True
elif self.train_config.match_adapter_chance > 0.0:
match_adapter_assist = torch.rand(
(1,), device=self.device_torch, dtype=dtype
) < self.train_config.match_adapter_chance
self.timer.stop('preprocess_batch')
is_reg = False
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:
@@ -1733,6 +1653,16 @@ class SDTrainer(BaseSDTrainProcess):
)
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list):
batch_size = noisy_latents.shape[0]
# update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1]
self._guidance_loss_target_batch = [
random.uniform(
self.train_config.guidance_loss_target[0],
self.train_config.guidance_loss_target[1]
) for _ in range(batch_size)
]
self.before_unet_predict()
@@ -1832,25 +1762,15 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
prior_to_calculate_loss = None
if self.train_config.loss_type == 'cfm':
loss = self.get_cfm_loss(
noisy_latents=noisy_latents,
noise=noise,
noise_pred=noise_pred,
conditional_embeds=conditional_embeds,
timesteps=timesteps,
batch=batch,
)
else:
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_to_calculate_loss,
)
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_to_calculate_loss,
)
if self.train_config.diff_output_preservation:
# send the loss backwards otherwise checkpointing will fail

View File

@@ -1,8 +1,9 @@
from collections import OrderedDict
from version import VERSION
v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.1.0"
v["version"] = VERSION
software_meta = v

View File

@@ -15,7 +15,6 @@ class BaseJob:
self.config = config['config']
self.raw_config = config
self.job = config['job']
self.torch_profiler = self.get_conf('torch_profiler', False)
self.name = self.get_conf('name', required=True)
if 'meta' in config:
self.meta = config['meta']

View File

@@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.ema: ExponentialMovingAverage = None
validate_configs(self.train_config, self.model_config, self.save_config)
do_profiler = self.get_conf('torch_profiler', False)
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -575,10 +583,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
direct_save = False
if self.adapter_config.train_only_image_encoder:
direct_save = True
if self.adapter_config.type == 'redux':
direct_save = True
if self.adapter_config.type in ['control_lora', 'subpixel', 'i2v']:
direct_save = True
elif isinstance(self.adapter, CustomAdapter):
direct_save = self.adapter.do_direct_save
save_ip_adapter_from_diffusers(
state_dict,
output_file=file_path,
@@ -923,7 +929,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise = self.get_consistent_noise(latents, batch, dtype=dtype)
else:
if hasattr(self.sd, 'get_latent_noise_from_latents'):
noise = self.sd.get_latent_noise_from_latents(latents).to(self.device_torch, dtype=dtype)
noise = self.sd.get_latent_noise_from_latents(
latents,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
else:
# get noise
noise = self.sd.get_latent_noise(
@@ -933,17 +942,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
# if self.train_config.random_noise_shift > 0.0:
# # get random noise -1 to 1
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
# dtype=noise.dtype) * 2 - 1
# # multiply by shift amount
# noise_shift *= self.train_config.random_noise_shift
# # add to noise
# noise += noise_shift
if self.train_config.blended_blur_noise:
noise = get_blended_blur_noise(
@@ -1014,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
cfm_batch = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
@@ -1087,19 +1084,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
# we determine noise from the differential of the latents
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
batch_size = len(batch.file_items)
min_noise_steps = self.train_config.min_denoising_steps
max_noise_steps = self.train_config.max_denoising_steps
if self.model_config.refiner_name_or_path is not None:
# if we are not training the unet, then we are only doing refiner and do not need to double up
if self.train_config.train_unet:
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = True
else:
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = False
with self.timer('prepare_scheduler'):
batch_size = len(batch.file_items)
min_noise_steps = self.train_config.min_denoising_steps
max_noise_steps = self.train_config.max_denoising_steps
if self.model_config.refiner_name_or_path is not None:
# if we are not training the unet, then we are only doing refiner and do not need to double up
if self.train_config.train_unet:
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = True
else:
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = False
with self.timer('prepare_noise'):
num_train_timesteps = self.train_config.num_train_timesteps
if self.train_config.noise_scheduler in ['custom_lcm']:
@@ -1146,6 +1144,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch
)
with self.timer('prepare_timesteps_indices'):
content_or_style = self.train_config.content_or_style
if is_reg:
@@ -1195,20 +1194,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
else:
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
min_idx = min_noise_steps + 1
max_idx = max_noise_steps - 1
if self.train_config.noise_scheduler == 'flowmatch':
# flowmatch uses indices, so we need to use indices
min_idx = 0
max_idx = max_noise_steps - 1
timestep_indices = torch.randint(
min_noise_steps + 1,
max_noise_steps - 1,
min_idx,
max_idx,
(batch_size,),
device=self.device_torch
)
timestep_indices = timestep_indices.long()
else:
raise ValueError(f"Unknown content_or_style {content_or_style}")
with self.timer('convert_timestep_indices_to_timesteps'):
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)
timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()]
with self.timer('prepare_noise'):
# get noise
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
@@ -1242,6 +1247,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
device=noise.device,
dtype=noise.dtype
) * self.train_config.random_noise_multiplier
with self.timer('make_noisy_latents'):
noise = noise * noise_multiplier
@@ -2058,6 +2065,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush()
### HOOK ###
if self.torch_profiler is not None:
self.torch_profiler.start()
with self.accelerator.accumulate(self.modules_being_trained):
try:
loss_dict = self.hook_train_loop(batch_list)
@@ -2069,7 +2078,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
for item in batch.file_items:
print(f" - {item.path}")
raise e
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()
print("\n==== Profile Results ====")
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
self.timer.stop('train_loop')
if not did_first_flush:
flush()

View File

@@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess):
prompt_image_configs = []
for _ in range(self.generate_config.num_repeats):
for prompt in self.generate_config.prompts:
# remove --
prompt = prompt.replace('--', '').strip()
width = self.generate_config.width
height = self.generate_config.height
# prompt = self.clean_prompt(prompt)

View File

@@ -1,6 +1,6 @@
import os
import time
from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
import random
import torch
@@ -413,7 +413,7 @@ class TrainConfig:
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow
# scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
@@ -454,8 +454,12 @@ class TrainConfig:
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
# diffusion feature extractor
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0)
self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None)
self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0)
# we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path)
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight)
# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
@@ -463,6 +467,13 @@ class TrainConfig:
# forces same noise for the same image at a given size.
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
# contrastive loss
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
if isinstance(self.guidance_loss_target, tuple):
self.guidance_loss_target = list(self.guidance_loss_target)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
@@ -827,6 +838,9 @@ class DatasetConfig:
self.controls = [self.controls]
# remove empty strings
self.controls = [control for control in self.controls if control.strip() != '']
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
@@ -1141,3 +1155,6 @@ def validate_configs(
if model_config.use_flux_cfg:
# bypass the embedding
train_config.bypass_guidance_embedding = True
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")

View File

@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.control_lora_adapter import ControlLoraAdapter
from toolkit.models.mean_flow_adapter import MeanFlowAdapter
from toolkit.models.i2v_adapter import I2VAdapter
from toolkit.models.subpixel_adapter import SubpixelAdapter
from toolkit.models.ilora import InstantLoRAModule
@@ -98,6 +99,7 @@ class CustomAdapter(torch.nn.Module):
self.single_value_adapter: SingleValueAdapter = None
self.redux_adapter: ReduxImageEncoder = None
self.control_lora: ControlLoraAdapter = None
self.mean_flow_adapter: MeanFlowAdapter = None
self.subpixel_adapter: SubpixelAdapter = None
self.i2v_adapter: I2VAdapter = None
@@ -125,6 +127,16 @@ class CustomAdapter(torch.nn.Module):
dtype=self.sd_ref().dtype,
)
self.load_state_dict(loaded_state_dict, strict=False)
@property
def do_direct_save(self):
# some adapters save their weights directly, others like ip adapters split the state dict
if self.config.train_only_image_encoder:
return True
if self.config.type in ['control_lora', 'subpixel', 'i2v', 'redux', 'mean_flow']:
return True
return False
def setup_adapter(self):
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
@@ -245,6 +257,13 @@ class CustomAdapter(torch.nn.Module):
elif self.adapter_type == 'redux':
vision_hidden_size = self.vision_encoder.config.hidden_size
self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
elif self.adapter_type == 'mean_flow':
self.mean_flow_adapter = MeanFlowAdapter(
self,
sd=self.sd_ref(),
config=self.config,
train_config=self.train_config
)
elif self.adapter_type == 'control_lora':
self.control_lora = ControlLoraAdapter(
self,
@@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module):
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]:
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel", "mean_flow"]:
return
if self.config.type == 'photo_maker':
try:
@@ -528,6 +547,14 @@ class CustomAdapter(torch.nn.Module):
new_dict[k + '.' + k2] = v2
self.control_lora.load_weights(new_dict, strict=strict)
if self.adapter_type == 'mean_flow':
# state dict is seperated. so recombine it
new_dict = {}
for k, v in state_dict.items():
for k2, v2 in v.items():
new_dict[k + '.' + k2] = v2
self.mean_flow_adapter.load_weights(new_dict, strict=strict)
if self.adapter_type == 'i2v':
# state dict is seperated. so recombine it
new_dict = {}
@@ -599,6 +626,11 @@ class CustomAdapter(torch.nn.Module):
for k, v in d.items():
state_dict[k] = v
return state_dict
elif self.adapter_type == 'mean_flow':
d = self.mean_flow_adapter.get_state_dict()
for k, v in d.items():
state_dict[k] = v
return state_dict
elif self.adapter_type == 'i2v':
d = self.i2v_adapter.get_state_dict()
for k, v in d.items():
@@ -757,7 +789,7 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']:
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v', 'mean_flow']:
return prompt
elif self.adapter_type == 'text_encoder':
# todo allow for training
@@ -1319,6 +1351,10 @@ class CustomAdapter(torch.nn.Module):
param_list = self.control_lora.get_params()
for param in param_list:
yield param
elif self.config.type == 'mean_flow':
param_list = self.mean_flow_adapter.get_params()
for param in param_list:
yield param
elif self.config.type == 'i2v':
param_list = self.i2v_adapter.get_params()
for param in param_list:

View File

@@ -84,15 +84,16 @@ class FileItemDTO(
video.release()
size_database[file_key] = (width, height, file_signature)
else:
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now.
# process width and height
# try:
# w, h = image_utils.get_image_size(self.path)
# except image_utils.UnknownImageFormat:
# print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
# f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
w, h = img.size
if self.dataset_config.fast_image_size:
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
try:
w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
else:
img = exif_transpose(Image.open(self.path))
w, h = img.size
size_database[file_key] = (w, h, file_signature)
self.width: int = w
self.height: int = h

View File

@@ -815,7 +815,10 @@ class BaseModel:
# predict the noise residual
if self.unet.device != self.device_torch:
self.unet.to(self.device_torch)
try:
self.unet.to(self.device_torch)
except Exception as e:
pass
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)

View File

@@ -135,7 +135,7 @@ class ControlLoraAdapter(torch.nn.Module):
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
if hasattr(sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
network_kwargs['target_lin_modules'] = sd.target_lora_modules
if 'ignore_if_contains' not in network_kwargs:
network_kwargs['ignore_if_contains'] = []

View File

@@ -127,18 +127,20 @@ class DFEBlock(nn.Module):
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.act = nn.GELU()
self.proj = nn.Conv2d(channels, channels, 1)
def forward(self, x):
x_in = x
x = self.conv1(x)
x = self.conv2(x)
x = self.act(x)
x = self.proj(x)
x = x + x_in
return x
class DiffusionFeatureExtractor(nn.Module):
def __init__(self, in_channels=32):
def __init__(self, in_channels=16):
super().__init__()
self.version = 1
num_blocks = 6

View File

@@ -176,60 +176,3 @@ def add_model_gpu_splitter_to_flux(
transformer._pre_gpu_split_to = transformer.to
transformer.to = partial(new_device_to, transformer)
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
timesteps_emb = timesteps_emb_start + timesteps_emb_end
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)

View File

@@ -1,567 +0,0 @@
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.utils import logging
from diffusers.models.attention import LuminaFeedForward
from diffusers.models.attention_processor import Attention
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
import torch
from torch.profiler import profile, record_function, ProfilerActivity
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
do_profile = False
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self,
hidden_size: int = 4096,
cap_feat_dim: int = 2048,
frequency_embedding_size: int = 256,
norm_eps: float = 1e-5,
) -> None:
super().__init__()
self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
)
self.timestep_embedder = TimestepEmbedding(
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
)
self.caption_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
)
def forward(
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(encoder_hidden_states)
return time_embed, caption_embed
class Lumina2AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)
# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
query, key = query.to(dtype), key.to(dtype)
# Apply proportional attention if true
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale
# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.type_as(query)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class Lumina2TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
num_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
modulation: bool = True,
) -> None:
super().__init__()
self.head_dim = dim // num_attention_heads
self.modulation = modulation
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
qk_norm="rms_norm",
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=Lumina2AttnProcessor2_0(),
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
if modulation:
self.norm1 = LuminaRMSNormZero(
embedding_dim=dim,
norm_eps=norm_eps,
norm_elementwise_affine=True,
)
else:
self.norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
return hidden_states
class Lumina2RotaryPosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.patch_size = patch_size
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
freqs_cis = []
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64)
freqs_cis.append(emb)
return freqs_cis
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
result = []
for i in range(len(self.axes_dim)):
freqs = self.freqs_cis[i].to(ids.device)
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
batch_size = len(hidden_states)
p_h = p_w = self.patch_size
device = hidden_states[0].device
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]
max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
for i in range(batch_size):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // p_h, W // p_w
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
row_ids = (
torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
)
col_ids = (
torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
)
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
freqs_cis = self._get_freqs_cis(position_ids)
cap_freqs_cis_shape = list(freqs_cis.shape)
cap_freqs_cis_shape[1] = attention_mask.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(batch_size):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
flat_hidden_states = []
for i in range(batch_size):
img = hidden_states[i]
C, H, W = img.size()
img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_hidden_states.append(img)
hidden_states = flat_hidden_states
padded_img_embed = torch.zeros(
batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
)
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
for i in range(batch_size):
padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i]
padded_img_mask[i, : l_effective_img_len[i]] = True
return (
padded_img_embed,
padded_img_mask,
img_sizes,
l_effective_cap_len,
l_effective_img_len,
freqs_cis,
cap_freqs_cis,
img_freqs_cis,
max_seq_len,
)
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
Lumina2NextDiT: Diffusion model with a Transformer backbone.
Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
in_channels (`int`, *optional*, defaults to 4):
The number of input channels for the model. Typically, this matches the number of channels in the input
images.
hidden_size (`int`, *optional*, defaults to 4096):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
num_layers (`int`, *optional*, default to 32):
The number of layers in the model. This defines the depth of the neural network.
num_attention_heads (`int`, *optional*, defaults to 32):
The number of attention heads in each attention layer. This parameter specifies how many separate attention
mechanisms are used.
num_kv_heads (`int`, *optional*, defaults to 8):
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
If None, it defaults to num_attention_heads.
multiple_of (`int`, *optional*, defaults to 256):
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
configurations.
ffn_dim_multiplier (`float`, *optional*):
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
the model configuration.
norm_eps (`float`, *optional*, defaults to 1e-5):
A small value added to the denominator for numerical stability in normalization layers.
scaling_factor (`float`, *optional*, defaults to 1.0):
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
overall scale of the model's operations.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["Lumina2TransformerBlock"]
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 2304,
num_layers: int = 26,
num_refiner_layers: int = 2,
num_attention_heads: int = 24,
num_kv_heads: int = 8,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
scaling_factor: float = 1.0,
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
axes_lens: Tuple[int, int, int] = (300, 512, 512),
cap_feat_dim: int = 1024,
) -> None:
super().__init__()
self.out_channels = out_channels or in_channels
# 1. Positional, patch & conditional embeddings
self.rope_embedder = Lumina2RotaryPosEmbed(
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
)
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
)
# 2. Noise and context refinement blocks
self.noise_refiner = nn.ModuleList(
[
Lumina2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True,
)
for _ in range(num_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
Lumina2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=False,
)
for _ in range(num_refiner_layers)
]
)
# 3. Transformer blocks
self.layers = nn.ModuleList(
[
Lumina2TransformerBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
modulation=True,
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
bias=True,
out_dim=patch_size * patch_size * self.out_channels,
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
hidden_size = self.config.get("hidden_size", 2304)
# pad or slice text encoder
if encoder_hidden_states.shape[2] > hidden_size:
encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size]
elif encoder_hidden_states.shape[2] < hidden_size:
encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2]))
batch_size = hidden_states.size(0)
if do_profile:
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
prof.start()
# 1. Condition, positional & patch embedding
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
(
hidden_states,
hidden_mask,
hidden_sizes,
encoder_hidden_len,
hidden_len,
joint_rotary_emb,
encoder_rotary_emb,
hidden_rotary_emb,
max_seq_len,
) = self.rope_embedder(hidden_states, attention_mask)
hidden_states = self.x_embedder(hidden_states)
# 2. Context & noise refinement
for layer in self.context_refiner:
encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb)
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb)
# 3. Attention mask preparation
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i in range(batch_size):
cap_len = encoder_hidden_len[i]
img_len = hidden_len[i]
mask[i, : cap_len + img_len] = True
padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
hidden_states = padded_hidden_states
# 4. Transformer blocks
for layer in self.layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb)
else:
hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb)
# 5. Output norm & projection & unpatchify
hidden_states = self.norm_out(hidden_states, temb)
height_tokens = width_tokens = self.config.patch_size
output = []
for i in range(len(hidden_sizes)):
height, width = hidden_sizes[i]
begin = encoder_hidden_len[i]
end = begin + (height // height_tokens) * (width // width_tokens)
output.append(
hidden_states[i][begin:end]
.view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels)
.permute(4, 0, 2, 1, 3)
.flatten(3, 4)
.flatten(1, 2)
)
output = torch.stack(output, dim=0)
if do_profile:
torch.cuda.synchronize() # Make sure all CUDA ops are done
prof.stop()
print("\n==== Profile Results ====")
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -0,0 +1,282 @@
import inspect
import weakref
import torch
from typing import TYPE_CHECKING
from toolkit.lora_special import LoRASpecialNetwork
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import (
CombinedTimestepTextProjEmbeddings,
CombinedTimestepGuidanceTextProjEmbeddings,
)
from functools import partial
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
from toolkit.custom_adapter import CustomAdapter
def mean_flow_time_text_embed_forward(
self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection
):
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
# make zero timestep ending if none is passed
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat(
[timestep, torch.zeros_like(timestep)], dim=0
) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
timesteps_proj.to(dtype=pooled_projection.dtype)
) # (N, D)
# mean flow stuff
if mean_flow_adapter.is_active:
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
orig_dtype = timesteps_emb.dtype
timesteps_emb = timesteps_emb.to(torch.float32)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
)
timesteps_emb = timesteps_emb.to(orig_dtype)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(
self: CombinedTimestepGuidanceTextProjEmbeddings,
timestep,
guidance,
pooled_projection,
):
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
# make zero timestep ending if none is passed
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat(
[timestep, torch.zeros_like(timestep)], dim=0
) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
timesteps_proj.to(dtype=pooled_projection.dtype)
) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(
guidance_proj.to(dtype=pooled_projection.dtype)
) # (N, D)
# mean flow stuff
if mean_flow_adapter.is_active:
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
orig_dtype = timesteps_emb.dtype
timesteps_emb = timesteps_emb.to(torch.float32)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
)
timesteps_emb = timesteps_emb.to(orig_dtype)
time_guidance_emb = timesteps_emb + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(
transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings
):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)
class MeanFlowAdapter(torch.nn.Module):
def __init__(
self,
adapter: "CustomAdapter",
sd: "StableDiffusion",
config: "AdapterConfig",
train_config: "TrainConfig",
):
super().__init__()
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref = weakref.ref(sd)
self.model_config: ModelConfig = sd.model_config
self.network_config = config.lora_config
self.train_config = train_config
self.device_torch = sd.device_torch
self.lora = None
if self.network_config is not None:
network_kwargs = (
{}
if self.network_config.network_kwargs is None
else self.network_config.network_kwargs
)
if hasattr(sd, "target_lora_modules"):
network_kwargs["target_lin_modules"] = sd.target_lora_modules
if "ignore_if_contains" not in network_kwargs:
network_kwargs["ignore_if_contains"] = []
self.lora = LoRASpecialNetwork(
text_encoder=sd.text_encoder,
unet=sd.unet,
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_v3=self.model_config.is_v3,
is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow,
is_flux=self.model_config.is_flux,
is_lumina2=self.model_config.is_lumina2,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,
use_bias=False,
is_lorm=False,
network_config=self.network_config,
network_type=self.network_config.type,
transformer_only=self.network_config.transformer_only,
is_transformer=sd.is_transformer,
base_model=sd,
**network_kwargs,
)
self.lora.force_to(self.device_torch, dtype=torch.float32)
self.lora._update_torch_multiplier()
self.lora.apply_to(
sd.text_encoder,
sd.unet,
self.train_config.train_text_encoder,
self.train_config.train_unet,
)
self.lora.can_merge_in = False
self.lora.prepare_grad_etc(sd.text_encoder, sd.unet)
if self.train_config.gradient_checkpointing:
self.lora.enable_gradient_checkpointing()
emb_dim = None
if self.model_config.arch in ["flux", "flex2", "flex2"]:
transformer: FluxTransformer2DModel = sd.unet
emb_dim = (
transformer.config.num_attention_heads
* transformer.config.attention_head_dim
)
convert_flux_to_mean_flow(transformer)
else:
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
self.mean_flow_timestep_embedder = torch.nn.Linear(
emb_dim * 2,
emb_dim,
)
# make the model function as before adding this adapter by initializing the weights
with torch.no_grad():
self.mean_flow_timestep_embedder.weight.zero_()
self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim)
self.mean_flow_timestep_embedder.bias.zero_()
self.mean_flow_timestep_embedder.to(self.device_torch)
# add our adapter as a weak ref
if self.model_config.arch in ["flux", "flex2", "flex2"]:
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
def get_params(self):
if self.lora is not None:
config = {
"text_encoder_lr": self.train_config.lr,
"unet_lr": self.train_config.lr,
}
sig = inspect.signature(self.lora.prepare_optimizer_params)
if "default_lr" in sig.parameters:
config["default_lr"] = self.train_config.lr
if "learning_rate" in sig.parameters:
config["learning_rate"] = self.train_config.lr
params_net = self.lora.prepare_optimizer_params(**config)
# we want only tensors here
params = []
for p in params_net:
if isinstance(p, dict):
params += p["params"]
elif isinstance(p, torch.Tensor):
params.append(p)
elif isinstance(p, list):
params += p
else:
params = []
# make sure the embedder is float32
self.mean_flow_timestep_embedder.to(torch.float32)
self.mean_flow_timestep_embedder.requires_grad = True
self.mean_flow_timestep_embedder.train()
params += list(self.mean_flow_timestep_embedder.parameters())
# we need to be able to yield from the list like yield from params
return params
def load_weights(self, state_dict, strict=True):
lora_sd = {}
mean_flow_embedder_sd = {}
for key, value in state_dict.items():
if "mean_flow_timestep_embedder" in key:
new_key = key.replace("transformer.mean_flow_timestep_embedder.", "")
mean_flow_embedder_sd[new_key] = value
else:
lora_sd[key] = value
# todo process state dict before loading for models that need it
if self.lora is not None:
self.lora.load_weights(lora_sd)
self.mean_flow_timestep_embedder.load_state_dict(
mean_flow_embedder_sd, strict=False
)
def get_state_dict(self):
if self.lora is not None:
lora_sd = self.lora.get_state_dict(dtype=torch.float32)
else:
lora_sd = {}
# todo make sure we match loras elseware.
mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict()
for key, value in mean_flow_embedder_sd.items():
lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value
return lora_sd
@property
def is_active(self):
return self.adapter_ref().is_active

View File

@@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer):
do_paramiter_swapping=False,
paramiter_swapping_factor=0.1,
stochastic_accumulation=True,
stochastic_rounding=True,
):
self.stochastic_rounding = stochastic_rounding
if lr is not None and relative_step:
raise ValueError(
"Cannot combine manual `lr` and `relative_step=True` options")
@@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer):
p_data_fp32.add_(-update)
if p.dtype != torch.float32:
if p.dtype != torch.float32 and self.stochastic_rounding:
# apply stochastic rounding
copy_stochastic(p, p_data_fp32)

View File

@@ -112,86 +112,57 @@ def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
return 0, 8 # Int8 doesn't have mantissa bits
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor):
# adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5
# create a random 16 bit integer
result = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << 16),
)
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
del result
def copy_stochastic(
target: torch.Tensor,
source: torch.Tensor,
eps: Optional[float] = None
) -> None:
"""
Performs stochastic rounding from source tensor to target tensor.
Args:
target: Destination tensor (determines the target format)
source: Source tensor (typically float32)
eps: Optional minimum value for stochastic rounding (for numerical stability)
"""
def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None:
with torch.no_grad():
# If target is float32, just copy directly
# assert if target is on cpu, throw error
assert target.device.type != 'cpu', "Target is on cpu!"
assert source.device.type != 'cpu', "Source is on cpu!"
if target.dtype == torch.float32:
target.copy_(source)
return
# Special handling for int8
if target.dtype == torch.int8:
# Scale the source values to utilize the full int8 range
scaled = source * 127.0 # Scale to [-127, 127]
# Add random noise for stochastic rounding
noise = torch.rand_like(scaled) - 0.5
rounded = torch.round(scaled + noise)
# Clamp to int8 range
clamped = torch.clamp(rounded, -127, 127)
target.copy_(clamped.to(torch.int8))
if target.dtype == torch.bfloat16:
copy_stochastic_bf16(target, source)
return
mantissa_bits, _ = get_format_params(target.dtype)
round_factor = 2 ** (23 - mantissa_bits)
# Convert source to int32 view
source_int = source.view(dtype=torch.int32)
# Add uniform noise for stochastic rounding
noise = torch.rand_like(source, device=source.device) - 0.5
rounded = torch.round(source * round_factor + noise)
result_float = rounded / round_factor
# Calculate number of bits to round
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
# Create random integers for stochastic rounding
rand = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << bits_to_round),
)
# Add random values to the bits that will be rounded off
result = source_int.clone()
result.add_(rand)
# Mask to keep only the bits we want
# Create mask with 1s in positions we want to keep
mask = (-1) << bits_to_round
result.bitwise_and_(mask)
# Handle minimum value threshold if specified
if eps is not None:
eps_int = torch.tensor(
eps, dtype=torch.float32).view(dtype=torch.int32)
zero_mask = (result.abs() < eps_int)
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
# Convert back to float32 view
result_float = result.view(dtype=torch.float32)
# Special handling for float8 formats
# Clamp for float8
if target.dtype == torch.float8_e4m3fn:
result_float.clamp_(-448.0, 448.0)
elif target.dtype == torch.float8_e5m2:
result_float.clamp_(-57344.0, 57344.0)
# Copy the result to the target tensor
update_parameter(target, result_float)
# target.copy_(result_float)
del result, rand, source_int
class Auto8bitTensor:

View File

@@ -50,8 +50,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \
FluxControlPipeline
from toolkit.models.lumina2 import Lumina2Transformer2DModel
FluxControlPipeline, Lumina2Transformer2DModel
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -1763,6 +1762,15 @@ class StableDiffusion:
)
noise = apply_noise_offset(noise, noise_offset)
return noise
def get_latent_noise_from_latents(
self,
latents: torch.Tensor,
noise_offset=0.0
):
noise = torch.randn_like(latents)
noise = apply_noise_offset(noise, noise_offset)
return noise
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
@@ -2170,7 +2178,7 @@ class StableDiffusion:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=t,
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
@@ -2529,8 +2537,8 @@ class StableDiffusion:
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
latents = latents.to(device, dtype=dtype)
self.vae.to(self.device_torch)
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
images = self.vae.decode(latents).sample
images = images.to(device, dtype=dtype)

31
ui/cron/worker.ts Normal file
View File

@@ -0,0 +1,31 @@
class CronWorker {
interval: number;
is_running: boolean;
intervalId: NodeJS.Timeout;
constructor() {
this.interval = 1000; // Default interval of 1 second
this.is_running = false;
this.intervalId = setInterval(() => {
this.run();
}, this.interval);
}
async run() {
if (this.is_running) {
return;
}
this.is_running = true;
try {
// Loop logic here
await this.loop();
} catch (error) {
console.error('Error in cron worker loop:', error);
}
this.is_running = false;
}
async loop() {}
}
// it automatically starts the loop
const cronWorker = new CronWorker();
console.log('Cron worker started with interval:', cronWorker.interval, 'ms');

654
ui/package-lock.json generated
View File

@@ -30,10 +30,12 @@
"@types/node": "^20",
"@types/react": "^19",
"@types/react-dom": "^19",
"concurrently": "^9.1.2",
"postcss": "^8",
"prettier": "^3.5.1",
"prettier-basic": "^1.0.0",
"tailwindcss": "^3.4.1",
"ts-node-dev": "^2.0.0",
"typescript": "^5"
}
},
@@ -169,6 +171,28 @@
"node": ">=6.9.0"
}
},
"node_modules/@cspotcode/source-map-support": {
"version": "0.8.1",
"resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
"integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
"dev": true,
"dependencies": {
"@jridgewell/trace-mapping": "0.3.9"
},
"engines": {
"node": ">=12"
}
},
"node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": {
"version": "0.3.9",
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
"integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
"dev": true,
"dependencies": {
"@jridgewell/resolve-uri": "^3.0.3",
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@emnapi/runtime": {
"version": "1.3.1",
"resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.3.1.tgz",
@@ -1168,6 +1192,30 @@
"node": ">= 6"
}
},
"node_modules/@tsconfig/node10": {
"version": "1.0.11",
"resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.11.tgz",
"integrity": "sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==",
"dev": true
},
"node_modules/@tsconfig/node12": {
"version": "1.0.11",
"resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
"integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
"dev": true
},
"node_modules/@tsconfig/node14": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
"integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
"dev": true
},
"node_modules/@tsconfig/node16": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.4.tgz",
"integrity": "sha512-vxhUy4J8lyeyinH7Azl1pdd43GJhZH/tP2weN8TntQblOY+A0XbT8DJk1/oCPuOOyg/Ja757rG0CgHcWC8OfMA==",
"dev": true
},
"node_modules/@types/node": {
"version": "20.17.19",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.19.tgz",
@@ -1207,6 +1255,18 @@
"@types/react": "*"
}
},
"node_modules/@types/strip-bom": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/@types/strip-bom/-/strip-bom-3.0.0.tgz",
"integrity": "sha512-xevGOReSYGM7g/kUBZzPqCrR/KYAo+F0yiPc85WFTJa0MSLtyFTVTU6cJu/aV4mid7IffDIWqo69THF2o4JiEQ==",
"dev": true
},
"node_modules/@types/strip-json-comments": {
"version": "0.0.30",
"resolved": "https://registry.npmjs.org/@types/strip-json-comments/-/strip-json-comments-0.0.30.tgz",
"integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==",
"dev": true
},
"node_modules/abbrev": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/abbrev/-/abbrev-1.1.1.tgz",
@@ -1214,6 +1274,30 @@
"license": "ISC",
"optional": true
},
"node_modules/acorn": {
"version": "8.15.0",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"dev": true,
"bin": {
"acorn": "bin/acorn"
},
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/acorn-walk": {
"version": "8.3.4",
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.4.tgz",
"integrity": "sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==",
"dev": true,
"dependencies": {
"acorn": "^8.11.0"
},
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/agent-base": {
"version": "6.0.2",
"resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz",
@@ -1465,6 +1549,12 @@
"ieee754": "^1.1.13"
}
},
"node_modules/buffer-from": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz",
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
"dev": true
},
"node_modules/busboy": {
"version": "1.6.0",
"resolved": "https://registry.npmjs.org/busboy/-/busboy-1.6.0.tgz",
@@ -1626,6 +1716,49 @@
}
]
},
"node_modules/chalk": {
"version": "4.1.2",
"resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz",
"integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==",
"dev": true,
"dependencies": {
"ansi-styles": "^4.1.0",
"supports-color": "^7.1.0"
},
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/chalk/chalk?sponsor=1"
}
},
"node_modules/chalk/node_modules/ansi-styles": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
"dev": true,
"dependencies": {
"color-convert": "^2.0.1"
},
"engines": {
"node": ">=8"
},
"funding": {
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
}
},
"node_modules/chalk/node_modules/supports-color": {
"version": "7.2.0",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz",
"integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==",
"dev": true,
"dependencies": {
"has-flag": "^4.0.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/chokidar": {
"version": "3.6.0",
"resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz",
@@ -1691,6 +1824,93 @@
"resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz",
"integrity": "sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA=="
},
"node_modules/cliui": {
"version": "8.0.1",
"resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz",
"integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==",
"dev": true,
"dependencies": {
"string-width": "^4.2.0",
"strip-ansi": "^6.0.1",
"wrap-ansi": "^7.0.0"
},
"engines": {
"node": ">=12"
}
},
"node_modules/cliui/node_modules/ansi-regex": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz",
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
"dev": true,
"engines": {
"node": ">=8"
}
},
"node_modules/cliui/node_modules/ansi-styles": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
"dev": true,
"dependencies": {
"color-convert": "^2.0.1"
},
"engines": {
"node": ">=8"
},
"funding": {
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
}
},
"node_modules/cliui/node_modules/emoji-regex": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
"integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==",
"dev": true
},
"node_modules/cliui/node_modules/string-width": {
"version": "4.2.3",
"resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
"integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
"dev": true,
"dependencies": {
"emoji-regex": "^8.0.0",
"is-fullwidth-code-point": "^3.0.0",
"strip-ansi": "^6.0.1"
},
"engines": {
"node": ">=8"
}
},
"node_modules/cliui/node_modules/strip-ansi": {
"version": "6.0.1",
"resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
"integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
"dev": true,
"dependencies": {
"ansi-regex": "^5.0.1"
},
"engines": {
"node": ">=8"
}
},
"node_modules/cliui/node_modules/wrap-ansi": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz",
"integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==",
"dev": true,
"dependencies": {
"ansi-styles": "^4.0.0",
"string-width": "^4.1.0",
"strip-ansi": "^6.0.0"
},
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/chalk/wrap-ansi?sponsor=1"
}
},
"node_modules/clone": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/clone/-/clone-2.1.2.tgz",
@@ -1782,8 +2002,33 @@
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
"integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==",
"license": "MIT",
"optional": true
"devOptional": true,
"license": "MIT"
},
"node_modules/concurrently": {
"version": "9.1.2",
"resolved": "https://registry.npmjs.org/concurrently/-/concurrently-9.1.2.tgz",
"integrity": "sha512-H9MWcoPsYddwbOGM6difjVwVZHl63nwMEwDJG/L7VGtuaJhb12h2caPG2tVPWs7emuYix252iGfqOyrz1GczTQ==",
"dev": true,
"dependencies": {
"chalk": "^4.1.2",
"lodash": "^4.17.21",
"rxjs": "^7.8.1",
"shell-quote": "^1.8.1",
"supports-color": "^8.1.1",
"tree-kill": "^1.2.2",
"yargs": "^17.7.2"
},
"bin": {
"conc": "dist/bin/concurrently.js",
"concurrently": "dist/bin/concurrently.js"
},
"engines": {
"node": ">=18"
},
"funding": {
"url": "https://github.com/open-cli-tools/concurrently?sponsor=1"
}
},
"node_modules/console-control-strings": {
"version": "1.1.0",
@@ -1820,6 +2065,12 @@
"node": ">= 6"
}
},
"node_modules/create-require": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
"dev": true
},
"node_modules/cross-spawn": {
"version": "7.0.6",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
@@ -1921,6 +2172,15 @@
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==",
"dev": true
},
"node_modules/diff": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
"integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
"dev": true,
"engines": {
"node": ">=0.3.1"
}
},
"node_modules/dlv": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz",
@@ -1949,6 +2209,15 @@
"node": ">= 0.4"
}
},
"node_modules/dynamic-dedupe": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/dynamic-dedupe/-/dynamic-dedupe-0.3.0.tgz",
"integrity": "sha512-ssuANeD+z97meYOqd50e04Ze5qp4bPqo8cCkI4TRjZkzAUgIDTrXV1R8QCdINpiI+hw14+rYazvTRdQrz0/rFQ==",
"dev": true,
"dependencies": {
"xtend": "^4.0.0"
}
},
"node_modules/eastasianwidth": {
"version": "0.2.0",
"resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz",
@@ -2051,6 +2320,15 @@
"node": ">= 0.4"
}
},
"node_modules/escalade": {
"version": "3.2.0",
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
"integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==",
"dev": true,
"engines": {
"node": ">=6"
}
},
"node_modules/escape-string-regexp": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz",
@@ -2225,8 +2503,8 @@
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz",
"integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==",
"license": "ISC",
"optional": true
"devOptional": true,
"license": "ISC"
},
"node_modules/fsevents": {
"version": "2.3.3",
@@ -2322,6 +2600,15 @@
"node": ">=8"
}
},
"node_modules/get-caller-file": {
"version": "2.0.5",
"resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz",
"integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==",
"dev": true,
"engines": {
"node": "6.* || 8.* || >= 10.*"
}
},
"node_modules/get-intrinsic": {
"version": "1.2.7",
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.7.tgz",
@@ -2421,6 +2708,15 @@
"license": "ISC",
"optional": true
},
"node_modules/has-flag": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz",
"integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==",
"dev": true,
"engines": {
"node": ">=8"
}
},
"node_modules/has-symbols": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
@@ -2598,8 +2894,8 @@
"resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz",
"integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==",
"deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.",
"devOptional": true,
"license": "ISC",
"optional": true,
"dependencies": {
"once": "^1.3.0",
"wrappy": "1"
@@ -2784,6 +3080,12 @@
"resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz",
"integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg=="
},
"node_modules/lodash": {
"version": "4.17.21",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
"dev": true
},
"node_modules/loose-envify": {
"version": "1.4.0",
"resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz",
@@ -2810,6 +3112,12 @@
"react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/make-error": {
"version": "1.3.6",
"resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
"integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
"dev": true
},
"node_modules/make-fetch-happen": {
"version": "9.1.0",
"resolved": "https://registry.npmjs.org/make-fetch-happen/-/make-fetch-happen-9.1.0.tgz",
@@ -3499,8 +3807,8 @@
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz",
"integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==",
"devOptional": true,
"license": "MIT",
"optional": true,
"engines": {
"node": ">=0.10.0"
}
@@ -4004,6 +4312,15 @@
"node": ">=8.10.0"
}
},
"node_modules/require-directory": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz",
"integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==",
"dev": true,
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/resolve": {
"version": "1.22.10",
"resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz",
@@ -4137,6 +4454,15 @@
"queue-microtask": "^1.2.2"
}
},
"node_modules/rxjs": {
"version": "7.8.2",
"resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.2.tgz",
"integrity": "sha512-dhKf903U/PQZY6boNNtAGdWbG85WAbjT/1xYoZIC7FAY0yWapOBQVsVrDl58W86//e1VpMNBtRV4MaXfdMySFA==",
"dev": true,
"dependencies": {
"tslib": "^2.1.0"
}
},
"node_modules/safe-buffer": {
"version": "5.2.1",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",
@@ -4247,6 +4573,18 @@
"node": ">=8"
}
},
"node_modules/shell-quote": {
"version": "1.8.3",
"resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz",
"integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==",
"dev": true,
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/signal-exit": {
"version": "4.1.0",
"resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz",
@@ -4370,6 +4708,25 @@
"node": ">=0.10.0"
}
},
"node_modules/source-map-support": {
"version": "0.5.21",
"resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz",
"integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==",
"dev": true,
"dependencies": {
"buffer-from": "^1.0.0",
"source-map": "^0.6.0"
}
},
"node_modules/source-map-support/node_modules/source-map": {
"version": "0.6.1",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
"dev": true,
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/sprintf-js": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz",
@@ -4545,6 +4902,15 @@
"node": ">=8"
}
},
"node_modules/strip-bom": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz",
"integrity": "sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==",
"dev": true,
"engines": {
"node": ">=4"
}
},
"node_modules/strip-json-comments": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz",
@@ -4603,6 +4969,21 @@
"node": ">=16 || 14 >=14.17"
}
},
"node_modules/supports-color": {
"version": "8.1.1",
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz",
"integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==",
"dev": true,
"dependencies": {
"has-flag": "^4.0.0"
},
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/chalk/supports-color?sponsor=1"
}
},
"node_modules/supports-preserve-symlinks-flag": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz",
@@ -4749,12 +5130,172 @@
"node": ">=8.0"
}
},
"node_modules/tree-kill": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/tree-kill/-/tree-kill-1.2.2.tgz",
"integrity": "sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==",
"dev": true,
"bin": {
"tree-kill": "cli.js"
}
},
"node_modules/ts-interface-checker": {
"version": "0.1.13",
"resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz",
"integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==",
"dev": true
},
"node_modules/ts-node-dev": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/ts-node-dev/-/ts-node-dev-2.0.0.tgz",
"integrity": "sha512-ywMrhCfH6M75yftYvrvNarLEY+SUXtUvU8/0Z6llrHQVBx12GiFk5sStF8UdfE/yfzk9IAq7O5EEbTQsxlBI8w==",
"dev": true,
"dependencies": {
"chokidar": "^3.5.1",
"dynamic-dedupe": "^0.3.0",
"minimist": "^1.2.6",
"mkdirp": "^1.0.4",
"resolve": "^1.0.0",
"rimraf": "^2.6.1",
"source-map-support": "^0.5.12",
"tree-kill": "^1.2.2",
"ts-node": "^10.4.0",
"tsconfig": "^7.0.0"
},
"bin": {
"ts-node-dev": "lib/bin.js",
"tsnd": "lib/bin.js"
},
"engines": {
"node": ">=0.8.0"
},
"peerDependencies": {
"node-notifier": "*",
"typescript": "*"
},
"peerDependenciesMeta": {
"node-notifier": {
"optional": true
}
}
},
"node_modules/ts-node-dev/node_modules/arg": {
"version": "4.1.3",
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
"dev": true
},
"node_modules/ts-node-dev/node_modules/brace-expansion": {
"version": "1.1.12",
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
"dev": true,
"dependencies": {
"balanced-match": "^1.0.0",
"concat-map": "0.0.1"
}
},
"node_modules/ts-node-dev/node_modules/glob": {
"version": "7.2.3",
"resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz",
"integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==",
"deprecated": "Glob versions prior to v9 are no longer supported",
"dev": true,
"dependencies": {
"fs.realpath": "^1.0.0",
"inflight": "^1.0.4",
"inherits": "2",
"minimatch": "^3.1.1",
"once": "^1.3.0",
"path-is-absolute": "^1.0.0"
},
"engines": {
"node": "*"
},
"funding": {
"url": "https://github.com/sponsors/isaacs"
}
},
"node_modules/ts-node-dev/node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
"dev": true,
"dependencies": {
"brace-expansion": "^1.1.7"
},
"engines": {
"node": "*"
}
},
"node_modules/ts-node-dev/node_modules/rimraf": {
"version": "2.7.1",
"resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz",
"integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==",
"deprecated": "Rimraf versions prior to v4 are no longer supported",
"dev": true,
"dependencies": {
"glob": "^7.1.3"
},
"bin": {
"rimraf": "bin.js"
}
},
"node_modules/ts-node-dev/node_modules/ts-node": {
"version": "10.9.2",
"resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.2.tgz",
"integrity": "sha512-f0FFpIdcHgn8zcPSbf1dRevwt047YMnaiJM3u2w2RewrB+fob/zePZcrOyQoLMMO7aBIddLcQIEK5dYjkLnGrQ==",
"dev": true,
"dependencies": {
"@cspotcode/source-map-support": "^0.8.0",
"@tsconfig/node10": "^1.0.7",
"@tsconfig/node12": "^1.0.7",
"@tsconfig/node14": "^1.0.0",
"@tsconfig/node16": "^1.0.2",
"acorn": "^8.4.1",
"acorn-walk": "^8.1.1",
"arg": "^4.1.0",
"create-require": "^1.1.0",
"diff": "^4.0.1",
"make-error": "^1.1.1",
"v8-compile-cache-lib": "^3.0.1",
"yn": "3.1.1"
},
"bin": {
"ts-node": "dist/bin.js",
"ts-node-cwd": "dist/bin-cwd.js",
"ts-node-esm": "dist/bin-esm.js",
"ts-node-script": "dist/bin-script.js",
"ts-node-transpile-only": "dist/bin-transpile.js",
"ts-script": "dist/bin-script-deprecated.js"
},
"peerDependencies": {
"@swc/core": ">=1.2.50",
"@swc/wasm": ">=1.2.50",
"@types/node": "*",
"typescript": ">=2.7"
},
"peerDependenciesMeta": {
"@swc/core": {
"optional": true
},
"@swc/wasm": {
"optional": true
}
}
},
"node_modules/tsconfig": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/tsconfig/-/tsconfig-7.0.0.tgz",
"integrity": "sha512-vZXmzPrL+EmC4T/4rVlT2jNVMWCi/O4DIiSj3UHg1OE5kCKbk4mfrXc6dZksLgRM/TZlKnousKH9bbTazUWRRw==",
"dev": true,
"dependencies": {
"@types/strip-bom": "^3.0.0",
"@types/strip-json-comments": "0.0.30",
"strip-bom": "^3.0.0",
"strip-json-comments": "^2.0.0"
}
},
"node_modules/tslib": {
"version": "2.8.1",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz",
@@ -4829,6 +5370,12 @@
"resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz",
"integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw=="
},
"node_modules/v8-compile-cache-lib": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
"integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
"dev": true
},
"node_modules/which": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
@@ -4996,6 +5543,24 @@
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==",
"license": "ISC"
},
"node_modules/xtend": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",
"integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==",
"dev": true,
"engines": {
"node": ">=0.4"
}
},
"node_modules/y18n": {
"version": "5.0.8",
"resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz",
"integrity": "sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==",
"dev": true,
"engines": {
"node": ">=10"
}
},
"node_modules/yallist": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",
@@ -5012,6 +5577,83 @@
"engines": {
"node": ">= 14"
}
},
"node_modules/yargs": {
"version": "17.7.2",
"resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz",
"integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==",
"dev": true,
"dependencies": {
"cliui": "^8.0.1",
"escalade": "^3.1.1",
"get-caller-file": "^2.0.5",
"require-directory": "^2.1.1",
"string-width": "^4.2.3",
"y18n": "^5.0.5",
"yargs-parser": "^21.1.1"
},
"engines": {
"node": ">=12"
}
},
"node_modules/yargs-parser": {
"version": "21.1.1",
"resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz",
"integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==",
"dev": true,
"engines": {
"node": ">=12"
}
},
"node_modules/yargs/node_modules/ansi-regex": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz",
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
"dev": true,
"engines": {
"node": ">=8"
}
},
"node_modules/yargs/node_modules/emoji-regex": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
"integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==",
"dev": true
},
"node_modules/yargs/node_modules/string-width": {
"version": "4.2.3",
"resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
"integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
"dev": true,
"dependencies": {
"emoji-regex": "^8.0.0",
"is-fullwidth-code-point": "^3.0.0",
"strip-ansi": "^6.0.1"
},
"engines": {
"node": ">=8"
}
},
"node_modules/yargs/node_modules/strip-ansi": {
"version": "6.0.1",
"resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
"integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
"dev": true,
"dependencies": {
"ansi-regex": "^5.0.1"
},
"engines": {
"node": ">=8"
}
},
"node_modules/yn": {
"version": "3.1.1",
"resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
"integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
"dev": true,
"engines": {
"node": ">=6"
}
}
}
}

View File

@@ -3,9 +3,9 @@
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev --turbopack",
"build": "next build",
"start": "next start --port 8675",
"dev": "concurrently -k -n WORKER,UI \"ts-node-dev --respawn --watch cron --transpile-only cron/worker.ts\" \"next dev --turbopack\"",
"build": "tsc -p tsconfig.worker.json && next build",
"start": "concurrently --restart-tries -1 --restart-after 1000 -n WORKER,UI \"node dist/worker.js\" \"next start --port 8675\"",
"build_and_start": "npm install && npm run update_db && npm run build && npm run start",
"lint": "next lint",
"update_db": "npx prisma generate && npx prisma db push",
@@ -34,10 +34,12 @@
"@types/node": "^20",
"@types/react": "^19",
"@types/react-dom": "^19",
"concurrently": "^9.1.2",
"postcss": "^8",
"prettier": "^3.5.1",
"prettier-basic": "^1.0.0",
"tailwindcss": "^3.4.1",
"ts-node-dev": "^2.0.0",
"typescript": "^5"
},
"prettier": "prettier-basic"

View File

@@ -26,3 +26,13 @@ model Job {
info String @default("")
speed_string String @default("")
}
model Queue {
id String @id @default(uuid())
channel String
job_id String
created_at DateTime @default(now())
updated_at DateTime @updatedAt
status String @default("waiting")
@@index([job_id, channel])
}

View File

@@ -45,7 +45,7 @@ function findImagesRecursively(dir: string): string[] {
const itemPath = path.join(dir, item);
const stat = fs.statSync(itemPath);
if (stat.isDirectory()) {
if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) {
// If it's a directory, recursively search it
results = results.concat(findImagesRecursively(itemPath));
} else {

View File

@@ -47,6 +47,7 @@ export default function SimpleJob({
<TextInput
label="Training Name"
value={jobConfig.config.name}
docKey="config.name"
onChange={value => setJobConfig(value, 'config.name')}
placeholder="Enter training name"
disabled={runId !== null}
@@ -55,12 +56,14 @@ export default function SimpleJob({
<SelectInput
label="GPU ID"
value={`${gpuIDs}`}
docKey="gpuids"
onChange={value => setGpuIDs(value)}
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/>
<TextInput
label="Trigger Word"
value={jobConfig.config.process[0].trigger_word || ''}
docKey="config.process[0].trigger_word"
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
@@ -120,6 +123,7 @@ export default function SimpleJob({
<TextInput
label="Name or Path"
value={jobConfig.config.process[0].model.name_or_path}
docKey="config.process[0].model.name_or_path"
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
@@ -185,22 +189,20 @@ export default function SimpleJob({
max={1024}
required
/>
{
modelArch?.disableSections?.includes('network.conv') ? null : (
<NumberInput
label="Conv Rank"
value={jobConfig.config.process[0].network.conv}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.conv');
setJobConfig(value, 'config.process[0].network.conv_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
/>
)
}
{modelArch?.disableSections?.includes('network.conv') ? null : (
<NumberInput
label="Conv Rank"
value={jobConfig.config.process[0].network.conv}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.conv');
setJobConfig(value, 'config.process[0].network.conv_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
/>
)}
</>
)}
</Card>

View File

@@ -7,6 +7,7 @@ import ConfirmModal from '@/components/ConfirmModal';
import SampleImageModal from '@/components/SampleImageModal';
import { Suspense } from 'react';
import AuthWrapper from '@/components/AuthWrapper';
import DocModal from '@/components/DocModal';
export const dynamic = 'force-dynamic';
@@ -38,6 +39,7 @@ export default function RootLayout({ children }: { children: React.ReactNode })
</AuthWrapper>
</ThemeProvider>
<ConfirmModal />
<DocModal />
<SampleImageModal />
</body>
</html>

View File

@@ -0,0 +1,59 @@
'use client';
import { createGlobalState } from 'react-global-hooks';
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
import React from 'react';
import { ConfigDoc } from '@/types';
export const docState = createGlobalState<ConfigDoc | null>(null);
export const openDoc = (doc: ConfigDoc) => {
docState.set({ ...doc });
};
export default function DocModal() {
const [doc, setDoc] = docState.use();
const isOpen = !!doc;
const onClose = () => {
setDoc(null);
};
return (
<Dialog open={isOpen} onClose={onClose} className="relative z-10">
<DialogBackdrop
transition
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
/>
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
<DialogPanel
transition
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-[50rem] data-closed:sm:translate-y-0 data-closed:sm:scale-95"
>
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
<div className="sm:flex sm:items-start">
<div className="mt-3 text-center sm:mt-0 sm:ml-4 sm:text-left flex-1">
<DialogTitle as="h3" className={`text-base font-semibold `}>
{doc?.title || 'Confirm Action'}
</DialogTitle>
<div className="mt-2 text-sm text-gray-200">{doc?.description}</div>
</div>
</div>
</div>
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
<button
type="button"
data-autofocus
onClick={onClose}
className="mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0"
>
Close
</button>
</div>
</DialogPanel>
</div>
</div>
</Dialog>
);
}

View File

@@ -1,5 +1,6 @@
import Link from 'next/link';
import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react';
import { Home, Settings, BrainCircuit, Images, Plus} from 'lucide-react';
import { FaXTwitter, FaDiscord, FaYoutube } from "react-icons/fa6";
const Sidebar = () => {
const navigation = [
@@ -10,13 +11,16 @@ const Sidebar = () => {
{ name: 'Settings', href: '/settings', icon: Settings },
];
const socialsBoxClass = 'flex flex-col items-center justify-center p-1 hover:bg-gray-800 rounded-lg transition-colors';
const socialIconClass = 'w-5 h-5 text-gray-400 hover:text-white';
return (
<div className="flex flex-col w-59 bg-gray-900 text-gray-100">
<div className="px-4 py-3">
<h1 className="text-l">
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-7 mr-3 inline" />
<span className="font-bold uppercase">Ostris</span>
<span className='ml-2 uppercase text-gray-300'>AI-Toolkit</span>
<span className="ml-2 uppercase text-gray-300">AI-Toolkit</span>
</h1>
</div>
<nav className="flex-1">
@@ -34,7 +38,12 @@ const Sidebar = () => {
))}
</ul>
</nav>
<a href="https://ostris.com/support" target="_blank" rel="noreferrer" className="flex items-center space-x-2 px-4 py-3">
<a
href="https://ostris.com/support"
target="_blank"
rel="noreferrer"
className="flex items-center space-x-2 px-4 py-3"
>
<div className="min-w-[26px] min-h-[26px]">
<svg height="24" version="1.1" width="24" xmlns="http://www.w3.org/2000/svg">
<g transform="translate(0 -1028.4)">
@@ -47,6 +56,39 @@ const Sidebar = () => {
</div>
<div className="uppercase text-gray-500 text-sm mb-2 flex-1 pt-2 pl-0">Support AI-Toolkit</div>
</a>
{/* Social links grid */}
<div className="px-1 py-1 border-t border-gray-800">
<div className="grid grid-cols-3 gap-4">
<a
href="https://discord.gg/VXmU2f5WEU"
target="_blank"
rel="noreferrer"
className={socialsBoxClass}
>
<FaDiscord className={socialIconClass} />
{/* <span className="text-xs text-gray-500 mt-1">Discord</span> */}
</a>
<a
href="https://www.youtube.com/@ostrisai"
target="_blank"
rel="noreferrer"
className={socialsBoxClass}
>
<FaYoutube className={socialIconClass} />
{/* <span className="text-xs text-gray-500 mt-1">YouTube</span> */}
</a>
<a
href="https://x.com/ostrisai"
target="_blank"
rel="noreferrer"
className={socialsBoxClass}
>
<FaXTwitter className={socialIconClass} />
{/* <span className="text-xs text-gray-500 mt-1">X</span> */}
</a>
</div>
</div>
</div>
);
};

View File

@@ -3,6 +3,10 @@
import React, { forwardRef } from 'react';
import classNames from 'classnames';
import dynamic from 'next/dynamic';
import { CircleHelp } from 'lucide-react';
import { getDoc } from '@/docs';
import { openDoc } from '@/components/DocModal';
const Select = dynamic(() => import('react-select'), { ssr: false });
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
@@ -11,6 +15,7 @@ const inputClasses =
export interface InputProps {
label?: string;
docKey?: string;
className?: string;
placeholder?: string;
required?: boolean;
@@ -24,10 +29,20 @@ export interface TextInputProps extends InputProps {
}
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
({ label, value, onChange, placeholder, required, disabled, type = 'text', className }, ref) => {
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => {
const doc = getDoc(docKey);
return (
<div className={classNames(className)}>
{label && <label className={labelClasses}>{label}</label>}
{label && (
<label className={labelClasses}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)}
</label>
)}
<input
ref={ref}
type={type}
@@ -56,7 +71,8 @@ export interface NumberInputProps extends InputProps {
}
export const NumberInput = (props: NumberInputProps) => {
const { label, value, onChange, placeholder, required, min, max } = props;
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
const doc = getDoc(docKey);
// Add controlled internal state to properly handle partial inputs
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
@@ -68,7 +84,16 @@ export const NumberInput = (props: NumberInputProps) => {
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
{label && (
<label className={labelClasses}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)}
</label>
)}
<input
type="number"
value={inputValue}
@@ -120,7 +145,8 @@ export interface SelectInputProps extends InputProps {
}
export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options } = props;
const { label, value, onChange, options, docKey = null } = props;
const doc = getDoc(docKey);
const selectedOption = options.find(option => option.value === value);
return (
<div
@@ -128,7 +154,16 @@ export const SelectInput = (props: SelectInputProps) => {
'opacity-30 cursor-not-allowed': props.disabled,
})}
>
{label && <label className={labelClasses}>{label}</label>}
{label && (
<label className={labelClasses}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)}
</label>
)}
<Select
value={selectedOption}
options={options}
@@ -200,13 +235,24 @@ export const Checkbox = (props: CheckboxProps) => {
interface FormGroupProps {
label?: string;
className?: string;
docKey?: string;
children: React.ReactNode;
}
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children }) => {
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => {
const doc = getDoc(docKey);
return (
<div className={classNames(className)}>
{label && <label className={labelClasses}>{label}</label>}
{label && (
<label className={labelClasses}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)}
</label>
)}
<div className="px-4 space-y-2">{children}</div>
</div>
);

60
ui/src/docs.tsx Normal file
View File

@@ -0,0 +1,60 @@
import React from 'react';
import { ConfigDoc } from '@/types';
const docs: { [key: string]: ConfigDoc } = {
'config.name': {
title: 'Training Name',
description: (
<>
The name of the training job. This name will be used to identify the job in the system and will the the filename
of the final model. It must be unique and can only contain alphanumeric characters, underscores, and dashes. No
spaces or special characters are allowed.
</>
),
},
'gpuids': {
title: 'GPU ID',
description: (
<>
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
However, you can start multiple jobs in parallel, each using a different GPU.
</>
),
},
'config.process[0].trigger_word': {
title: 'Trigger Word',
description: (
<>
Optional: This will be the word or token used to trigger your concept or character.
<br />
<br />
When using a trigger word,
If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have
captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots,
you can use the <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger word.
<br />
<br />
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
</>
),
},
'config.process[0].model.name_or_path': {
title: 'Name or Path',
description: (
<>
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in
diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here.
</>
),
},
};
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
if (key && key in docs) {
return docs[key];
}
return null;
};
export default docs;

View File

@@ -59,7 +59,7 @@ export interface NetworkConfig {
lokr_factor: number;
network_kwargs: {
ignore_if_contains: string[];
}
};
}
export interface SaveConfig {
@@ -125,7 +125,7 @@ export interface ModelConfig {
quantize_kwargs?: QuantizeKwargsConfig;
arch: string;
low_vram: boolean;
model_kwargs: {[key: string]: any};
model_kwargs: { [key: string]: any };
}
export interface SampleConfig {
@@ -173,3 +173,8 @@ export interface JobConfig {
config: ConfigObject;
meta: MetaConfig;
}
export interface ConfigDoc {
title: string;
description: React.ReactNode;
}

15
ui/tsconfig.worker.json Normal file
View File

@@ -0,0 +1,15 @@
{
// tsconfig.worker.json
"compilerOptions": {
"module": "commonjs",
"target": "es2020",
"outDir": "dist",
"moduleResolution": "node",
"types": [
"node"
]
},
"include": [
"cron/**/*.ts"
]
}

View File

@@ -1 +1 @@
VERSION = "0.3.0"
VERSION = "0.3.1"