mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merged in from main
This commit is contained in:
11
README.md
11
README.md
@@ -417,6 +417,17 @@ Everything else should work the same including layer targeting.
|
||||
|
||||
## Updates
|
||||
|
||||
### June 17, 2024
|
||||
- Performance optimizations for batch preparation
|
||||
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
|
||||
|
||||
### June 16, 2024
|
||||
- Hide control images in the UI when viewing datasets
|
||||
- WIP on mean flow loss
|
||||
|
||||
### June 12, 2024
|
||||
- Fixed issue that resulted in blank captions in the dataloader
|
||||
|
||||
### June 10, 2024
|
||||
- Decided to keep track up updates in the readme
|
||||
- Added support for SDXL in the UI
|
||||
|
||||
@@ -36,7 +36,6 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.models.flux import convert_flux_to_mean_flow
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -62,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
self.cfm_cache = None
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
@@ -85,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
|
||||
|
||||
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
||||
self.unconditional_embeds = None
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
if self.trigger_word is None:
|
||||
@@ -96,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# always do a prior prediction when doing diff output preservation
|
||||
self.do_prior_prediction = True
|
||||
|
||||
# store the loss target for a batch so we can use it in a loss
|
||||
self._guidance_loss_target_batch: float = 0.0
|
||||
if isinstance(self.train_config.guidance_loss_target, (int, float)):
|
||||
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target)
|
||||
elif isinstance(self.train_config.guidance_loss_target, list):
|
||||
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0])
|
||||
else:
|
||||
raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}")
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
@@ -135,9 +143,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.train_config.loss_type == "mean_flow":
|
||||
# todo handle non flux models
|
||||
convert_flux_to_mean_flow(self.sd.unet)
|
||||
|
||||
# cache unconditional embeds (blank prompt)
|
||||
with torch.no_grad():
|
||||
self.unconditional_embeds = self.sd.encode_prompt(
|
||||
[self.train_config.unconditional_prompt],
|
||||
long_prompts=self.do_long_prompts
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=self.sd.torch_dtype
|
||||
).detach()
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
@@ -418,15 +433,41 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.dfe is not None:
|
||||
if self.dfe.version == 1:
|
||||
# do diffusion feature extraction on target
|
||||
model = self.sd
|
||||
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||
else:
|
||||
# stepped_latents = noise - noise_pred
|
||||
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||
bs = noise_pred.shape[0]
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
stepped_chunks = []
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx]
|
||||
timestep = timestep_chunks[idx]
|
||||
self.sd.noise_scheduler._step_index = None
|
||||
self.sd.noise_scheduler._init_step_index(timestep)
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index]
|
||||
sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
stepped_chunks.append(prev_sample)
|
||||
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
|
||||
pred_features = self.dfe(stepped_latents.float())
|
||||
with torch.no_grad():
|
||||
rectified_flow_target = noise.float() - batch.latents.float()
|
||||
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
||||
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
|
||||
# scale dfe so it is weaker at higher noise levels
|
||||
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
|
||||
|
||||
# do diffusion feature extraction on prediction
|
||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
||||
self.train_config.diffusion_feature_extractor_weight
|
||||
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
|
||||
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
|
||||
additional_loss += dfe_loss.mean()
|
||||
elif self.dfe.version == 2:
|
||||
# version 2
|
||||
# do diffusion feature extraction on target
|
||||
@@ -454,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||
else:
|
||||
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
||||
|
||||
if self.train_config.do_guidance_loss:
|
||||
with torch.no_grad():
|
||||
# we make cached blank prompt embeds that match the batch size
|
||||
unconditional_embeds = concat_prompt_embeds(
|
||||
[self.unconditional_embeds] * noisy_latents.shape[0],
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=unconditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# zero cfg
|
||||
|
||||
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
||||
batch_size = target.shape[0]
|
||||
positive_flat = target.view(batch_size, -1)
|
||||
negative_flat = cfm_pred.view(batch_size, -1)
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
alpha = st_star
|
||||
|
||||
is_video = len(target.shape) == 5
|
||||
|
||||
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
||||
|
||||
guidance_scale = self._guidance_loss_target_batch
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
|
||||
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
|
||||
|
||||
unconditional_target = cfm_pred * alpha
|
||||
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||
|
||||
|
||||
if target is None:
|
||||
@@ -634,102 +716,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
return loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
# This version avoids jvp / double-back-prop issues with Flash-Attention
|
||||
# adapted from the work of lodestonerock
|
||||
# ------------------------------------------------------------------
|
||||
def get_mean_flow_loss_wip(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
|
||||
|
||||
time_end = timesteps.float() / 1000
|
||||
# for timestep_r, we need values from timestep_end to 0.0 randomly
|
||||
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
|
||||
|
||||
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
|
||||
# Compute noised data points
|
||||
# lerp_vector = noisy_latents
|
||||
# compute instantaneous vector
|
||||
instantaneous_vector = noise - batch_latents
|
||||
|
||||
# finite difference method
|
||||
epsilon_fd = 1e-3
|
||||
jitter_std = 1e-4
|
||||
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
|
||||
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
|
||||
|
||||
# f(x + epsilon * v) for the primal (we backprop through here)
|
||||
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
|
||||
mean_vec_val_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
|
||||
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
|
||||
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
|
||||
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
|
||||
f_x_plus_eps_v = self.predict_noise(
|
||||
noisy_latents=perturbed_lerp_vector,
|
||||
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
|
||||
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
|
||||
mean_vec_grad = mean_vec_grad_fd
|
||||
|
||||
|
||||
# calculate the regression target the mean vector
|
||||
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
|
||||
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
|
||||
|
||||
# 5) MSE loss
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
mean_vec_val_pred.float(),
|
||||
mean_vec_target.float(),
|
||||
reduction='none'
|
||||
)
|
||||
with torch.no_grad():
|
||||
pure_loss = loss.mean().detach()
|
||||
# add grad to pure_loss so it can be backwards without issues
|
||||
pure_loss.requires_grad_(True)
|
||||
# normalize the loss per batch element to 1.0
|
||||
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||
with torch.no_grad():
|
||||
loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||
loss = loss / loss_mean
|
||||
loss = loss.mean()
|
||||
|
||||
# backward the pure loss for logging
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# return the real loss for logging
|
||||
return pure_loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
@@ -970,6 +956,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
guidance_embedding_scale = self.train_config.cfg_scale
|
||||
if self.train_config.do_guidance_loss:
|
||||
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
@@ -977,6 +967,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=guidance_embedding_scale,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
batch=batch,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
@@ -1020,13 +1011,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
guidance_embedding_scale = self.train_config.cfg_scale
|
||||
if self.train_config.do_guidance_loss:
|
||||
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||
return self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=guidance_embedding_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
||||
@@ -1034,152 +1028,78 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def cfm_augment_tensors(
|
||||
self,
|
||||
images: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfm_cache is None:
|
||||
# flip the current one. Only need this for first time
|
||||
self.cfm_cache = torch.flip(images, [3]).clone()
|
||||
augmented_tensor_list = []
|
||||
for i in range(images.shape[0]):
|
||||
# get a random one
|
||||
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
|
||||
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
|
||||
augmented = torch.cat(augmented_tensor_list, dim=0)
|
||||
# resize to match the input
|
||||
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
|
||||
self.cfm_cache = images.clone()
|
||||
return augmented
|
||||
|
||||
def get_cfm_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noise_pred: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
timesteps: torch.Tensor,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
if hasattr(self.sd, 'get_loss_target'):
|
||||
target = self.sd.get_loss_target(
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
|
||||
elif self.sd.is_flow_matching:
|
||||
# forward ODE
|
||||
target = (noise - batch.latents).detach()
|
||||
else:
|
||||
raise ValueError("CFM loss only works with flow matching")
|
||||
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
with torch.no_grad():
|
||||
# we need to compute the contrast
|
||||
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
|
||||
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
|
||||
cfm_noisy_latents = self.sd.add_noise(
|
||||
original_samples=cfm_latents,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
noisy_latents=cfm_noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
|
||||
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
|
||||
|
||||
# # Compute cosine similarity at each pixel
|
||||
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
|
||||
|
||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
# Compute cosine similarity at each pixel
|
||||
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
|
||||
|
||||
# Average over spatial dimensions, then batch
|
||||
contrastive_loss = -sim.mean()
|
||||
|
||||
loss = fm_loss.mean() + alpha * contrastive_loss
|
||||
return loss
|
||||
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
self.timer.start('preprocess_batch')
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_raw(batch)
|
||||
batch = self.preprocess_batch(batch)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_processed(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
if self.negative_prompt_pool is not None:
|
||||
negative_prompts = []
|
||||
for i in range(noisy_latents.shape[0]):
|
||||
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
||||
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
||||
this_neg_prompt = ', '.join(this_neg_prompts)
|
||||
negative_prompts.append(this_neg_prompt)
|
||||
self.batch_negative_prompt = negative_prompts
|
||||
else:
|
||||
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
with torch.no_grad():
|
||||
self.timer.start('preprocess_batch')
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_raw(batch)
|
||||
batch = self.preprocess_batch(batch)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_processed(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
if self.negative_prompt_pool is not None:
|
||||
negative_prompts = []
|
||||
for i in range(noisy_latents.shape[0]):
|
||||
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
||||
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
||||
this_neg_prompt = ', '.join(this_neg_prompts)
|
||||
negative_prompts.append(this_neg_prompt)
|
||||
self.batch_negative_prompt = negative_prompts
|
||||
else:
|
||||
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
for idx, file_item in enumerate(batch.file_items):
|
||||
if file_item.is_reg:
|
||||
@@ -1733,6 +1653,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list):
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1]
|
||||
self._guidance_loss_target_batch = [
|
||||
random.uniform(
|
||||
self.train_config.guidance_loss_target[0],
|
||||
self.train_config.guidance_loss_target[1]
|
||||
) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
self.before_unet_predict()
|
||||
|
||||
@@ -1832,25 +1762,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||
prior_to_calculate_loss = None
|
||||
|
||||
if self.train_config.loss_type == 'cfm':
|
||||
loss = self.get_cfm_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
conditional_embeds=conditional_embeds,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
)
|
||||
else:
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
# send the loss backwards otherwise checkpointing will fail
|
||||
|
||||
3
info.py
3
info.py
@@ -1,8 +1,9 @@
|
||||
from collections import OrderedDict
|
||||
from version import VERSION
|
||||
|
||||
v = OrderedDict()
|
||||
v["name"] = "ai-toolkit"
|
||||
v["repo"] = "https://github.com/ostris/ai-toolkit"
|
||||
v["version"] = "0.1.0"
|
||||
v["version"] = VERSION
|
||||
|
||||
software_meta = v
|
||||
|
||||
@@ -15,7 +15,6 @@ class BaseJob:
|
||||
self.config = config['config']
|
||||
self.raw_config = config
|
||||
self.job = config['job']
|
||||
self.torch_profiler = self.get_conf('torch_profiler', False)
|
||||
self.name = self.get_conf('name', required=True)
|
||||
if 'meta' in config:
|
||||
self.meta = config['meta']
|
||||
|
||||
@@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.ema: ExponentialMovingAverage = None
|
||||
|
||||
validate_configs(self.train_config, self.model_config, self.save_config)
|
||||
|
||||
do_profiler = self.get_conf('torch_profiler', False)
|
||||
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -575,10 +583,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
direct_save = False
|
||||
if self.adapter_config.train_only_image_encoder:
|
||||
direct_save = True
|
||||
if self.adapter_config.type == 'redux':
|
||||
direct_save = True
|
||||
if self.adapter_config.type in ['control_lora', 'subpixel', 'i2v']:
|
||||
direct_save = True
|
||||
elif isinstance(self.adapter, CustomAdapter):
|
||||
direct_save = self.adapter.do_direct_save
|
||||
save_ip_adapter_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
@@ -923,7 +929,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise = self.get_consistent_noise(latents, batch, dtype=dtype)
|
||||
else:
|
||||
if hasattr(self.sd, 'get_latent_noise_from_latents'):
|
||||
noise = self.sd.get_latent_noise_from_latents(latents).to(self.device_torch, dtype=dtype)
|
||||
noise = self.sd.get_latent_noise_from_latents(
|
||||
latents,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
else:
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
@@ -933,17 +942,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# if self.train_config.random_noise_shift > 0.0:
|
||||
# # get random noise -1 to 1
|
||||
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||
# dtype=noise.dtype) * 2 - 1
|
||||
|
||||
# # multiply by shift amount
|
||||
# noise_shift *= self.train_config.random_noise_shift
|
||||
|
||||
# # add to noise
|
||||
# noise += noise_shift
|
||||
|
||||
if self.train_config.blended_blur_noise:
|
||||
noise = get_blended_blur_noise(
|
||||
@@ -1014,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
cfm_batch = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
@@ -1087,19 +1084,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# we determine noise from the differential of the latents
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
|
||||
batch_size = len(batch.file_items)
|
||||
min_noise_steps = self.train_config.min_denoising_steps
|
||||
max_noise_steps = self.train_config.max_denoising_steps
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# if we are not training the unet, then we are only doing refiner and do not need to double up
|
||||
if self.train_config.train_unet:
|
||||
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = True
|
||||
else:
|
||||
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = False
|
||||
with self.timer('prepare_scheduler'):
|
||||
|
||||
batch_size = len(batch.file_items)
|
||||
min_noise_steps = self.train_config.min_denoising_steps
|
||||
max_noise_steps = self.train_config.max_denoising_steps
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# if we are not training the unet, then we are only doing refiner and do not need to double up
|
||||
if self.train_config.train_unet:
|
||||
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = True
|
||||
else:
|
||||
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = False
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
num_train_timesteps = self.train_config.num_train_timesteps
|
||||
|
||||
if self.train_config.noise_scheduler in ['custom_lcm']:
|
||||
@@ -1146,6 +1144,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
with self.timer('prepare_timesteps_indices'):
|
||||
|
||||
content_or_style = self.train_config.content_or_style
|
||||
if is_reg:
|
||||
@@ -1195,20 +1194,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
else:
|
||||
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
|
||||
min_idx = min_noise_steps + 1
|
||||
max_idx = max_noise_steps - 1
|
||||
if self.train_config.noise_scheduler == 'flowmatch':
|
||||
# flowmatch uses indices, so we need to use indices
|
||||
min_idx = 0
|
||||
max_idx = max_noise_steps - 1
|
||||
timestep_indices = torch.randint(
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1,
|
||||
min_idx,
|
||||
max_idx,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timestep_indices = timestep_indices.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||
|
||||
with self.timer('convert_timestep_indices_to_timesteps'):
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()]
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
# get noise
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
|
||||
|
||||
@@ -1242,6 +1247,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
device=noise.device,
|
||||
dtype=noise.dtype
|
||||
) * self.train_config.random_noise_multiplier
|
||||
|
||||
with self.timer('make_noisy_latents'):
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
|
||||
@@ -2058,6 +2065,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.start()
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
try:
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
@@ -2069,7 +2078,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for item in batch.file_items:
|
||||
print(f" - {item.path}")
|
||||
raise e
|
||||
|
||||
if self.torch_profiler is not None:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
self.torch_profiler.stop()
|
||||
|
||||
print("\n==== Profile Results ====")
|
||||
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
flush()
|
||||
|
||||
@@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess):
|
||||
prompt_image_configs = []
|
||||
for _ in range(self.generate_config.num_repeats):
|
||||
for prompt in self.generate_config.prompts:
|
||||
# remove --
|
||||
prompt = prompt.replace('--', '').strip()
|
||||
width = self.generate_config.width
|
||||
height = self.generate_config.height
|
||||
# prompt = self.clean_prompt(prompt)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
|
||||
from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
|
||||
import random
|
||||
|
||||
import torch
|
||||
@@ -413,7 +413,7 @@ class TrainConfig:
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
@@ -454,8 +454,12 @@ class TrainConfig:
|
||||
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
|
||||
|
||||
# diffusion feature extractor
|
||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
||||
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0)
|
||||
self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None)
|
||||
self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0)
|
||||
|
||||
# we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture
|
||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path)
|
||||
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight)
|
||||
|
||||
# optimal noise pairing
|
||||
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
||||
@@ -463,6 +467,13 @@ class TrainConfig:
|
||||
# forces same noise for the same image at a given size.
|
||||
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
||||
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
|
||||
|
||||
# contrastive loss
|
||||
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
||||
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||
if isinstance(self.guidance_loss_target, tuple):
|
||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||
|
||||
|
||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||
@@ -827,6 +838,9 @@ class DatasetConfig:
|
||||
self.controls = [self.controls]
|
||||
# remove empty strings
|
||||
self.controls = [control for control in self.controls if control.strip() != '']
|
||||
|
||||
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
|
||||
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
@@ -1141,3 +1155,6 @@ def validate_configs(
|
||||
if model_config.use_flux_cfg:
|
||||
# bypass the embedding
|
||||
train_config.bypass_guidance_embedding = True
|
||||
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
|
||||
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
|
||||
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
|
||||
|
||||
@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.control_lora_adapter import ControlLoraAdapter
|
||||
from toolkit.models.mean_flow_adapter import MeanFlowAdapter
|
||||
from toolkit.models.i2v_adapter import I2VAdapter
|
||||
from toolkit.models.subpixel_adapter import SubpixelAdapter
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
@@ -98,6 +99,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.single_value_adapter: SingleValueAdapter = None
|
||||
self.redux_adapter: ReduxImageEncoder = None
|
||||
self.control_lora: ControlLoraAdapter = None
|
||||
self.mean_flow_adapter: MeanFlowAdapter = None
|
||||
self.subpixel_adapter: SubpixelAdapter = None
|
||||
self.i2v_adapter: I2VAdapter = None
|
||||
|
||||
@@ -125,6 +127,16 @@ class CustomAdapter(torch.nn.Module):
|
||||
dtype=self.sd_ref().dtype,
|
||||
)
|
||||
self.load_state_dict(loaded_state_dict, strict=False)
|
||||
|
||||
@property
|
||||
def do_direct_save(self):
|
||||
# some adapters save their weights directly, others like ip adapters split the state dict
|
||||
if self.config.train_only_image_encoder:
|
||||
return True
|
||||
if self.config.type in ['control_lora', 'subpixel', 'i2v', 'redux', 'mean_flow']:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def setup_adapter(self):
|
||||
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
|
||||
@@ -245,6 +257,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.adapter_type == 'redux':
|
||||
vision_hidden_size = self.vision_encoder.config.hidden_size
|
||||
self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
|
||||
elif self.adapter_type == 'mean_flow':
|
||||
self.mean_flow_adapter = MeanFlowAdapter(
|
||||
self,
|
||||
sd=self.sd_ref(),
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
elif self.adapter_type == 'control_lora':
|
||||
self.control_lora = ControlLoraAdapter(
|
||||
self,
|
||||
@@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]:
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel", "mean_flow"]:
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
@@ -528,6 +547,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.control_lora.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'mean_flow':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.mean_flow_adapter.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'i2v':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
@@ -599,6 +626,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'mean_flow':
|
||||
d = self.mean_flow_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'i2v':
|
||||
d = self.i2v_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
@@ -757,7 +789,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']:
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v', 'mean_flow']:
|
||||
return prompt
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
# todo allow for training
|
||||
@@ -1319,6 +1351,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
param_list = self.control_lora.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'mean_flow':
|
||||
param_list = self.mean_flow_adapter.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'i2v':
|
||||
param_list = self.i2v_adapter.get_params()
|
||||
for param in param_list:
|
||||
|
||||
@@ -84,15 +84,16 @@ class FileItemDTO(
|
||||
video.release()
|
||||
size_database[file_key] = (width, height, file_signature)
|
||||
else:
|
||||
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now.
|
||||
# process width and height
|
||||
# try:
|
||||
# w, h = image_utils.get_image_size(self.path)
|
||||
# except image_utils.UnknownImageFormat:
|
||||
# print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
# f'This process is faster for png, jpeg')
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
w, h = img.size
|
||||
if self.dataset_config.fast_image_size:
|
||||
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
|
||||
try:
|
||||
w, h = image_utils.get_image_size(self.path)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
else:
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
w, h = img.size
|
||||
size_database[file_key] = (w, h, file_signature)
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
|
||||
@@ -815,7 +815,10 @@ class BaseModel:
|
||||
|
||||
# predict the noise residual
|
||||
if self.unet.device != self.device_torch:
|
||||
self.unet.to(self.device_torch)
|
||||
try:
|
||||
self.unet.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
|
||||
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
|
||||
if hasattr(sd, 'target_lora_modules'):
|
||||
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
|
||||
network_kwargs['target_lin_modules'] = sd.target_lora_modules
|
||||
|
||||
if 'ignore_if_contains' not in network_kwargs:
|
||||
network_kwargs['ignore_if_contains'] = []
|
||||
|
||||
@@ -127,18 +127,20 @@ class DFEBlock(nn.Module):
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.act = nn.GELU()
|
||||
self.proj = nn.Conv2d(channels, channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.act(x)
|
||||
x = self.proj(x)
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionFeatureExtractor(nn.Module):
|
||||
def __init__(self, in_channels=32):
|
||||
def __init__(self, in_channels=16):
|
||||
super().__init__()
|
||||
self.version = 1
|
||||
num_blocks = 6
|
||||
|
||||
@@ -176,60 +176,3 @@ def add_model_gpu_splitter_to_flux(
|
||||
transformer._pre_gpu_split_to = transformer.to
|
||||
transformer.to = partial(new_device_to, transformer)
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
|
||||
|
||||
timesteps_emb = timesteps_emb_start + timesteps_emb_end
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
|
||||
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = time_guidance_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def convert_flux_to_mean_flow(
|
||||
transformer: FluxTransformer2DModel,
|
||||
):
|
||||
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported time_text_embed type: {}".format(
|
||||
type(transformer.time_text_embed)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,567 +0,0 @@
|
||||
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.utils import logging
|
||||
from diffusers.models.attention import LuminaFeedForward
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
import torch
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
do_profile = False
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 4096,
|
||||
cap_feat_dim: int = 2048,
|
||||
frequency_embedding_size: int = 256,
|
||||
norm_eps: float = 1e-5,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
||||
)
|
||||
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
||||
)
|
||||
|
||||
self.caption_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(encoder_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
class Lumina2AttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply Query-Key Norm if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Apply proportional attention if true
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# perform Grouped-qurey Attention (GQA)
|
||||
n_rep = attn.heads // kv_heads
|
||||
if n_rep >= 1:
|
||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Lumina2TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
modulation: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = dim // num_attention_heads
|
||||
self.modulation = modulation
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
qk_norm="rms_norm",
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=Lumina2AttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(
|
||||
embedding_dim=dim,
|
||||
norm_eps=norm_eps,
|
||||
norm_elementwise_affine=True,
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Lumina2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
|
||||
|
||||
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
||||
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64)
|
||||
freqs_cis.append(emb)
|
||||
return freqs_cis
|
||||
|
||||
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
result = []
|
||||
for i in range(len(self.axes_dim)):
|
||||
freqs = self.freqs_cis[i].to(ids.device)
|
||||
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
||||
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
||||
batch_size = len(hidden_states)
|
||||
p_h = p_w = self.patch_size
|
||||
device = hidden_states[0].device
|
||||
|
||||
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
||||
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]
|
||||
|
||||
max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i in range(batch_size):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p_h, W // p_w
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
|
||||
row_ids = (
|
||||
torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
)
|
||||
col_ids = (
|
||||
torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
)
|
||||
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
|
||||
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self._get_freqs_cis(position_ids)
|
||||
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
cap_freqs_cis_shape[1] = attention_mask.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(batch_size):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
|
||||
|
||||
flat_hidden_states = []
|
||||
for i in range(batch_size):
|
||||
img = hidden_states[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_hidden_states.append(img)
|
||||
hidden_states = flat_hidden_states
|
||||
padded_img_embed = torch.zeros(
|
||||
batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
|
||||
)
|
||||
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(batch_size):
|
||||
padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i]
|
||||
padded_img_mask[i, : l_effective_img_len[i]] = True
|
||||
|
||||
return (
|
||||
padded_img_embed,
|
||||
padded_img_mask,
|
||||
img_sizes,
|
||||
l_effective_cap_len,
|
||||
l_effective_img_len,
|
||||
freqs_cis,
|
||||
cap_freqs_cis,
|
||||
img_freqs_cis,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
|
||||
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
||||
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of input channels for the model. Typically, this matches the number of channels in the input
|
||||
images.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
num_layers (`int`, *optional*, default to 32):
|
||||
The number of layers in the model. This defines the depth of the neural network.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
||||
mechanisms are used.
|
||||
num_kv_heads (`int`, *optional*, defaults to 8):
|
||||
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
||||
If None, it defaults to num_attention_heads.
|
||||
multiple_of (`int`, *optional*, defaults to 256):
|
||||
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
||||
configurations.
|
||||
ffn_dim_multiplier (`float`, *optional*):
|
||||
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
||||
the model configuration.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
A small value added to the denominator for numerical stability in normalization layers.
|
||||
scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
||||
overall scale of the model's operations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Lumina2TransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
scaling_factor: float = 1.0,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
cap_feat_dim: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Positional, patch & conditional embeddings
|
||||
self.rope_embedder = Lumina2RotaryPosEmbed(
|
||||
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
|
||||
)
|
||||
|
||||
# 2. Noise and context refinement blocks
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True,
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=False,
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
out_dim=patch_size * patch_size * self.out_channels,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
|
||||
hidden_size = self.config.get("hidden_size", 2304)
|
||||
# pad or slice text encoder
|
||||
if encoder_hidden_states.shape[2] > hidden_size:
|
||||
encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size]
|
||||
elif encoder_hidden_states.shape[2] < hidden_size:
|
||||
encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2]))
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if do_profile:
|
||||
prof = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
prof.start()
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
hidden_mask,
|
||||
hidden_sizes,
|
||||
encoder_hidden_len,
|
||||
hidden_len,
|
||||
joint_rotary_emb,
|
||||
encoder_rotary_emb,
|
||||
hidden_rotary_emb,
|
||||
max_seq_len,
|
||||
) = self.rope_embedder(hidden_states, attention_mask)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# 2. Context & noise refinement
|
||||
for layer in self.context_refiner:
|
||||
encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb)
|
||||
|
||||
# 3. Attention mask preparation
|
||||
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
||||
padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
||||
for i in range(batch_size):
|
||||
cap_len = encoder_hidden_len[i]
|
||||
img_len = hidden_len[i]
|
||||
mask[i, : cap_len + img_len] = True
|
||||
padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
|
||||
padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
|
||||
hidden_states = padded_hidden_states
|
||||
|
||||
# 4. Transformer blocks
|
||||
for layer in self.layers:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb)
|
||||
|
||||
# 5. Output norm & projection & unpatchify
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
height_tokens = width_tokens = self.config.patch_size
|
||||
output = []
|
||||
for i in range(len(hidden_sizes)):
|
||||
height, width = hidden_sizes[i]
|
||||
begin = encoder_hidden_len[i]
|
||||
end = begin + (height // height_tokens) * (width // width_tokens)
|
||||
output.append(
|
||||
hidden_states[i][begin:end]
|
||||
.view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels)
|
||||
.permute(4, 0, 2, 1, 3)
|
||||
.flatten(3, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if do_profile:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
prof.stop()
|
||||
|
||||
print("\n==== Profile Results ====")
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
282
toolkit/models/mean_flow_adapter.py
Normal file
282
toolkit/models/mean_flow_adapter.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import inspect
|
||||
import weakref
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(
|
||||
self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection
|
||||
):
|
||||
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
|
||||
# make zero timestep ending if none is passed
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.zeros_like(timestep)], dim=0
|
||||
) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
timesteps_proj.to(dtype=pooled_projection.dtype)
|
||||
) # (N, D)
|
||||
|
||||
# mean flow stuff
|
||||
if mean_flow_adapter.is_active:
|
||||
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
|
||||
orig_dtype = timesteps_emb.dtype
|
||||
timesteps_emb = timesteps_emb.to(torch.float32)
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
|
||||
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
|
||||
)
|
||||
timesteps_emb = timesteps_emb.to(orig_dtype)
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(
|
||||
self: CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
timestep,
|
||||
guidance,
|
||||
pooled_projection,
|
||||
):
|
||||
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
|
||||
# make zero timestep ending if none is passed
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.zeros_like(timestep)], dim=0
|
||||
) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
timesteps_proj.to(dtype=pooled_projection.dtype)
|
||||
) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(
|
||||
guidance_proj.to(dtype=pooled_projection.dtype)
|
||||
) # (N, D)
|
||||
|
||||
# mean flow stuff
|
||||
if mean_flow_adapter.is_active:
|
||||
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
|
||||
orig_dtype = timesteps_emb.dtype
|
||||
timesteps_emb = timesteps_emb.to(torch.float32)
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
|
||||
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
|
||||
)
|
||||
timesteps_emb = timesteps_emb.to(orig_dtype)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = time_guidance_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def convert_flux_to_mean_flow(
|
||||
transformer: FluxTransformer2DModel,
|
||||
):
|
||||
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
elif isinstance(
|
||||
transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings
|
||||
):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported time_text_embed type: {}".format(
|
||||
type(transformer.time_text_embed)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MeanFlowAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: "CustomAdapter",
|
||||
sd: "StableDiffusion",
|
||||
config: "AdapterConfig",
|
||||
train_config: "TrainConfig",
|
||||
):
|
||||
super().__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.model_config: ModelConfig = sd.model_config
|
||||
self.network_config = config.lora_config
|
||||
self.train_config = train_config
|
||||
self.device_torch = sd.device_torch
|
||||
self.lora = None
|
||||
|
||||
if self.network_config is not None:
|
||||
network_kwargs = (
|
||||
{}
|
||||
if self.network_config.network_kwargs is None
|
||||
else self.network_config.network_kwargs
|
||||
)
|
||||
if hasattr(sd, "target_lora_modules"):
|
||||
network_kwargs["target_lin_modules"] = sd.target_lora_modules
|
||||
|
||||
if "ignore_if_contains" not in network_kwargs:
|
||||
network_kwargs["ignore_if_contains"] = []
|
||||
|
||||
self.lora = LoRASpecialNetwork(
|
||||
text_encoder=sd.text_encoder,
|
||||
unet=sd.unet,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_v3=self.model_config.is_v3,
|
||||
is_pixart=self.model_config.is_pixart,
|
||||
is_auraflow=self.model_config.is_auraflow,
|
||||
is_flux=self.model_config.is_flux,
|
||||
is_lumina2=self.model_config.is_lumina2,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
use_bias=False,
|
||||
is_lorm=False,
|
||||
network_config=self.network_config,
|
||||
network_type=self.network_config.type,
|
||||
transformer_only=self.network_config.transformer_only,
|
||||
is_transformer=sd.is_transformer,
|
||||
base_model=sd,
|
||||
**network_kwargs,
|
||||
)
|
||||
self.lora.force_to(self.device_torch, dtype=torch.float32)
|
||||
self.lora._update_torch_multiplier()
|
||||
self.lora.apply_to(
|
||||
sd.text_encoder,
|
||||
sd.unet,
|
||||
self.train_config.train_text_encoder,
|
||||
self.train_config.train_unet,
|
||||
)
|
||||
self.lora.can_merge_in = False
|
||||
self.lora.prepare_grad_etc(sd.text_encoder, sd.unet)
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.lora.enable_gradient_checkpointing()
|
||||
|
||||
emb_dim = None
|
||||
if self.model_config.arch in ["flux", "flex2", "flex2"]:
|
||||
transformer: FluxTransformer2DModel = sd.unet
|
||||
emb_dim = (
|
||||
transformer.config.num_attention_heads
|
||||
* transformer.config.attention_head_dim
|
||||
)
|
||||
convert_flux_to_mean_flow(transformer)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
|
||||
|
||||
self.mean_flow_timestep_embedder = torch.nn.Linear(
|
||||
emb_dim * 2,
|
||||
emb_dim,
|
||||
)
|
||||
|
||||
# make the model function as before adding this adapter by initializing the weights
|
||||
with torch.no_grad():
|
||||
self.mean_flow_timestep_embedder.weight.zero_()
|
||||
self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim)
|
||||
self.mean_flow_timestep_embedder.bias.zero_()
|
||||
|
||||
self.mean_flow_timestep_embedder.to(self.device_torch)
|
||||
|
||||
# add our adapter as a weak ref
|
||||
if self.model_config.arch in ["flux", "flex2", "flex2"]:
|
||||
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
|
||||
|
||||
def get_params(self):
|
||||
if self.lora is not None:
|
||||
config = {
|
||||
"text_encoder_lr": self.train_config.lr,
|
||||
"unet_lr": self.train_config.lr,
|
||||
}
|
||||
sig = inspect.signature(self.lora.prepare_optimizer_params)
|
||||
if "default_lr" in sig.parameters:
|
||||
config["default_lr"] = self.train_config.lr
|
||||
if "learning_rate" in sig.parameters:
|
||||
config["learning_rate"] = self.train_config.lr
|
||||
params_net = self.lora.prepare_optimizer_params(**config)
|
||||
|
||||
# we want only tensors here
|
||||
params = []
|
||||
for p in params_net:
|
||||
if isinstance(p, dict):
|
||||
params += p["params"]
|
||||
elif isinstance(p, torch.Tensor):
|
||||
params.append(p)
|
||||
elif isinstance(p, list):
|
||||
params += p
|
||||
else:
|
||||
params = []
|
||||
|
||||
# make sure the embedder is float32
|
||||
self.mean_flow_timestep_embedder.to(torch.float32)
|
||||
self.mean_flow_timestep_embedder.requires_grad = True
|
||||
self.mean_flow_timestep_embedder.train()
|
||||
|
||||
params += list(self.mean_flow_timestep_embedder.parameters())
|
||||
|
||||
# we need to be able to yield from the list like yield from params
|
||||
|
||||
return params
|
||||
|
||||
def load_weights(self, state_dict, strict=True):
|
||||
lora_sd = {}
|
||||
mean_flow_embedder_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
if "mean_flow_timestep_embedder" in key:
|
||||
new_key = key.replace("transformer.mean_flow_timestep_embedder.", "")
|
||||
mean_flow_embedder_sd[new_key] = value
|
||||
else:
|
||||
lora_sd[key] = value
|
||||
|
||||
# todo process state dict before loading for models that need it
|
||||
if self.lora is not None:
|
||||
self.lora.load_weights(lora_sd)
|
||||
self.mean_flow_timestep_embedder.load_state_dict(
|
||||
mean_flow_embedder_sd, strict=False
|
||||
)
|
||||
|
||||
def get_state_dict(self):
|
||||
if self.lora is not None:
|
||||
lora_sd = self.lora.get_state_dict(dtype=torch.float32)
|
||||
else:
|
||||
lora_sd = {}
|
||||
# todo make sure we match loras elseware.
|
||||
mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict()
|
||||
for key, value in mean_flow_embedder_sd.items():
|
||||
lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value
|
||||
return lora_sd
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
@@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer):
|
||||
do_paramiter_swapping=False,
|
||||
paramiter_swapping_factor=0.1,
|
||||
stochastic_accumulation=True,
|
||||
stochastic_rounding=True,
|
||||
):
|
||||
self.stochastic_rounding = stochastic_rounding
|
||||
if lr is not None and relative_step:
|
||||
raise ValueError(
|
||||
"Cannot combine manual `lr` and `relative_step=True` options")
|
||||
@@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer):
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype != torch.float32:
|
||||
if p.dtype != torch.float32 and self.stochastic_rounding:
|
||||
# apply stochastic rounding
|
||||
copy_stochastic(p, p_data_fp32)
|
||||
|
||||
|
||||
@@ -112,86 +112,57 @@ def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
||||
return 0, 8 # Int8 doesn't have mantissa bits
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor):
|
||||
# adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5
|
||||
# create a random 16 bit integer
|
||||
result = torch.randint_like(
|
||||
source,
|
||||
dtype=torch.int32,
|
||||
low=0,
|
||||
high=(1 << 16),
|
||||
)
|
||||
|
||||
# add the random number to the lower 16 bit of the mantissa
|
||||
result.add_(source.view(dtype=torch.int32))
|
||||
|
||||
# mask off the lower 16 bit of the mantissa
|
||||
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
||||
|
||||
# copy the higher 16 bit into the target tensor
|
||||
target.copy_(result.view(dtype=torch.float32))
|
||||
|
||||
del result
|
||||
|
||||
|
||||
def copy_stochastic(
|
||||
target: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
eps: Optional[float] = None
|
||||
) -> None:
|
||||
"""
|
||||
Performs stochastic rounding from source tensor to target tensor.
|
||||
|
||||
Args:
|
||||
target: Destination tensor (determines the target format)
|
||||
source: Source tensor (typically float32)
|
||||
eps: Optional minimum value for stochastic rounding (for numerical stability)
|
||||
"""
|
||||
def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None:
|
||||
with torch.no_grad():
|
||||
# If target is float32, just copy directly
|
||||
# assert if target is on cpu, throw error
|
||||
assert target.device.type != 'cpu', "Target is on cpu!"
|
||||
assert source.device.type != 'cpu', "Source is on cpu!"
|
||||
|
||||
if target.dtype == torch.float32:
|
||||
target.copy_(source)
|
||||
return
|
||||
|
||||
# Special handling for int8
|
||||
if target.dtype == torch.int8:
|
||||
# Scale the source values to utilize the full int8 range
|
||||
scaled = source * 127.0 # Scale to [-127, 127]
|
||||
|
||||
# Add random noise for stochastic rounding
|
||||
noise = torch.rand_like(scaled) - 0.5
|
||||
rounded = torch.round(scaled + noise)
|
||||
|
||||
# Clamp to int8 range
|
||||
clamped = torch.clamp(rounded, -127, 127)
|
||||
target.copy_(clamped.to(torch.int8))
|
||||
if target.dtype == torch.bfloat16:
|
||||
copy_stochastic_bf16(target, source)
|
||||
return
|
||||
|
||||
mantissa_bits, _ = get_format_params(target.dtype)
|
||||
round_factor = 2 ** (23 - mantissa_bits)
|
||||
|
||||
# Convert source to int32 view
|
||||
source_int = source.view(dtype=torch.int32)
|
||||
# Add uniform noise for stochastic rounding
|
||||
noise = torch.rand_like(source, device=source.device) - 0.5
|
||||
rounded = torch.round(source * round_factor + noise)
|
||||
result_float = rounded / round_factor
|
||||
|
||||
# Calculate number of bits to round
|
||||
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
|
||||
|
||||
# Create random integers for stochastic rounding
|
||||
rand = torch.randint_like(
|
||||
source,
|
||||
dtype=torch.int32,
|
||||
low=0,
|
||||
high=(1 << bits_to_round),
|
||||
)
|
||||
|
||||
# Add random values to the bits that will be rounded off
|
||||
result = source_int.clone()
|
||||
result.add_(rand)
|
||||
|
||||
# Mask to keep only the bits we want
|
||||
# Create mask with 1s in positions we want to keep
|
||||
mask = (-1) << bits_to_round
|
||||
result.bitwise_and_(mask)
|
||||
|
||||
# Handle minimum value threshold if specified
|
||||
if eps is not None:
|
||||
eps_int = torch.tensor(
|
||||
eps, dtype=torch.float32).view(dtype=torch.int32)
|
||||
zero_mask = (result.abs() < eps_int)
|
||||
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
|
||||
|
||||
# Convert back to float32 view
|
||||
result_float = result.view(dtype=torch.float32)
|
||||
|
||||
# Special handling for float8 formats
|
||||
# Clamp for float8
|
||||
if target.dtype == torch.float8_e4m3fn:
|
||||
result_float.clamp_(-448.0, 448.0)
|
||||
elif target.dtype == torch.float8_e5m2:
|
||||
result_float.clamp_(-57344.0, 57344.0)
|
||||
|
||||
# Copy the result to the target tensor
|
||||
update_parameter(target, result_float)
|
||||
# target.copy_(result_float)
|
||||
del result, rand, source_int
|
||||
|
||||
|
||||
class Auto8bitTensor:
|
||||
|
||||
@@ -50,8 +50,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \
|
||||
FluxControlPipeline
|
||||
from toolkit.models.lumina2 import Lumina2Transformer2DModel
|
||||
FluxControlPipeline, Lumina2Transformer2DModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -1763,6 +1762,15 @@ class StableDiffusion:
|
||||
)
|
||||
noise = apply_noise_offset(noise, noise_offset)
|
||||
return noise
|
||||
|
||||
def get_latent_noise_from_latents(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
noise_offset=0.0
|
||||
):
|
||||
noise = torch.randn_like(latents)
|
||||
noise = apply_noise_offset(noise, noise_offset)
|
||||
return noise
|
||||
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
@@ -2170,7 +2178,7 @@ class StableDiffusion:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
@@ -2529,8 +2537,8 @@ class StableDiffusion:
|
||||
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
self.vae.to(self.device_torch)
|
||||
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
|
||||
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
||||
images = self.vae.decode(latents).sample
|
||||
images = images.to(device, dtype=dtype)
|
||||
|
||||
31
ui/cron/worker.ts
Normal file
31
ui/cron/worker.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
class CronWorker {
|
||||
interval: number;
|
||||
is_running: boolean;
|
||||
intervalId: NodeJS.Timeout;
|
||||
constructor() {
|
||||
this.interval = 1000; // Default interval of 1 second
|
||||
this.is_running = false;
|
||||
this.intervalId = setInterval(() => {
|
||||
this.run();
|
||||
}, this.interval);
|
||||
}
|
||||
async run() {
|
||||
if (this.is_running) {
|
||||
return;
|
||||
}
|
||||
this.is_running = true;
|
||||
try {
|
||||
// Loop logic here
|
||||
await this.loop();
|
||||
} catch (error) {
|
||||
console.error('Error in cron worker loop:', error);
|
||||
}
|
||||
this.is_running = false;
|
||||
}
|
||||
|
||||
async loop() {}
|
||||
}
|
||||
|
||||
// it automatically starts the loop
|
||||
const cronWorker = new CronWorker();
|
||||
console.log('Cron worker started with interval:', cronWorker.interval, 'ms');
|
||||
654
ui/package-lock.json
generated
654
ui/package-lock.json
generated
@@ -30,10 +30,12 @@
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"concurrently": "^9.1.2",
|
||||
"postcss": "^8",
|
||||
"prettier": "^3.5.1",
|
||||
"prettier-basic": "^1.0.0",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typescript": "^5"
|
||||
}
|
||||
},
|
||||
@@ -169,6 +171,28 @@
|
||||
"node": ">=6.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@cspotcode/source-map-support": {
|
||||
"version": "0.8.1",
|
||||
"resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
|
||||
"integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/trace-mapping": "0.3.9"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
|
||||
"integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/resolve-uri": "^3.0.3",
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@emnapi/runtime": {
|
||||
"version": "1.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.3.1.tgz",
|
||||
@@ -1168,6 +1192,30 @@
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/@tsconfig/node10": {
|
||||
"version": "1.0.11",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.11.tgz",
|
||||
"integrity": "sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@tsconfig/node12": {
|
||||
"version": "1.0.11",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
|
||||
"integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@tsconfig/node14": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
|
||||
"integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@tsconfig/node16": {
|
||||
"version": "1.0.4",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.4.tgz",
|
||||
"integrity": "sha512-vxhUy4J8lyeyinH7Azl1pdd43GJhZH/tP2weN8TntQblOY+A0XbT8DJk1/oCPuOOyg/Ja757rG0CgHcWC8OfMA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "20.17.19",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.19.tgz",
|
||||
@@ -1207,6 +1255,18 @@
|
||||
"@types/react": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/strip-bom": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/strip-bom/-/strip-bom-3.0.0.tgz",
|
||||
"integrity": "sha512-xevGOReSYGM7g/kUBZzPqCrR/KYAo+F0yiPc85WFTJa0MSLtyFTVTU6cJu/aV4mid7IffDIWqo69THF2o4JiEQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/strip-json-comments": {
|
||||
"version": "0.0.30",
|
||||
"resolved": "https://registry.npmjs.org/@types/strip-json-comments/-/strip-json-comments-0.0.30.tgz",
|
||||
"integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/abbrev": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/abbrev/-/abbrev-1.1.1.tgz",
|
||||
@@ -1214,6 +1274,30 @@
|
||||
"license": "ISC",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/acorn": {
|
||||
"version": "8.15.0",
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"dev": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/acorn-walk": {
|
||||
"version": "8.3.4",
|
||||
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.4.tgz",
|
||||
"integrity": "sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"acorn": "^8.11.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/agent-base": {
|
||||
"version": "6.0.2",
|
||||
"resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz",
|
||||
@@ -1465,6 +1549,12 @@
|
||||
"ieee754": "^1.1.13"
|
||||
}
|
||||
},
|
||||
"node_modules/buffer-from": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz",
|
||||
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/busboy": {
|
||||
"version": "1.6.0",
|
||||
"resolved": "https://registry.npmjs.org/busboy/-/busboy-1.6.0.tgz",
|
||||
@@ -1626,6 +1716,49 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"node_modules/chalk": {
|
||||
"version": "4.1.2",
|
||||
"resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz",
|
||||
"integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"ansi-styles": "^4.1.0",
|
||||
"supports-color": "^7.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/chalk/chalk?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/chalk/node_modules/ansi-styles": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
|
||||
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"color-convert": "^2.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/chalk/node_modules/supports-color": {
|
||||
"version": "7.2.0",
|
||||
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz",
|
||||
"integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"has-flag": "^4.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/chokidar": {
|
||||
"version": "3.6.0",
|
||||
"resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz",
|
||||
@@ -1691,6 +1824,93 @@
|
||||
"resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz",
|
||||
"integrity": "sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA=="
|
||||
},
|
||||
"node_modules/cliui": {
|
||||
"version": "8.0.1",
|
||||
"resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz",
|
||||
"integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"string-width": "^4.2.0",
|
||||
"strip-ansi": "^6.0.1",
|
||||
"wrap-ansi": "^7.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/cliui/node_modules/ansi-regex": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz",
|
||||
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/cliui/node_modules/ansi-styles": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz",
|
||||
"integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"color-convert": "^2.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/cliui/node_modules/emoji-regex": {
|
||||
"version": "8.0.0",
|
||||
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
|
||||
"integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/cliui/node_modules/string-width": {
|
||||
"version": "4.2.3",
|
||||
"resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
|
||||
"integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"emoji-regex": "^8.0.0",
|
||||
"is-fullwidth-code-point": "^3.0.0",
|
||||
"strip-ansi": "^6.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/cliui/node_modules/strip-ansi": {
|
||||
"version": "6.0.1",
|
||||
"resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
|
||||
"integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"ansi-regex": "^5.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/cliui/node_modules/wrap-ansi": {
|
||||
"version": "7.0.0",
|
||||
"resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz",
|
||||
"integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"ansi-styles": "^4.0.0",
|
||||
"string-width": "^4.1.0",
|
||||
"strip-ansi": "^6.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/chalk/wrap-ansi?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/clone": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/clone/-/clone-2.1.2.tgz",
|
||||
@@ -1782,8 +2002,33 @@
|
||||
"version": "0.0.1",
|
||||
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
|
||||
"integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==",
|
||||
"license": "MIT",
|
||||
"optional": true
|
||||
"devOptional": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/concurrently": {
|
||||
"version": "9.1.2",
|
||||
"resolved": "https://registry.npmjs.org/concurrently/-/concurrently-9.1.2.tgz",
|
||||
"integrity": "sha512-H9MWcoPsYddwbOGM6difjVwVZHl63nwMEwDJG/L7VGtuaJhb12h2caPG2tVPWs7emuYix252iGfqOyrz1GczTQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"chalk": "^4.1.2",
|
||||
"lodash": "^4.17.21",
|
||||
"rxjs": "^7.8.1",
|
||||
"shell-quote": "^1.8.1",
|
||||
"supports-color": "^8.1.1",
|
||||
"tree-kill": "^1.2.2",
|
||||
"yargs": "^17.7.2"
|
||||
},
|
||||
"bin": {
|
||||
"conc": "dist/bin/concurrently.js",
|
||||
"concurrently": "dist/bin/concurrently.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/open-cli-tools/concurrently?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/console-control-strings": {
|
||||
"version": "1.1.0",
|
||||
@@ -1820,6 +2065,12 @@
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/create-require": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
|
||||
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/cross-spawn": {
|
||||
"version": "7.0.6",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
|
||||
@@ -1921,6 +2172,15 @@
|
||||
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/diff": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
|
||||
"integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=0.3.1"
|
||||
}
|
||||
},
|
||||
"node_modules/dlv": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz",
|
||||
@@ -1949,6 +2209,15 @@
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/dynamic-dedupe": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/dynamic-dedupe/-/dynamic-dedupe-0.3.0.tgz",
|
||||
"integrity": "sha512-ssuANeD+z97meYOqd50e04Ze5qp4bPqo8cCkI4TRjZkzAUgIDTrXV1R8QCdINpiI+hw14+rYazvTRdQrz0/rFQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"xtend": "^4.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/eastasianwidth": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz",
|
||||
@@ -2051,6 +2320,15 @@
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/escalade": {
|
||||
"version": "3.2.0",
|
||||
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
|
||||
"integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/escape-string-regexp": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz",
|
||||
@@ -2225,8 +2503,8 @@
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz",
|
||||
"integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==",
|
||||
"license": "ISC",
|
||||
"optional": true
|
||||
"devOptional": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/fsevents": {
|
||||
"version": "2.3.3",
|
||||
@@ -2322,6 +2600,15 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/get-caller-file": {
|
||||
"version": "2.0.5",
|
||||
"resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz",
|
||||
"integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": "6.* || 8.* || >= 10.*"
|
||||
}
|
||||
},
|
||||
"node_modules/get-intrinsic": {
|
||||
"version": "1.2.7",
|
||||
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.7.tgz",
|
||||
@@ -2421,6 +2708,15 @@
|
||||
"license": "ISC",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/has-flag": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz",
|
||||
"integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/has-symbols": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
|
||||
@@ -2598,8 +2894,8 @@
|
||||
"resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz",
|
||||
"integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==",
|
||||
"deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.",
|
||||
"devOptional": true,
|
||||
"license": "ISC",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"once": "^1.3.0",
|
||||
"wrappy": "1"
|
||||
@@ -2784,6 +3080,12 @@
|
||||
"resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz",
|
||||
"integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg=="
|
||||
},
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/loose-envify": {
|
||||
"version": "1.4.0",
|
||||
"resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz",
|
||||
@@ -2810,6 +3112,12 @@
|
||||
"react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/make-error": {
|
||||
"version": "1.3.6",
|
||||
"resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
|
||||
"integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/make-fetch-happen": {
|
||||
"version": "9.1.0",
|
||||
"resolved": "https://registry.npmjs.org/make-fetch-happen/-/make-fetch-happen-9.1.0.tgz",
|
||||
@@ -3499,8 +3807,8 @@
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz",
|
||||
"integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
@@ -4004,6 +4312,15 @@
|
||||
"node": ">=8.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/require-directory": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz",
|
||||
"integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/resolve": {
|
||||
"version": "1.22.10",
|
||||
"resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz",
|
||||
@@ -4137,6 +4454,15 @@
|
||||
"queue-microtask": "^1.2.2"
|
||||
}
|
||||
},
|
||||
"node_modules/rxjs": {
|
||||
"version": "7.8.2",
|
||||
"resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.2.tgz",
|
||||
"integrity": "sha512-dhKf903U/PQZY6boNNtAGdWbG85WAbjT/1xYoZIC7FAY0yWapOBQVsVrDl58W86//e1VpMNBtRV4MaXfdMySFA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/safe-buffer": {
|
||||
"version": "5.2.1",
|
||||
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",
|
||||
@@ -4247,6 +4573,18 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/shell-quote": {
|
||||
"version": "1.8.3",
|
||||
"resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz",
|
||||
"integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/signal-exit": {
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz",
|
||||
@@ -4370,6 +4708,25 @@
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/source-map-support": {
|
||||
"version": "0.5.21",
|
||||
"resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz",
|
||||
"integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"buffer-from": "^1.0.0",
|
||||
"source-map": "^0.6.0"
|
||||
}
|
||||
},
|
||||
"node_modules/source-map-support/node_modules/source-map": {
|
||||
"version": "0.6.1",
|
||||
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
|
||||
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/sprintf-js": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz",
|
||||
@@ -4545,6 +4902,15 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/strip-bom": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz",
|
||||
"integrity": "sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=4"
|
||||
}
|
||||
},
|
||||
"node_modules/strip-json-comments": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz",
|
||||
@@ -4603,6 +4969,21 @@
|
||||
"node": ">=16 || 14 >=14.17"
|
||||
}
|
||||
},
|
||||
"node_modules/supports-color": {
|
||||
"version": "8.1.1",
|
||||
"resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz",
|
||||
"integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"has-flag": "^4.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/chalk/supports-color?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/supports-preserve-symlinks-flag": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz",
|
||||
@@ -4749,12 +5130,172 @@
|
||||
"node": ">=8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/tree-kill": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/tree-kill/-/tree-kill-1.2.2.tgz",
|
||||
"integrity": "sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==",
|
||||
"dev": true,
|
||||
"bin": {
|
||||
"tree-kill": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-interface-checker": {
|
||||
"version": "0.1.13",
|
||||
"resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz",
|
||||
"integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/ts-node-dev": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/ts-node-dev/-/ts-node-dev-2.0.0.tgz",
|
||||
"integrity": "sha512-ywMrhCfH6M75yftYvrvNarLEY+SUXtUvU8/0Z6llrHQVBx12GiFk5sStF8UdfE/yfzk9IAq7O5EEbTQsxlBI8w==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"chokidar": "^3.5.1",
|
||||
"dynamic-dedupe": "^0.3.0",
|
||||
"minimist": "^1.2.6",
|
||||
"mkdirp": "^1.0.4",
|
||||
"resolve": "^1.0.0",
|
||||
"rimraf": "^2.6.1",
|
||||
"source-map-support": "^0.5.12",
|
||||
"tree-kill": "^1.2.2",
|
||||
"ts-node": "^10.4.0",
|
||||
"tsconfig": "^7.0.0"
|
||||
},
|
||||
"bin": {
|
||||
"ts-node-dev": "lib/bin.js",
|
||||
"tsnd": "lib/bin.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.8.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"node-notifier": "*",
|
||||
"typescript": "*"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"node-notifier": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node-dev/node_modules/arg": {
|
||||
"version": "4.1.3",
|
||||
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
|
||||
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/ts-node-dev/node_modules/brace-expansion": {
|
||||
"version": "1.1.12",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
|
||||
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"balanced-match": "^1.0.0",
|
||||
"concat-map": "0.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node-dev/node_modules/glob": {
|
||||
"version": "7.2.3",
|
||||
"resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz",
|
||||
"integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==",
|
||||
"deprecated": "Glob versions prior to v9 are no longer supported",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"fs.realpath": "^1.0.0",
|
||||
"inflight": "^1.0.4",
|
||||
"inherits": "2",
|
||||
"minimatch": "^3.1.1",
|
||||
"once": "^1.3.0",
|
||||
"path-is-absolute": "^1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": "*"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/isaacs"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node-dev/node_modules/minimatch": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
"integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"brace-expansion": "^1.1.7"
|
||||
},
|
||||
"engines": {
|
||||
"node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node-dev/node_modules/rimraf": {
|
||||
"version": "2.7.1",
|
||||
"resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz",
|
||||
"integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==",
|
||||
"deprecated": "Rimraf versions prior to v4 are no longer supported",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"glob": "^7.1.3"
|
||||
},
|
||||
"bin": {
|
||||
"rimraf": "bin.js"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node-dev/node_modules/ts-node": {
|
||||
"version": "10.9.2",
|
||||
"resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.2.tgz",
|
||||
"integrity": "sha512-f0FFpIdcHgn8zcPSbf1dRevwt047YMnaiJM3u2w2RewrB+fob/zePZcrOyQoLMMO7aBIddLcQIEK5dYjkLnGrQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@cspotcode/source-map-support": "^0.8.0",
|
||||
"@tsconfig/node10": "^1.0.7",
|
||||
"@tsconfig/node12": "^1.0.7",
|
||||
"@tsconfig/node14": "^1.0.0",
|
||||
"@tsconfig/node16": "^1.0.2",
|
||||
"acorn": "^8.4.1",
|
||||
"acorn-walk": "^8.1.1",
|
||||
"arg": "^4.1.0",
|
||||
"create-require": "^1.1.0",
|
||||
"diff": "^4.0.1",
|
||||
"make-error": "^1.1.1",
|
||||
"v8-compile-cache-lib": "^3.0.1",
|
||||
"yn": "3.1.1"
|
||||
},
|
||||
"bin": {
|
||||
"ts-node": "dist/bin.js",
|
||||
"ts-node-cwd": "dist/bin-cwd.js",
|
||||
"ts-node-esm": "dist/bin-esm.js",
|
||||
"ts-node-script": "dist/bin-script.js",
|
||||
"ts-node-transpile-only": "dist/bin-transpile.js",
|
||||
"ts-script": "dist/bin-script-deprecated.js"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@swc/core": ">=1.2.50",
|
||||
"@swc/wasm": ">=1.2.50",
|
||||
"@types/node": "*",
|
||||
"typescript": ">=2.7"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@swc/core": {
|
||||
"optional": true
|
||||
},
|
||||
"@swc/wasm": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/tsconfig": {
|
||||
"version": "7.0.0",
|
||||
"resolved": "https://registry.npmjs.org/tsconfig/-/tsconfig-7.0.0.tgz",
|
||||
"integrity": "sha512-vZXmzPrL+EmC4T/4rVlT2jNVMWCi/O4DIiSj3UHg1OE5kCKbk4mfrXc6dZksLgRM/TZlKnousKH9bbTazUWRRw==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/strip-bom": "^3.0.0",
|
||||
"@types/strip-json-comments": "0.0.30",
|
||||
"strip-bom": "^3.0.0",
|
||||
"strip-json-comments": "^2.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/tslib": {
|
||||
"version": "2.8.1",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz",
|
||||
@@ -4829,6 +5370,12 @@
|
||||
"resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz",
|
||||
"integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw=="
|
||||
},
|
||||
"node_modules/v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
"integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/which": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
|
||||
@@ -4996,6 +5543,24 @@
|
||||
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==",
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/xtend": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",
|
||||
"integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/y18n": {
|
||||
"version": "5.0.8",
|
||||
"resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz",
|
||||
"integrity": "sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/yallist": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz",
|
||||
@@ -5012,6 +5577,83 @@
|
||||
"engines": {
|
||||
"node": ">= 14"
|
||||
}
|
||||
},
|
||||
"node_modules/yargs": {
|
||||
"version": "17.7.2",
|
||||
"resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz",
|
||||
"integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"cliui": "^8.0.1",
|
||||
"escalade": "^3.1.1",
|
||||
"get-caller-file": "^2.0.5",
|
||||
"require-directory": "^2.1.1",
|
||||
"string-width": "^4.2.3",
|
||||
"y18n": "^5.0.5",
|
||||
"yargs-parser": "^21.1.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/yargs-parser": {
|
||||
"version": "21.1.1",
|
||||
"resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz",
|
||||
"integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/yargs/node_modules/ansi-regex": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz",
|
||||
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/yargs/node_modules/emoji-regex": {
|
||||
"version": "8.0.0",
|
||||
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
|
||||
"integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/yargs/node_modules/string-width": {
|
||||
"version": "4.2.3",
|
||||
"resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz",
|
||||
"integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"emoji-regex": "^8.0.0",
|
||||
"is-fullwidth-code-point": "^3.0.0",
|
||||
"strip-ansi": "^6.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/yargs/node_modules/strip-ansi": {
|
||||
"version": "6.0.1",
|
||||
"resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz",
|
||||
"integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"ansi-regex": "^5.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/yn": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
|
||||
"integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
|
||||
"dev": true,
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start --port 8675",
|
||||
"dev": "concurrently -k -n WORKER,UI \"ts-node-dev --respawn --watch cron --transpile-only cron/worker.ts\" \"next dev --turbopack\"",
|
||||
"build": "tsc -p tsconfig.worker.json && next build",
|
||||
"start": "concurrently --restart-tries -1 --restart-after 1000 -n WORKER,UI \"node dist/worker.js\" \"next start --port 8675\"",
|
||||
"build_and_start": "npm install && npm run update_db && npm run build && npm run start",
|
||||
"lint": "next lint",
|
||||
"update_db": "npx prisma generate && npx prisma db push",
|
||||
@@ -34,10 +34,12 @@
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"concurrently": "^9.1.2",
|
||||
"postcss": "^8",
|
||||
"prettier": "^3.5.1",
|
||||
"prettier-basic": "^1.0.0",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"ts-node-dev": "^2.0.0",
|
||||
"typescript": "^5"
|
||||
},
|
||||
"prettier": "prettier-basic"
|
||||
|
||||
@@ -26,3 +26,13 @@ model Job {
|
||||
info String @default("")
|
||||
speed_string String @default("")
|
||||
}
|
||||
|
||||
model Queue {
|
||||
id String @id @default(uuid())
|
||||
channel String
|
||||
job_id String
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @updatedAt
|
||||
status String @default("waiting")
|
||||
@@index([job_id, channel])
|
||||
}
|
||||
@@ -45,7 +45,7 @@ function findImagesRecursively(dir: string): string[] {
|
||||
const itemPath = path.join(dir, item);
|
||||
const stat = fs.statSync(itemPath);
|
||||
|
||||
if (stat.isDirectory()) {
|
||||
if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) {
|
||||
// If it's a directory, recursively search it
|
||||
results = results.concat(findImagesRecursively(itemPath));
|
||||
} else {
|
||||
|
||||
@@ -47,6 +47,7 @@ export default function SimpleJob({
|
||||
<TextInput
|
||||
label="Training Name"
|
||||
value={jobConfig.config.name}
|
||||
docKey="config.name"
|
||||
onChange={value => setJobConfig(value, 'config.name')}
|
||||
placeholder="Enter training name"
|
||||
disabled={runId !== null}
|
||||
@@ -55,12 +56,14 @@ export default function SimpleJob({
|
||||
<SelectInput
|
||||
label="GPU ID"
|
||||
value={`${gpuIDs}`}
|
||||
docKey="gpuids"
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
||||
/>
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
docKey="config.process[0].trigger_word"
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
@@ -120,6 +123,7 @@ export default function SimpleJob({
|
||||
<TextInput
|
||||
label="Name or Path"
|
||||
value={jobConfig.config.process[0].model.name_or_path}
|
||||
docKey="config.process[0].model.name_or_path"
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
@@ -185,22 +189,20 @@ export default function SimpleJob({
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
{
|
||||
modelArch?.disableSections?.includes('network.conv') ? null : (
|
||||
<NumberInput
|
||||
label="Conv Rank"
|
||||
value={jobConfig.config.process[0].network.conv}
|
||||
onChange={value => {
|
||||
console.log('onChange', value);
|
||||
setJobConfig(value, 'config.process[0].network.conv');
|
||||
setJobConfig(value, 'config.process[0].network.conv_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={0}
|
||||
max={1024}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{modelArch?.disableSections?.includes('network.conv') ? null : (
|
||||
<NumberInput
|
||||
label="Conv Rank"
|
||||
value={jobConfig.config.process[0].network.conv}
|
||||
onChange={value => {
|
||||
console.log('onChange', value);
|
||||
setJobConfig(value, 'config.process[0].network.conv');
|
||||
setJobConfig(value, 'config.process[0].network.conv_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={0}
|
||||
max={1024}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
@@ -7,6 +7,7 @@ import ConfirmModal from '@/components/ConfirmModal';
|
||||
import SampleImageModal from '@/components/SampleImageModal';
|
||||
import { Suspense } from 'react';
|
||||
import AuthWrapper from '@/components/AuthWrapper';
|
||||
import DocModal from '@/components/DocModal';
|
||||
|
||||
export const dynamic = 'force-dynamic';
|
||||
|
||||
@@ -38,6 +39,7 @@ export default function RootLayout({ children }: { children: React.ReactNode })
|
||||
</AuthWrapper>
|
||||
</ThemeProvider>
|
||||
<ConfirmModal />
|
||||
<DocModal />
|
||||
<SampleImageModal />
|
||||
</body>
|
||||
</html>
|
||||
|
||||
59
ui/src/components/DocModal.tsx
Normal file
59
ui/src/components/DocModal.tsx
Normal file
@@ -0,0 +1,59 @@
|
||||
'use client';
|
||||
import { createGlobalState } from 'react-global-hooks';
|
||||
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
|
||||
import React from 'react';
|
||||
import { ConfigDoc } from '@/types';
|
||||
|
||||
export const docState = createGlobalState<ConfigDoc | null>(null);
|
||||
|
||||
export const openDoc = (doc: ConfigDoc) => {
|
||||
docState.set({ ...doc });
|
||||
};
|
||||
|
||||
export default function DocModal() {
|
||||
const [doc, setDoc] = docState.use();
|
||||
const isOpen = !!doc;
|
||||
|
||||
const onClose = () => {
|
||||
setDoc(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onClose={onClose} className="relative z-10">
|
||||
<DialogBackdrop
|
||||
transition
|
||||
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
|
||||
/>
|
||||
|
||||
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
|
||||
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
||||
<DialogPanel
|
||||
transition
|
||||
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-[50rem] data-closed:sm:translate-y-0 data-closed:sm:scale-95"
|
||||
>
|
||||
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
|
||||
<div className="sm:flex sm:items-start">
|
||||
<div className="mt-3 text-center sm:mt-0 sm:ml-4 sm:text-left flex-1">
|
||||
<DialogTitle as="h3" className={`text-base font-semibold `}>
|
||||
{doc?.title || 'Confirm Action'}
|
||||
</DialogTitle>
|
||||
<div className="mt-2 text-sm text-gray-200">{doc?.description}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
|
||||
<button
|
||||
type="button"
|
||||
data-autofocus
|
||||
onClick={onClose}
|
||||
className="mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0"
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</div>
|
||||
</DialogPanel>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import Link from 'next/link';
|
||||
import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react';
|
||||
import { Home, Settings, BrainCircuit, Images, Plus} from 'lucide-react';
|
||||
import { FaXTwitter, FaDiscord, FaYoutube } from "react-icons/fa6";
|
||||
|
||||
const Sidebar = () => {
|
||||
const navigation = [
|
||||
@@ -10,13 +11,16 @@ const Sidebar = () => {
|
||||
{ name: 'Settings', href: '/settings', icon: Settings },
|
||||
];
|
||||
|
||||
const socialsBoxClass = 'flex flex-col items-center justify-center p-1 hover:bg-gray-800 rounded-lg transition-colors';
|
||||
const socialIconClass = 'w-5 h-5 text-gray-400 hover:text-white';
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-59 bg-gray-900 text-gray-100">
|
||||
<div className="px-4 py-3">
|
||||
<h1 className="text-l">
|
||||
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-7 mr-3 inline" />
|
||||
<span className="font-bold uppercase">Ostris</span>
|
||||
<span className='ml-2 uppercase text-gray-300'>AI-Toolkit</span>
|
||||
<span className="ml-2 uppercase text-gray-300">AI-Toolkit</span>
|
||||
</h1>
|
||||
</div>
|
||||
<nav className="flex-1">
|
||||
@@ -34,7 +38,12 @@ const Sidebar = () => {
|
||||
))}
|
||||
</ul>
|
||||
</nav>
|
||||
<a href="https://ostris.com/support" target="_blank" rel="noreferrer" className="flex items-center space-x-2 px-4 py-3">
|
||||
<a
|
||||
href="https://ostris.com/support"
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
className="flex items-center space-x-2 px-4 py-3"
|
||||
>
|
||||
<div className="min-w-[26px] min-h-[26px]">
|
||||
<svg height="24" version="1.1" width="24" xmlns="http://www.w3.org/2000/svg">
|
||||
<g transform="translate(0 -1028.4)">
|
||||
@@ -47,6 +56,39 @@ const Sidebar = () => {
|
||||
</div>
|
||||
<div className="uppercase text-gray-500 text-sm mb-2 flex-1 pt-2 pl-0">Support AI-Toolkit</div>
|
||||
</a>
|
||||
|
||||
{/* Social links grid */}
|
||||
<div className="px-1 py-1 border-t border-gray-800">
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<a
|
||||
href="https://discord.gg/VXmU2f5WEU"
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
className={socialsBoxClass}
|
||||
>
|
||||
<FaDiscord className={socialIconClass} />
|
||||
{/* <span className="text-xs text-gray-500 mt-1">Discord</span> */}
|
||||
</a>
|
||||
<a
|
||||
href="https://www.youtube.com/@ostrisai"
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
className={socialsBoxClass}
|
||||
>
|
||||
<FaYoutube className={socialIconClass} />
|
||||
{/* <span className="text-xs text-gray-500 mt-1">YouTube</span> */}
|
||||
</a>
|
||||
<a
|
||||
href="https://x.com/ostrisai"
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
className={socialsBoxClass}
|
||||
>
|
||||
<FaXTwitter className={socialIconClass} />
|
||||
{/* <span className="text-xs text-gray-500 mt-1">X</span> */}
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
import React, { forwardRef } from 'react';
|
||||
import classNames from 'classnames';
|
||||
import dynamic from 'next/dynamic';
|
||||
import { CircleHelp } from 'lucide-react';
|
||||
import { getDoc } from '@/docs';
|
||||
import { openDoc } from '@/components/DocModal';
|
||||
|
||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||
|
||||
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
|
||||
@@ -11,6 +15,7 @@ const inputClasses =
|
||||
|
||||
export interface InputProps {
|
||||
label?: string;
|
||||
docKey?: string;
|
||||
className?: string;
|
||||
placeholder?: string;
|
||||
required?: boolean;
|
||||
@@ -24,10 +29,20 @@ export interface TextInputProps extends InputProps {
|
||||
}
|
||||
|
||||
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
||||
({ label, value, onChange, placeholder, required, disabled, type = 'text', className }, ref) => {
|
||||
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => {
|
||||
const doc = getDoc(docKey);
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<input
|
||||
ref={ref}
|
||||
type={type}
|
||||
@@ -56,7 +71,8 @@ export interface NumberInputProps extends InputProps {
|
||||
}
|
||||
|
||||
export const NumberInput = (props: NumberInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, min, max } = props;
|
||||
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
|
||||
const doc = getDoc(docKey);
|
||||
|
||||
// Add controlled internal state to properly handle partial inputs
|
||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||
@@ -68,7 +84,16 @@ export const NumberInput = (props: NumberInputProps) => {
|
||||
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<input
|
||||
type="number"
|
||||
value={inputValue}
|
||||
@@ -120,7 +145,8 @@ export interface SelectInputProps extends InputProps {
|
||||
}
|
||||
|
||||
export const SelectInput = (props: SelectInputProps) => {
|
||||
const { label, value, onChange, options } = props;
|
||||
const { label, value, onChange, options, docKey = null } = props;
|
||||
const doc = getDoc(docKey);
|
||||
const selectedOption = options.find(option => option.value === value);
|
||||
return (
|
||||
<div
|
||||
@@ -128,7 +154,16 @@ export const SelectInput = (props: SelectInputProps) => {
|
||||
'opacity-30 cursor-not-allowed': props.disabled,
|
||||
})}
|
||||
>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<Select
|
||||
value={selectedOption}
|
||||
options={options}
|
||||
@@ -200,13 +235,24 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
interface FormGroupProps {
|
||||
label?: string;
|
||||
className?: string;
|
||||
docKey?: string;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children }) => {
|
||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => {
|
||||
const doc = getDoc(docKey);
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<div className="px-4 space-y-2">{children}</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
60
ui/src/docs.tsx
Normal file
60
ui/src/docs.tsx
Normal file
@@ -0,0 +1,60 @@
|
||||
import React from 'react';
|
||||
import { ConfigDoc } from '@/types';
|
||||
|
||||
const docs: { [key: string]: ConfigDoc } = {
|
||||
'config.name': {
|
||||
title: 'Training Name',
|
||||
description: (
|
||||
<>
|
||||
The name of the training job. This name will be used to identify the job in the system and will the the filename
|
||||
of the final model. It must be unique and can only contain alphanumeric characters, underscores, and dashes. No
|
||||
spaces or special characters are allowed.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'gpuids': {
|
||||
title: 'GPU ID',
|
||||
description: (
|
||||
<>
|
||||
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
|
||||
However, you can start multiple jobs in parallel, each using a different GPU.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'config.process[0].trigger_word': {
|
||||
title: 'Trigger Word',
|
||||
description: (
|
||||
<>
|
||||
Optional: This will be the word or token used to trigger your concept or character.
|
||||
<br />
|
||||
<br />
|
||||
When using a trigger word,
|
||||
If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have
|
||||
captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots,
|
||||
you can use the <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger word.
|
||||
<br />
|
||||
<br />
|
||||
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the
|
||||
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'config.process[0].model.name_or_path': {
|
||||
title: 'Name or Path',
|
||||
description: (
|
||||
<>
|
||||
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in
|
||||
diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
if (key && key in docs) {
|
||||
return docs[key];
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
export default docs;
|
||||
@@ -59,7 +59,7 @@ export interface NetworkConfig {
|
||||
lokr_factor: number;
|
||||
network_kwargs: {
|
||||
ignore_if_contains: string[];
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
export interface SaveConfig {
|
||||
@@ -125,7 +125,7 @@ export interface ModelConfig {
|
||||
quantize_kwargs?: QuantizeKwargsConfig;
|
||||
arch: string;
|
||||
low_vram: boolean;
|
||||
model_kwargs: {[key: string]: any};
|
||||
model_kwargs: { [key: string]: any };
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
@@ -173,3 +173,8 @@ export interface JobConfig {
|
||||
config: ConfigObject;
|
||||
meta: MetaConfig;
|
||||
}
|
||||
|
||||
export interface ConfigDoc {
|
||||
title: string;
|
||||
description: React.ReactNode;
|
||||
}
|
||||
|
||||
15
ui/tsconfig.worker.json
Normal file
15
ui/tsconfig.worker.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
// tsconfig.worker.json
|
||||
"compilerOptions": {
|
||||
"module": "commonjs",
|
||||
"target": "es2020",
|
||||
"outDir": "dist",
|
||||
"moduleResolution": "node",
|
||||
"types": [
|
||||
"node"
|
||||
]
|
||||
},
|
||||
"include": [
|
||||
"cron/**/*.ts"
|
||||
]
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.3.0"
|
||||
VERSION = "0.3.1"
|
||||
Reference in New Issue
Block a user