diff --git a/README.md b/README.md index 5e68605c..afb59dce 100644 --- a/README.md +++ b/README.md @@ -417,6 +417,17 @@ Everything else should work the same including layer targeting. ## Updates +### June 17, 2024 +- Performance optimizations for batch preparation +- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP + +### June 16, 2024 +- Hide control images in the UI when viewing datasets +- WIP on mean flow loss + +### June 12, 2024 +- Fixed issue that resulted in blank captions in the dataloader + ### June 10, 2024 - Decided to keep track up updates in the readme - Added support for SDXL in the UI diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f3813133..b4439f30 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -36,7 +36,6 @@ from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.util.wavelet_loss import wavelet_loss import torch.nn.functional as F -from toolkit.models.flux import convert_flux_to_mean_flow def flush(): @@ -62,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess): self._clip_image_embeds_unconditional: Union[List[str], None] = None self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None - self.cfm_cache = None self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" @@ -85,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess): self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None self.dfe: Optional[DiffusionFeatureExtractor] = None + self.unconditional_embeds = None if self.train_config.diff_output_preservation: if self.trigger_word is None: @@ -96,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess): # always do a prior prediction when doing diff output preservation self.do_prior_prediction = True + + # store the loss target for a batch so we can use it in a loss + self._guidance_loss_target_batch: float = 0.0 + if isinstance(self.train_config.guidance_loss_target, (int, float)): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target) + elif isinstance(self.train_config.guidance_loss_target, list): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) + else: + raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") def before_model_load(self): @@ -135,9 +143,16 @@ class SDTrainer(BaseSDTrainProcess): def hook_before_train_loop(self): super().hook_before_train_loop() - if self.train_config.loss_type == "mean_flow": - # todo handle non flux models - convert_flux_to_mean_flow(self.sd.unet) + + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + self.unconditional_embeds = self.sd.encode_prompt( + [self.train_config.unconditional_prompt], + long_prompts=self.do_long_prompts + ).to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() if self.train_config.do_prior_divergence: self.do_prior_prediction = True @@ -418,15 +433,41 @@ class SDTrainer(BaseSDTrainProcess): if self.dfe is not None: if self.dfe.version == 1: - # do diffusion feature extraction on target + model = self.sd + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + self.sd.noise_scheduler._step_index = None + self.sd.noise_scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index] + sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) + pred_features = self.dfe(stepped_latents.float()) with torch.no_grad(): - rectified_flow_target = noise.float() - batch.latents.float() - target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32)) + # scale dfe so it is weaker at higher noise levels + dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) - # do diffusion feature extraction on prediction - pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) - additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \ - self.train_config.diffusion_feature_extractor_weight + dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ + self.train_config.diffusion_feature_extractor_weight * dfe_scaler + additional_loss += dfe_loss.mean() elif self.dfe.version == 2: # version 2 # do diffusion feature extraction on target @@ -454,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess): additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight else: raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") + + if self.train_config.do_guidance_loss: + with torch.no_grad(): + # we make cached blank prompt embeds that match the batch size + unconditional_embeds = concat_prompt_embeds( + [self.unconditional_embeds] * noisy_latents.shape[0], + ) + cfm_pred = self.predict_noise( + noisy_latents=noisy_latents, + timesteps=timesteps, + conditional_embeds=unconditional_embeds, + unconditional_embeds=None, + batch=batch, + ) + + # zero cfg + + # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 + batch_size = target.shape[0] + positive_flat = target.view(batch_size, -1) + negative_flat = cfm_pred.view(batch_size, -1) + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + alpha = st_star + + is_video = len(target.shape) == 5 + + alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) + + guidance_scale = self._guidance_loss_target_batch + if isinstance(guidance_scale, list): + guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) + guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) + + unconditional_target = cfm_pred * alpha + target = unconditional_target + guidance_scale * (target - unconditional_target) if target is None: @@ -634,102 +716,6 @@ class SDTrainer(BaseSDTrainProcess): return loss - # ------------------------------------------------------------------ - # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative - # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) - # This version avoids jvp / double-back-prop issues with Flash-Attention - # adapted from the work of lodestonerock - # ------------------------------------------------------------------ - def get_mean_flow_loss_wip( - self, - noisy_latents: torch.Tensor, - conditional_embeds: PromptEmbeds, - match_adapter_assist: bool, - network_weight_list: list, - timesteps: torch.Tensor, - pred_kwargs: dict, - batch: 'DataLoaderBatchDTO', - noise: torch.Tensor, - unconditional_embeds: Optional[PromptEmbeds] = None, - **kwargs - ): - batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) - - - time_end = timesteps.float() / 1000 - # for timestep_r, we need values from timestep_end to 0.0 randomly - time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end - - # time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype) - # Compute noised data points - # lerp_vector = noisy_latents - # compute instantaneous vector - instantaneous_vector = noise - batch_latents - - # finite difference method - epsilon_fd = 1e-3 - jitter_std = 1e-4 - epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std - epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4) - - # f(x + epsilon * v) for the primal (we backprop through here) - # mean_vec_val_pred = self.forward(lerp_vector, class_label) - mean_vec_val_pred = self.predict_noise( - noisy_latents=noisy_latents, - timesteps=torch.cat([time_end, time_origin], dim=0) * 1000, - conditional_embeds=conditional_embeds, - unconditional_embeds=unconditional_embeds, - batch=batch, - **pred_kwargs - ) - - with torch.no_grad(): - perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0) - # intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE! - perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector - # f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label) - f_x_plus_eps_v = self.predict_noise( - noisy_latents=perturbed_lerp_vector, - timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000, - conditional_embeds=conditional_embeds, - unconditional_embeds=unconditional_embeds, - batch=batch, - **pred_kwargs - ) - - # JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon - mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered - mean_vec_grad = mean_vec_grad_fd - - - # calculate the regression target the mean vector - time_difference_broadcast = (time_end - time_origin)[:, None, None, None] - mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad - - # 5) MSE loss - loss = torch.nn.functional.mse_loss( - mean_vec_val_pred.float(), - mean_vec_target.float(), - reduction='none' - ) - with torch.no_grad(): - pure_loss = loss.mean().detach() - # add grad to pure_loss so it can be backwards without issues - pure_loss.requires_grad_(True) - # normalize the loss per batch element to 1.0 - # this method has large loss swings that can hurt the model. This method will prevent that - with torch.no_grad(): - loss_mean = loss.mean([1, 2, 3], keepdim=True) - loss = loss / loss_mean - loss = loss.mean() - - # backward the pure loss for logging - self.accelerator.backward(loss) - - # return the real loss for logging - return pure_loss - - # ------------------------------------------------------------------ # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) @@ -970,6 +956,10 @@ class SDTrainer(BaseSDTrainProcess): if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), @@ -977,6 +967,7 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, rescale_cfg=self.train_config.cfg_rescale, batch=batch, **pred_kwargs # adapter residuals in here @@ -1020,13 +1011,16 @@ class SDTrainer(BaseSDTrainProcess): **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch return self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, - guidance_embedding_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, detach_unconditional=False, rescale_cfg=self.train_config.cfg_rescale, bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, @@ -1034,152 +1028,78 @@ class SDTrainer(BaseSDTrainProcess): **kwargs ) - def cfm_augment_tensors( - self, - images: torch.Tensor - ) -> torch.Tensor: - if self.cfm_cache is None: - # flip the current one. Only need this for first time - self.cfm_cache = torch.flip(images, [3]).clone() - augmented_tensor_list = [] - for i in range(images.shape[0]): - # get a random one - idx = random.randint(0, self.cfm_cache.shape[0] - 1) - augmented_tensor_list.append(self.cfm_cache[idx:idx + 1]) - augmented = torch.cat(augmented_tensor_list, dim=0) - # resize to match the input - augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear') - self.cfm_cache = images.clone() - return augmented - - def get_cfm_loss( - self, - noisy_latents: torch.Tensor, - noise: torch.Tensor, - noise_pred: torch.Tensor, - conditional_embeds: PromptEmbeds, - timesteps: torch.Tensor, - batch: 'DataLoaderBatchDTO', - alpha: float = 0.1, - ): - dtype = get_torch_dtype(self.train_config.dtype) - if hasattr(self.sd, 'get_loss_target'): - target = self.sd.get_loss_target( - noise=noise, - batch=batch, - timesteps=timesteps, - ).detach() - - elif self.sd.is_flow_matching: - # forward ODE - target = (noise - batch.latents).detach() - else: - raise ValueError("CFM loss only works with flow matching") - fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - with torch.no_grad(): - # we need to compute the contrast - cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype) - cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype) - cfm_noisy_latents = self.sd.add_noise( - original_samples=cfm_latents, - noise=noise, - timesteps=timesteps, - ) - cfm_pred = self.predict_noise( - noisy_latents=cfm_noisy_latents, - timesteps=timesteps, - conditional_embeds=conditional_embeds, - unconditional_embeds=None, - batch=batch, - ) - - # v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1) - # v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W) - - # # Compute cosine similarity at each pixel - # sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W) - - cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) - # Compute cosine similarity at each pixel - sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W) - - # Average over spatial dimensions, then batch - contrastive_loss = -sim.mean() - - loss = fm_loss.mean() + alpha * contrastive_loss - return loss def train_single_accumulation(self, batch: DataLoaderBatchDTO): - self.timer.start('preprocess_batch') - if isinstance(self.adapter, CustomAdapter): - batch = self.adapter.edit_batch_raw(batch) - batch = self.preprocess_batch(batch) - if isinstance(self.adapter, CustomAdapter): - batch = self.adapter.edit_batch_processed(batch) - dtype = get_torch_dtype(self.train_config.dtype) - # sanity check - if self.sd.vae.dtype != self.sd.vae_torch_dtype: - self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) - if isinstance(self.sd.text_encoder, list): - for encoder in self.sd.text_encoder: - if encoder.dtype != self.sd.te_torch_dtype: - encoder.to(self.sd.te_torch_dtype) - else: - if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: - self.sd.text_encoder.to(self.sd.te_torch_dtype) - - noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) - if self.train_config.do_cfg or self.train_config.do_random_cfg: - # pick random negative prompts - if self.negative_prompt_pool is not None: - negative_prompts = [] - for i in range(noisy_latents.shape[0]): - num_neg = random.randint(1, self.train_config.max_negative_prompts) - this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] - this_neg_prompt = ', '.join(this_neg_prompts) - negative_prompts.append(this_neg_prompt) - self.batch_negative_prompt = negative_prompts - else: - self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] - - if self.adapter and isinstance(self.adapter, CustomAdapter): - # condition the prompt - # todo handle more than one adapter image - conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) - - network_weight_list = batch.get_network_weight_list() - if self.train_config.single_item_batching: - network_weight_list = network_weight_list + network_weight_list - - has_adapter_img = batch.control_tensor is not None - has_clip_image = batch.clip_image_tensor is not None - has_clip_image_embeds = batch.clip_image_embeds is not None - # force it to be true if doing regs as we handle those differently - if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): - has_clip_image = True - if self._clip_image_embeds_unconditional is not None: - has_clip_image_embeds = True # we are caching embeds, handle that differently - has_clip_image = False - - if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: - raise ValueError( - "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") - - match_adapter_assist = False - - # check if we are matching the adapter assistant - if self.assistant_adapter: - if self.train_config.match_adapter_chance == 1.0: - match_adapter_assist = True - elif self.train_config.match_adapter_chance > 0.0: - match_adapter_assist = torch.rand( - (1,), device=self.device_torch, dtype=dtype - ) < self.train_config.match_adapter_chance - - self.timer.stop('preprocess_batch') - - is_reg = False with torch.no_grad(): + self.timer.start('preprocess_batch') + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_raw(batch) + batch = self.preprocess_batch(batch) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_processed(batch) + dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + + network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list + + has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + + self.timer.stop('preprocess_batch') + + is_reg = False loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) for idx, file_item in enumerate(batch.file_items): if file_item.is_reg: @@ -1733,6 +1653,16 @@ class SDTrainer(BaseSDTrainProcess): ) pred_kwargs['down_block_additional_residuals'] = down_block_res_samples pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + + if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list): + batch_size = noisy_latents.shape[0] + # update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1] + self._guidance_loss_target_batch = [ + random.uniform( + self.train_config.guidance_loss_target[0], + self.train_config.guidance_loss_target[1] + ) for _ in range(batch_size) + ] self.before_unet_predict() @@ -1832,25 +1762,15 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diff_output_preservation and not do_inverted_masked_prior: prior_to_calculate_loss = None - if self.train_config.loss_type == 'cfm': - loss = self.get_cfm_loss( - noisy_latents=noisy_latents, - noise=noise, - noise_pred=noise_pred, - conditional_embeds=conditional_embeds, - timesteps=timesteps, - batch=batch, - ) - else: - loss = self.calculate_loss( - noise_pred=noise_pred, - noise=noise, - noisy_latents=noisy_latents, - timesteps=timesteps, - batch=batch, - mask_multiplier=mask_multiplier, - prior_pred=prior_to_calculate_loss, - ) + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_to_calculate_loss, + ) if self.train_config.diff_output_preservation: # send the loss backwards otherwise checkpointing will fail diff --git a/info.py b/info.py index 9f2f0a97..2eb2a82e 100644 --- a/info.py +++ b/info.py @@ -1,8 +1,9 @@ from collections import OrderedDict +from version import VERSION v = OrderedDict() v["name"] = "ai-toolkit" v["repo"] = "https://github.com/ostris/ai-toolkit" -v["version"] = "0.1.0" +v["version"] = VERSION software_meta = v diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py index 8efd0097..5b339ebe 100644 --- a/jobs/BaseJob.py +++ b/jobs/BaseJob.py @@ -15,7 +15,6 @@ class BaseJob: self.config = config['config'] self.raw_config = config self.job = config['job'] - self.torch_profiler = self.get_conf('torch_profiler', False) self.name = self.get_conf('name', required=True) if 'meta' in config: self.meta = config['meta'] diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 064e89fc..814796e6 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess): self.ema: ExponentialMovingAverage = None validate_configs(self.train_config, self.model_config, self.save_config) + + do_profiler = self.get_conf('torch_profiler', False) + self.torch_profiler = None if not do_profiler else torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -575,10 +583,8 @@ class BaseSDTrainProcess(BaseTrainProcess): direct_save = False if self.adapter_config.train_only_image_encoder: direct_save = True - if self.adapter_config.type == 'redux': - direct_save = True - if self.adapter_config.type in ['control_lora', 'subpixel', 'i2v']: - direct_save = True + elif isinstance(self.adapter, CustomAdapter): + direct_save = self.adapter.do_direct_save save_ip_adapter_from_diffusers( state_dict, output_file=file_path, @@ -923,7 +929,10 @@ class BaseSDTrainProcess(BaseTrainProcess): noise = self.get_consistent_noise(latents, batch, dtype=dtype) else: if hasattr(self.sd, 'get_latent_noise_from_latents'): - noise = self.sd.get_latent_noise_from_latents(latents).to(self.device_torch, dtype=dtype) + noise = self.sd.get_latent_noise_from_latents( + latents, + noise_offset=self.train_config.noise_offset + ).to(self.device_torch, dtype=dtype) else: # get noise noise = self.sd.get_latent_noise( @@ -933,17 +942,6 @@ class BaseSDTrainProcess(BaseTrainProcess): batch_size=batch_size, noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) - - # if self.train_config.random_noise_shift > 0.0: - # # get random noise -1 to 1 - # noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, - # dtype=noise.dtype) * 2 - 1 - - # # multiply by shift amount - # noise_shift *= self.train_config.random_noise_shift - - # # add to noise - # noise += noise_shift if self.train_config.blended_blur_noise: noise = get_blended_blur_noise( @@ -1014,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) imgs = None is_reg = any(batch.get_is_reg_list()) - cfm_batch = None if batch.tensor is not None: imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) @@ -1087,19 +1084,20 @@ class BaseSDTrainProcess(BaseTrainProcess): # we determine noise from the differential of the latents unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) - batch_size = len(batch.file_items) - min_noise_steps = self.train_config.min_denoising_steps - max_noise_steps = self.train_config.max_denoising_steps - if self.model_config.refiner_name_or_path is not None: - # if we are not training the unet, then we are only doing refiner and do not need to double up - if self.train_config.train_unet: - max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) - do_double = True - else: - min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) - do_double = False + with self.timer('prepare_scheduler'): + + batch_size = len(batch.file_items) + min_noise_steps = self.train_config.min_denoising_steps + max_noise_steps = self.train_config.max_denoising_steps + if self.model_config.refiner_name_or_path is not None: + # if we are not training the unet, then we are only doing refiner and do not need to double up + if self.train_config.train_unet: + max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = True + else: + min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = False - with self.timer('prepare_noise'): num_train_timesteps = self.train_config.num_train_timesteps if self.train_config.noise_scheduler in ['custom_lcm']: @@ -1146,6 +1144,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch ) + with self.timer('prepare_timesteps_indices'): content_or_style = self.train_config.content_or_style if is_reg: @@ -1195,20 +1194,26 @@ class BaseSDTrainProcess(BaseTrainProcess): timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps else: # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here + min_idx = min_noise_steps + 1 + max_idx = max_noise_steps - 1 + if self.train_config.noise_scheduler == 'flowmatch': + # flowmatch uses indices, so we need to use indices + min_idx = 0 + max_idx = max_noise_steps - 1 timestep_indices = torch.randint( - min_noise_steps + 1, - max_noise_steps - 1, + min_idx, + max_idx, (batch_size,), device=self.device_torch ) timestep_indices = timestep_indices.long() else: raise ValueError(f"Unknown content_or_style {content_or_style}") - + with self.timer('convert_timestep_indices_to_timesteps'): # convert the timestep_indices to a timestep - timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] - timesteps = torch.stack(timesteps, dim=0) - + timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] + + with self.timer('prepare_noise'): # get noise noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps) @@ -1242,6 +1247,8 @@ class BaseSDTrainProcess(BaseTrainProcess): device=noise.device, dtype=noise.dtype ) * self.train_config.random_noise_multiplier + + with self.timer('make_noisy_latents'): noise = noise * noise_multiplier @@ -2058,6 +2065,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### + if self.torch_profiler is not None: + self.torch_profiler.start() with self.accelerator.accumulate(self.modules_being_trained): try: loss_dict = self.hook_train_loop(batch_list) @@ -2069,7 +2078,12 @@ class BaseSDTrainProcess(BaseTrainProcess): for item in batch.file_items: print(f" - {item.path}") raise e - + if self.torch_profiler is not None: + torch.cuda.synchronize() # Make sure all CUDA ops are done + self.torch_profiler.stop() + + print("\n==== Profile Results ====") + print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) self.timer.stop('train_loop') if not did_first_flush: flush() diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index e0cb32d8..44fc6b28 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess): prompt_image_configs = [] for _ in range(self.generate_config.num_repeats): for prompt in self.generate_config.prompts: + # remove -- + prompt = prompt.replace('--', '').strip() width = self.generate_config.width height = self.generate_config.height # prompt = self.clean_prompt(prompt) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index c1fd396c..aeca126e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict +from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict import random import torch @@ -413,7 +413,7 @@ class TrainConfig: self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) - self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) @@ -454,8 +454,12 @@ class TrainConfig: self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False) # diffusion feature extractor - self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None) - self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0) + self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None) + self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0) + + # we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture + self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path) + self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight) # optimal noise pairing self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1) @@ -463,6 +467,13 @@ class TrainConfig: # forces same noise for the same image at a given size. self.force_consistent_noise = kwargs.get('force_consistent_noise', False) self.blended_blur_noise = kwargs.get('blended_blur_noise', False) + + # contrastive loss + self.do_guidance_loss = kwargs.get('do_guidance_loss', False) + self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0) + self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '') + if isinstance(self.guidance_loss_target, tuple): + self.guidance_loss_target = list(self.guidance_loss_target) ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] @@ -827,6 +838,9 @@ class DatasetConfig: self.controls = [self.controls] # remove empty strings self.controls = [control for control in self.controls if control.strip() != ''] + + # if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing + self.fast_image_size: bool = kwargs.get('fast_image_size', False) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: @@ -1141,3 +1155,6 @@ def validate_configs( if model_config.use_flux_cfg: # bypass the embedding train_config.bypass_guidance_embedding = True + if train_config.bypass_guidance_embedding and train_config.do_guidance_loss: + raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. " + "Please set bypass_guidance_embedding to False or do_guidance_loss to False.") diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index a9ec5d10..cc58c5b5 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.control_lora_adapter import ControlLoraAdapter +from toolkit.models.mean_flow_adapter import MeanFlowAdapter from toolkit.models.i2v_adapter import I2VAdapter from toolkit.models.subpixel_adapter import SubpixelAdapter from toolkit.models.ilora import InstantLoRAModule @@ -98,6 +99,7 @@ class CustomAdapter(torch.nn.Module): self.single_value_adapter: SingleValueAdapter = None self.redux_adapter: ReduxImageEncoder = None self.control_lora: ControlLoraAdapter = None + self.mean_flow_adapter: MeanFlowAdapter = None self.subpixel_adapter: SubpixelAdapter = None self.i2v_adapter: I2VAdapter = None @@ -125,6 +127,16 @@ class CustomAdapter(torch.nn.Module): dtype=self.sd_ref().dtype, ) self.load_state_dict(loaded_state_dict, strict=False) + + @property + def do_direct_save(self): + # some adapters save their weights directly, others like ip adapters split the state dict + if self.config.train_only_image_encoder: + return True + if self.config.type in ['control_lora', 'subpixel', 'i2v', 'redux', 'mean_flow']: + return True + return False + def setup_adapter(self): torch_dtype = get_torch_dtype(self.sd_ref().dtype) @@ -245,6 +257,13 @@ class CustomAdapter(torch.nn.Module): elif self.adapter_type == 'redux': vision_hidden_size = self.vision_encoder.config.hidden_size self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype) + elif self.adapter_type == 'mean_flow': + self.mean_flow_adapter = MeanFlowAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config + ) elif self.adapter_type == 'control_lora': self.control_lora = ControlLoraAdapter( self, @@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module): def setup_clip(self): adapter_config = self.config sd = self.sd_ref() - if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]: + if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel", "mean_flow"]: return if self.config.type == 'photo_maker': try: @@ -528,6 +547,14 @@ class CustomAdapter(torch.nn.Module): new_dict[k + '.' + k2] = v2 self.control_lora.load_weights(new_dict, strict=strict) + if self.adapter_type == 'mean_flow': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.mean_flow_adapter.load_weights(new_dict, strict=strict) + if self.adapter_type == 'i2v': # state dict is seperated. so recombine it new_dict = {} @@ -599,6 +626,11 @@ class CustomAdapter(torch.nn.Module): for k, v in d.items(): state_dict[k] = v return state_dict + elif self.adapter_type == 'mean_flow': + d = self.mean_flow_adapter.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict elif self.adapter_type == 'i2v': d = self.i2v_adapter.get_state_dict() for k, v in d.items(): @@ -757,7 +789,7 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']: + if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v', 'mean_flow']: return prompt elif self.adapter_type == 'text_encoder': # todo allow for training @@ -1319,6 +1351,10 @@ class CustomAdapter(torch.nn.Module): param_list = self.control_lora.get_params() for param in param_list: yield param + elif self.config.type == 'mean_flow': + param_list = self.mean_flow_adapter.get_params() + for param in param_list: + yield param elif self.config.type == 'i2v': param_list = self.i2v_adapter.get_params() for param in param_list: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index ba235e91..0c7d7562 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -84,15 +84,16 @@ class FileItemDTO( video.release() size_database[file_key] = (width, height, file_signature) else: - # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now. - # process width and height - # try: - # w, h = image_utils.get_image_size(self.path) - # except image_utils.UnknownImageFormat: - # print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ - # f'This process is faster for png, jpeg') - img = exif_transpose(Image.open(self.path)) - w, h = img.size + if self.dataset_config.fast_image_size: + # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default. + try: + w, h = image_utils.get_image_size(self.path) + except image_utils.UnknownImageFormat: + print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ + f'This process is faster for png, jpeg') + else: + img = exif_transpose(Image.open(self.path)) + w, h = img.size size_database[file_key] = (w, h, file_signature) self.width: int = w self.height: int = h diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 17550dde..17dc6b55 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -815,7 +815,10 @@ class BaseModel: # predict the noise residual if self.unet.device != self.device_torch: - self.unet.to(self.device_torch) + try: + self.unet.to(self.device_torch) + except Exception as e: + pass if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype) diff --git a/toolkit/models/control_lora_adapter.py b/toolkit/models/control_lora_adapter.py index 3588302d..38147ea9 100644 --- a/toolkit/models/control_lora_adapter.py +++ b/toolkit/models/control_lora_adapter.py @@ -135,7 +135,7 @@ class ControlLoraAdapter(torch.nn.Module): network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs if hasattr(sd, 'target_lora_modules'): - network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + network_kwargs['target_lin_modules'] = sd.target_lora_modules if 'ignore_if_contains' not in network_kwargs: network_kwargs['ignore_if_contains'] = [] diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 01d7f278..0f7bff09 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -127,18 +127,20 @@ class DFEBlock(nn.Module): self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.act = nn.GELU() + self.proj = nn.Conv2d(channels, channels, 1) def forward(self, x): x_in = x x = self.conv1(x) x = self.conv2(x) x = self.act(x) + x = self.proj(x) x = x + x_in return x class DiffusionFeatureExtractor(nn.Module): - def __init__(self, in_channels=32): + def __init__(self, in_channels=16): super().__init__() self.version = 1 num_blocks = 6 diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py index 5d3064f4..0241ce2f 100644 --- a/toolkit/models/flux.py +++ b/toolkit/models/flux.py @@ -176,60 +176,3 @@ def add_model_gpu_splitter_to_flux( transformer._pre_gpu_split_to = transformer.to transformer.to = partial(new_device_to, transformer) - -def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection): - # make zero timestep ending if none is passed - if timestep.shape[0] == pooled_projection.shape[0]: - timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep - - timesteps_proj = self.time_proj(timestep) - timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0) - - timesteps_emb = timesteps_emb_start + timesteps_emb_end - - pooled_projections = self.text_embedder(pooled_projection) - - conditioning = timesteps_emb + pooled_projections - - return conditioning - -def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection): - # make zero timestep ending if none is passed - if timestep.shape[0] == pooled_projection.shape[0]: - timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) - - time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb - - pooled_projections = self.text_embedder(pooled_projection) - conditioning = time_guidance_emb + pooled_projections - - return conditioning - - -def convert_flux_to_mean_flow( - transformer: FluxTransformer2DModel, -): - if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): - transformer.time_text_embed.forward = partial( - mean_flow_time_text_embed_forward, transformer.time_text_embed - ) - elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings): - transformer.time_text_embed.forward = partial( - mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed - ) - else: - raise ValueError( - "Unsupported time_text_embed type: {}".format( - type(transformer.time_text_embed) - ) - ) - \ No newline at end of file diff --git a/toolkit/models/lumina2.py b/toolkit/models/lumina2.py deleted file mode 100644 index 628f28b8..00000000 --- a/toolkit/models/lumina2.py +++ /dev/null @@ -1,567 +0,0 @@ -# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin -from diffusers.utils import logging -from diffusers.models.attention import LuminaFeedForward -from diffusers.models.attention_processor import Attention -from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm -import torch -from torch.profiler import profile, record_function, ProfilerActivity - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -do_profile = False - - -class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): - def __init__( - self, - hidden_size: int = 4096, - cap_feat_dim: int = 2048, - frequency_embedding_size: int = 256, - norm_eps: float = 1e-5, - ) -> None: - super().__init__() - - self.time_proj = Timesteps( - num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 - ) - - self.timestep_embedder = TimestepEmbedding( - in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) - ) - - self.caption_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True) - ) - - def forward( - self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - timestep_proj = self.time_proj(timestep).type_as(hidden_states) - time_embed = self.timestep_embedder(timestep_proj) - caption_embed = self.caption_embedder(encoder_hidden_states) - return time_embed, caption_embed - - -class Lumina2AttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - base_sequence_length: Optional[int] = None, - ) -> torch.Tensor: - batch_size, sequence_length, _ = hidden_states.shape - - # Get Query-Key-Value Pair - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query_dim = query.shape[-1] - inner_dim = key.shape[-1] - head_dim = query_dim // attn.heads - dtype = query.dtype - - # Get key-value heads - kv_heads = inner_dim // head_dim - - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, kv_heads, head_dim) - value = value.view(batch_size, -1, kv_heads, head_dim) - - # Apply Query-Key Norm if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb, use_real=False) - key = apply_rotary_emb(key, image_rotary_emb, use_real=False) - - query, key = query.to(dtype), key.to(dtype) - - # Apply proportional attention if true - if base_sequence_length is not None: - softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale - else: - softmax_scale = attn.scale - - # perform Grouped-qurey Attention (GQA) - n_rep = attn.heads // kv_heads - if n_rep >= 1: - key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) - attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, scale=softmax_scale - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.type_as(query) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class Lumina2TransformerBlock(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - num_kv_heads: int, - multiple_of: int, - ffn_dim_multiplier: float, - norm_eps: float, - modulation: bool = True, - ) -> None: - super().__init__() - self.head_dim = dim // num_attention_heads - self.modulation = modulation - - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=dim // num_attention_heads, - qk_norm="rms_norm", - heads=num_attention_heads, - kv_heads=num_kv_heads, - eps=1e-5, - bias=False, - out_bias=False, - processor=Lumina2AttnProcessor2_0(), - ) - - self.feed_forward = LuminaFeedForward( - dim=dim, - inner_dim=4 * dim, - multiple_of=multiple_of, - ffn_dim_multiplier=ffn_dim_multiplier, - ) - - if modulation: - self.norm1 = LuminaRMSNormZero( - embedding_dim=dim, - norm_eps=norm_eps, - norm_elementwise_affine=True, - ) - else: - self.norm1 = RMSNorm(dim, eps=norm_eps) - self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) - - self.norm2 = RMSNorm(dim, eps=norm_eps) - self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - image_rotary_emb: torch.Tensor, - temb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if self.modulation: - norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - attention_mask=attention_mask, - image_rotary_emb=image_rotary_emb, - ) - hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) - mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) - hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) - else: - norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - attention_mask=attention_mask, - image_rotary_emb=image_rotary_emb, - ) - hidden_states = hidden_states + self.norm2(attn_output) - mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) - hidden_states = hidden_states + self.ffn_norm2(mlp_output) - - return hidden_states - - -class Lumina2RotaryPosEmbed(nn.Module): - def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - self.axes_lens = axes_lens - self.patch_size = patch_size - - self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta) - - def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: - freqs_cis = [] - for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): - emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64) - freqs_cis.append(emb) - return freqs_cis - - def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: - result = [] - for i in range(len(self.axes_dim)): - freqs = self.freqs_cis[i].to(ids.device) - index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) - result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) - return torch.cat(result, dim=-1) - - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): - batch_size = len(hidden_states) - p_h = p_w = self.patch_size - device = hidden_states[0].device - - l_effective_cap_len = attention_mask.sum(dim=1).tolist() - # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape - img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] - l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes] - - max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))) - max_img_len = max(l_effective_img_len) - - position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) - - for i in range(batch_size): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // p_h, W // p_w - assert H_tokens * W_tokens == img_len - - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len : cap_len + img_len, 0] = cap_len - row_ids = ( - torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - ) - col_ids = ( - torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - ) - position_ids[i, cap_len : cap_len + img_len, 1] = row_ids - position_ids[i, cap_len : cap_len + img_len, 2] = col_ids - - freqs_cis = self._get_freqs_cis(position_ids) - - cap_freqs_cis_shape = list(freqs_cis.shape) - cap_freqs_cis_shape[1] = attention_mask.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(batch_size): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] - - flat_hidden_states = [] - for i in range(batch_size): - img = hidden_states[i] - C, H, W = img.size() - img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_hidden_states.append(img) - hidden_states = flat_hidden_states - padded_img_embed = torch.zeros( - batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype - ) - padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) - for i in range(batch_size): - padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i] - padded_img_mask[i, : l_effective_img_len[i]] = True - - return ( - padded_img_embed, - padded_img_mask, - img_sizes, - l_effective_cap_len, - l_effective_img_len, - freqs_cis, - cap_freqs_cis, - img_freqs_cis, - max_seq_len, - ) - - -class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): - r""" - Lumina2NextDiT: Diffusion model with a Transformer backbone. - - Parameters: - sample_size (`int`): The width of the latent images. This is fixed during training since - it is used to learn a number of position embeddings. - patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): - The size of each patch in the image. This parameter defines the resolution of patches fed into the model. - in_channels (`int`, *optional*, defaults to 4): - The number of input channels for the model. Typically, this matches the number of channels in the input - images. - hidden_size (`int`, *optional*, defaults to 4096): - The dimensionality of the hidden layers in the model. This parameter determines the width of the model's - hidden representations. - num_layers (`int`, *optional*, default to 32): - The number of layers in the model. This defines the depth of the neural network. - num_attention_heads (`int`, *optional*, defaults to 32): - The number of attention heads in each attention layer. This parameter specifies how many separate attention - mechanisms are used. - num_kv_heads (`int`, *optional*, defaults to 8): - The number of key-value heads in the attention mechanism, if different from the number of attention heads. - If None, it defaults to num_attention_heads. - multiple_of (`int`, *optional*, defaults to 256): - A factor that the hidden size should be a multiple of. This can help optimize certain hardware - configurations. - ffn_dim_multiplier (`float`, *optional*): - A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on - the model configuration. - norm_eps (`float`, *optional*, defaults to 1e-5): - A small value added to the denominator for numerical stability in normalization layers. - scaling_factor (`float`, *optional*, defaults to 1.0): - A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the - overall scale of the model's operations. - """ - - _supports_gradient_checkpointing = True - _no_split_modules = ["Lumina2TransformerBlock"] - _skip_layerwise_casting_patterns = ["x_embedder", "norm"] - - @register_to_config - def __init__( - self, - sample_size: int = 128, - patch_size: int = 2, - in_channels: int = 16, - out_channels: Optional[int] = None, - hidden_size: int = 2304, - num_layers: int = 26, - num_refiner_layers: int = 2, - num_attention_heads: int = 24, - num_kv_heads: int = 8, - multiple_of: int = 256, - ffn_dim_multiplier: Optional[float] = None, - norm_eps: float = 1e-5, - scaling_factor: float = 1.0, - axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), - axes_lens: Tuple[int, int, int] = (300, 512, 512), - cap_feat_dim: int = 1024, - ) -> None: - super().__init__() - self.out_channels = out_channels or in_channels - - # 1. Positional, patch & conditional embeddings - self.rope_embedder = Lumina2RotaryPosEmbed( - theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size - ) - - self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size) - - self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( - hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps - ) - - # 2. Noise and context refinement blocks - self.noise_refiner = nn.ModuleList( - [ - Lumina2TransformerBlock( - hidden_size, - num_attention_heads, - num_kv_heads, - multiple_of, - ffn_dim_multiplier, - norm_eps, - modulation=True, - ) - for _ in range(num_refiner_layers) - ] - ) - - self.context_refiner = nn.ModuleList( - [ - Lumina2TransformerBlock( - hidden_size, - num_attention_heads, - num_kv_heads, - multiple_of, - ffn_dim_multiplier, - norm_eps, - modulation=False, - ) - for _ in range(num_refiner_layers) - ] - ) - - # 3. Transformer blocks - self.layers = nn.ModuleList( - [ - Lumina2TransformerBlock( - hidden_size, - num_attention_heads, - num_kv_heads, - multiple_of, - ffn_dim_multiplier, - norm_eps, - modulation=True, - ) - for _ in range(num_layers) - ] - ) - - # 4. Output norm & projection - self.norm_out = LuminaLayerNormContinuous( - embedding_dim=hidden_size, - conditioning_embedding_dim=min(hidden_size, 1024), - elementwise_affine=False, - eps=1e-6, - bias=True, - out_dim=patch_size * patch_size * self.out_channels, - ) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - return_dict: bool = True, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: - - hidden_size = self.config.get("hidden_size", 2304) - # pad or slice text encoder - if encoder_hidden_states.shape[2] > hidden_size: - encoder_hidden_states = encoder_hidden_states[:, :, :hidden_size] - elif encoder_hidden_states.shape[2] < hidden_size: - encoder_hidden_states = F.pad(encoder_hidden_states, (0, hidden_size - encoder_hidden_states.shape[2])) - - batch_size = hidden_states.size(0) - - if do_profile: - prof = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - ) - - prof.start() - - # 1. Condition, positional & patch embedding - temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) - - ( - hidden_states, - hidden_mask, - hidden_sizes, - encoder_hidden_len, - hidden_len, - joint_rotary_emb, - encoder_rotary_emb, - hidden_rotary_emb, - max_seq_len, - ) = self.rope_embedder(hidden_states, attention_mask) - - hidden_states = self.x_embedder(hidden_states) - - # 2. Context & noise refinement - for layer in self.context_refiner: - encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb) - - for layer in self.noise_refiner: - hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb) - - # 3. Attention mask preparation - mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) - padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) - for i in range(batch_size): - cap_len = encoder_hidden_len[i] - img_len = hidden_len[i] - mask[i, : cap_len + img_len] = True - padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] - padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len] - hidden_states = padded_hidden_states - - # 4. Transformer blocks - for layer in self.layers: - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb) - else: - hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb) - - # 5. Output norm & projection & unpatchify - hidden_states = self.norm_out(hidden_states, temb) - - height_tokens = width_tokens = self.config.patch_size - output = [] - for i in range(len(hidden_sizes)): - height, width = hidden_sizes[i] - begin = encoder_hidden_len[i] - end = begin + (height // height_tokens) * (width // width_tokens) - output.append( - hidden_states[i][begin:end] - .view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels) - .permute(4, 0, 2, 1, 3) - .flatten(3, 4) - .flatten(1, 2) - ) - output = torch.stack(output, dim=0) - - if do_profile: - torch.cuda.synchronize() # Make sure all CUDA ops are done - prof.stop() - - print("\n==== Profile Results ====") - print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/toolkit/models/mean_flow_adapter.py b/toolkit/models/mean_flow_adapter.py new file mode 100644 index 00000000..1c5a6c15 --- /dev/null +++ b/toolkit/models/mean_flow_adapter.py @@ -0,0 +1,282 @@ +import inspect +import weakref +import torch +from typing import TYPE_CHECKING +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import FluxTransformer2DModel +from diffusers.models.embeddings import ( + CombinedTimestepTextProjEmbeddings, + CombinedTimestepGuidanceTextProjEmbeddings, +) +from functools import partial + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + + +def mean_flow_time_text_embed_forward( + self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection +): + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + # make zero timestep ending if none is passed + if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat( + [timestep, torch.zeros_like(timestep)], dim=0 + ) # timestep - 0 (final timestep) == same as start timestep + + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = timesteps_emb.dtype + timesteps_emb = timesteps_emb.to(torch.float32) + timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) + timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) + ) + timesteps_emb = timesteps_emb.to(orig_dtype) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +def mean_flow_time_text_guidance_embed_forward( + self: CombinedTimestepGuidanceTextProjEmbeddings, + timestep, + guidance, + pooled_projection, +): + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + # make zero timestep ending if none is passed + if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat( + [timestep, torch.zeros_like(timestep)], dim=0 + ) # timestep - 0 (final timestep) == same as start timestep + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder( + guidance_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = timesteps_emb.dtype + timesteps_emb = timesteps_emb.to(torch.float32) + timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) + timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) + ) + timesteps_emb = timesteps_emb.to(orig_dtype) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +def convert_flux_to_mean_flow( + transformer: FluxTransformer2DModel, +): + if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_embed_forward, transformer.time_text_embed + ) + elif isinstance( + transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings + ): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed + ) + else: + raise ValueError( + "Unsupported time_text_embed type: {}".format( + type(transformer.time_text_embed) + ) + ) + + +class MeanFlowAdapter(torch.nn.Module): + def __init__( + self, + adapter: "CustomAdapter", + sd: "StableDiffusion", + config: "AdapterConfig", + train_config: "TrainConfig", + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.device_torch = sd.device_torch + self.lora = None + + if self.network_config is not None: + network_kwargs = ( + {} + if self.network_config.network_kwargs is None + else self.network_config.network_kwargs + ) + if hasattr(sd, "target_lora_modules"): + network_kwargs["target_lin_modules"] = sd.target_lora_modules + + if "ignore_if_contains" not in network_kwargs: + network_kwargs["ignore_if_contains"] = [] + + self.lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs, + ) + self.lora.force_to(self.device_torch, dtype=torch.float32) + self.lora._update_torch_multiplier() + self.lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet, + ) + self.lora.can_merge_in = False + self.lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.lora.enable_gradient_checkpointing() + + emb_dim = None + if self.model_config.arch in ["flux", "flex2", "flex2"]: + transformer: FluxTransformer2DModel = sd.unet + emb_dim = ( + transformer.config.num_attention_heads + * transformer.config.attention_head_dim + ) + convert_flux_to_mean_flow(transformer) + else: + raise ValueError(f"Unsupported architecture: {self.model_config.arch}") + + self.mean_flow_timestep_embedder = torch.nn.Linear( + emb_dim * 2, + emb_dim, + ) + + # make the model function as before adding this adapter by initializing the weights + with torch.no_grad(): + self.mean_flow_timestep_embedder.weight.zero_() + self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim) + self.mean_flow_timestep_embedder.bias.zero_() + + self.mean_flow_timestep_embedder.to(self.device_torch) + + # add our adapter as a weak ref + if self.model_config.arch in ["flux", "flex2", "flex2"]: + sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self) + + def get_params(self): + if self.lora is not None: + config = { + "text_encoder_lr": self.train_config.lr, + "unet_lr": self.train_config.lr, + } + sig = inspect.signature(self.lora.prepare_optimizer_params) + if "default_lr" in sig.parameters: + config["default_lr"] = self.train_config.lr + if "learning_rate" in sig.parameters: + config["learning_rate"] = self.train_config.lr + params_net = self.lora.prepare_optimizer_params(**config) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + # make sure the embedder is float32 + self.mean_flow_timestep_embedder.to(torch.float32) + self.mean_flow_timestep_embedder.requires_grad = True + self.mean_flow_timestep_embedder.train() + + params += list(self.mean_flow_timestep_embedder.parameters()) + + # we need to be able to yield from the list like yield from params + + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + mean_flow_embedder_sd = {} + for key, value in state_dict.items(): + if "mean_flow_timestep_embedder" in key: + new_key = key.replace("transformer.mean_flow_timestep_embedder.", "") + mean_flow_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading for models that need it + if self.lora is not None: + self.lora.load_weights(lora_sd) + self.mean_flow_timestep_embedder.load_state_dict( + mean_flow_embedder_sd, strict=False + ) + + def get_state_dict(self): + if self.lora is not None: + lora_sd = self.lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + # todo make sure we match loras elseware. + mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict() + for key, value in mean_flow_embedder_sd.items(): + lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value + return lora_sd + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 4d97b2cd..8897bdc0 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer): do_paramiter_swapping=False, paramiter_swapping_factor=0.1, stochastic_accumulation=True, + stochastic_rounding=True, ): + self.stochastic_rounding = stochastic_rounding if lr is not None and relative_step: raise ValueError( "Cannot combine manual `lr` and `relative_step=True` options") @@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer): p_data_fp32.add_(-update) - if p.dtype != torch.float32: + if p.dtype != torch.float32 and self.stochastic_rounding: # apply stochastic rounding copy_stochastic(p, p_data_fp32) diff --git a/toolkit/optimizers/optimizer_utils.py b/toolkit/optimizers/optimizer_utils.py index 67991f21..3adae275 100644 --- a/toolkit/optimizers/optimizer_utils.py +++ b/toolkit/optimizers/optimizer_utils.py @@ -112,86 +112,57 @@ def get_format_params(dtype: torch.dtype) -> tuple[int, int]: return 0, 8 # Int8 doesn't have mantissa bits else: raise ValueError(f"Unsupported dtype: {dtype}") + +def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor): + # adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5 + # create a random 16 bit integer + result = torch.randint_like( + source, + dtype=torch.int32, + low=0, + high=(1 << 16), + ) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result -def copy_stochastic( - target: torch.Tensor, - source: torch.Tensor, - eps: Optional[float] = None -) -> None: - """ - Performs stochastic rounding from source tensor to target tensor. - - Args: - target: Destination tensor (determines the target format) - source: Source tensor (typically float32) - eps: Optional minimum value for stochastic rounding (for numerical stability) - """ +def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None: with torch.no_grad(): - # If target is float32, just copy directly + # assert if target is on cpu, throw error + assert target.device.type != 'cpu', "Target is on cpu!" + assert source.device.type != 'cpu', "Source is on cpu!" + if target.dtype == torch.float32: target.copy_(source) return - - # Special handling for int8 - if target.dtype == torch.int8: - # Scale the source values to utilize the full int8 range - scaled = source * 127.0 # Scale to [-127, 127] - - # Add random noise for stochastic rounding - noise = torch.rand_like(scaled) - 0.5 - rounded = torch.round(scaled + noise) - - # Clamp to int8 range - clamped = torch.clamp(rounded, -127, 127) - target.copy_(clamped.to(torch.int8)) + if target.dtype == torch.bfloat16: + copy_stochastic_bf16(target, source) return mantissa_bits, _ = get_format_params(target.dtype) + round_factor = 2 ** (23 - mantissa_bits) - # Convert source to int32 view - source_int = source.view(dtype=torch.int32) + # Add uniform noise for stochastic rounding + noise = torch.rand_like(source, device=source.device) - 0.5 + rounded = torch.round(source * round_factor + noise) + result_float = rounded / round_factor - # Calculate number of bits to round - bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits - - # Create random integers for stochastic rounding - rand = torch.randint_like( - source, - dtype=torch.int32, - low=0, - high=(1 << bits_to_round), - ) - - # Add random values to the bits that will be rounded off - result = source_int.clone() - result.add_(rand) - - # Mask to keep only the bits we want - # Create mask with 1s in positions we want to keep - mask = (-1) << bits_to_round - result.bitwise_and_(mask) - - # Handle minimum value threshold if specified - if eps is not None: - eps_int = torch.tensor( - eps, dtype=torch.float32).view(dtype=torch.int32) - zero_mask = (result.abs() < eps_int) - result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int - - # Convert back to float32 view - result_float = result.view(dtype=torch.float32) - - # Special handling for float8 formats + # Clamp for float8 if target.dtype == torch.float8_e4m3fn: result_float.clamp_(-448.0, 448.0) elif target.dtype == torch.float8_e5m2: result_float.clamp_(-57344.0, 57344.0) - # Copy the result to the target tensor update_parameter(target, result_float) - # target.copy_(result_float) - del result, rand, source_int class Auto8bitTensor: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 068e747e..de8865f6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -50,8 +50,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \ - FluxControlPipeline -from toolkit.models.lumina2 import Lumina2Transformer2DModel + FluxControlPipeline, Lumina2Transformer2DModel import diffusers from diffusers import \ AutoencoderKL, \ @@ -1763,6 +1762,15 @@ class StableDiffusion: ) noise = apply_noise_offset(noise, noise_offset) return noise + + def get_latent_noise_from_latents( + self, + latents: torch.Tensor, + noise_offset=0.0 + ): + noise = torch.randn_like(latents) + noise = apply_noise_offset(noise, noise_offset) + return noise def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) @@ -2170,7 +2178,7 @@ class StableDiffusion: noise_pred = self.unet( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), timestep=t, - attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), + encoder_attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), **kwargs, ).sample @@ -2529,8 +2537,8 @@ class StableDiffusion: # Move to vae to device if on cpu if self.vae.device == 'cpu': - self.vae.to(self.device) - latents = latents.to(device, dtype=dtype) + self.vae.to(self.device_torch) + latents = latents.to(self.device_torch, dtype=self.torch_dtype) latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] images = self.vae.decode(latents).sample images = images.to(device, dtype=dtype) diff --git a/ui/cron/worker.ts b/ui/cron/worker.ts new file mode 100644 index 00000000..589393a4 --- /dev/null +++ b/ui/cron/worker.ts @@ -0,0 +1,31 @@ +class CronWorker { + interval: number; + is_running: boolean; + intervalId: NodeJS.Timeout; + constructor() { + this.interval = 1000; // Default interval of 1 second + this.is_running = false; + this.intervalId = setInterval(() => { + this.run(); + }, this.interval); + } + async run() { + if (this.is_running) { + return; + } + this.is_running = true; + try { + // Loop logic here + await this.loop(); + } catch (error) { + console.error('Error in cron worker loop:', error); + } + this.is_running = false; + } + + async loop() {} +} + +// it automatically starts the loop +const cronWorker = new CronWorker(); +console.log('Cron worker started with interval:', cronWorker.interval, 'ms'); diff --git a/ui/package-lock.json b/ui/package-lock.json index f20ef7a4..6bebedce 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -30,10 +30,12 @@ "@types/node": "^20", "@types/react": "^19", "@types/react-dom": "^19", + "concurrently": "^9.1.2", "postcss": "^8", "prettier": "^3.5.1", "prettier-basic": "^1.0.0", "tailwindcss": "^3.4.1", + "ts-node-dev": "^2.0.0", "typescript": "^5" } }, @@ -169,6 +171,28 @@ "node": ">=6.9.0" } }, + "node_modules/@cspotcode/source-map-support": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz", + "integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==", + "dev": true, + "dependencies": { + "@jridgewell/trace-mapping": "0.3.9" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": { + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz", + "integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==", + "dev": true, + "dependencies": { + "@jridgewell/resolve-uri": "^3.0.3", + "@jridgewell/sourcemap-codec": "^1.4.10" + } + }, "node_modules/@emnapi/runtime": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.3.1.tgz", @@ -1168,6 +1192,30 @@ "node": ">= 6" } }, + "node_modules/@tsconfig/node10": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.11.tgz", + "integrity": "sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==", + "dev": true + }, + "node_modules/@tsconfig/node12": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz", + "integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==", + "dev": true + }, + "node_modules/@tsconfig/node14": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz", + "integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==", + "dev": true + }, + "node_modules/@tsconfig/node16": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.4.tgz", + "integrity": "sha512-vxhUy4J8lyeyinH7Azl1pdd43GJhZH/tP2weN8TntQblOY+A0XbT8DJk1/oCPuOOyg/Ja757rG0CgHcWC8OfMA==", + "dev": true + }, "node_modules/@types/node": { "version": "20.17.19", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.19.tgz", @@ -1207,6 +1255,18 @@ "@types/react": "*" } }, + "node_modules/@types/strip-bom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@types/strip-bom/-/strip-bom-3.0.0.tgz", + "integrity": "sha512-xevGOReSYGM7g/kUBZzPqCrR/KYAo+F0yiPc85WFTJa0MSLtyFTVTU6cJu/aV4mid7IffDIWqo69THF2o4JiEQ==", + "dev": true + }, + "node_modules/@types/strip-json-comments": { + "version": "0.0.30", + "resolved": "https://registry.npmjs.org/@types/strip-json-comments/-/strip-json-comments-0.0.30.tgz", + "integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==", + "dev": true + }, "node_modules/abbrev": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/abbrev/-/abbrev-1.1.1.tgz", @@ -1214,6 +1274,30 @@ "license": "ISC", "optional": true }, + "node_modules/acorn": { + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "dev": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-walk": { + "version": "8.3.4", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.4.tgz", + "integrity": "sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==", + "dev": true, + "dependencies": { + "acorn": "^8.11.0" + }, + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/agent-base": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz", @@ -1465,6 +1549,12 @@ "ieee754": "^1.1.13" } }, + "node_modules/buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "dev": true + }, "node_modules/busboy": { "version": "1.6.0", "resolved": "https://registry.npmjs.org/busboy/-/busboy-1.6.0.tgz", @@ -1626,6 +1716,49 @@ } ] }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/chalk/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/chalk/node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/chokidar": { "version": "3.6.0", "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", @@ -1691,6 +1824,93 @@ "resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz", "integrity": "sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==" }, + "node_modules/cliui": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", + "integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==", + "dev": true, + "dependencies": { + "string-width": "^4.2.0", + "strip-ansi": "^6.0.1", + "wrap-ansi": "^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/cliui/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/cliui/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/cliui/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true + }, + "node_modules/cliui/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/cliui/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/cliui/node_modules/wrap-ansi": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "dev": true, + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, "node_modules/clone": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/clone/-/clone-2.1.2.tgz", @@ -1782,8 +2002,33 @@ "version": "0.0.1", "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", - "license": "MIT", - "optional": true + "devOptional": true, + "license": "MIT" + }, + "node_modules/concurrently": { + "version": "9.1.2", + "resolved": "https://registry.npmjs.org/concurrently/-/concurrently-9.1.2.tgz", + "integrity": "sha512-H9MWcoPsYddwbOGM6difjVwVZHl63nwMEwDJG/L7VGtuaJhb12h2caPG2tVPWs7emuYix252iGfqOyrz1GczTQ==", + "dev": true, + "dependencies": { + "chalk": "^4.1.2", + "lodash": "^4.17.21", + "rxjs": "^7.8.1", + "shell-quote": "^1.8.1", + "supports-color": "^8.1.1", + "tree-kill": "^1.2.2", + "yargs": "^17.7.2" + }, + "bin": { + "conc": "dist/bin/concurrently.js", + "concurrently": "dist/bin/concurrently.js" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/open-cli-tools/concurrently?sponsor=1" + } }, "node_modules/console-control-strings": { "version": "1.1.0", @@ -1820,6 +2065,12 @@ "node": ">= 6" } }, + "node_modules/create-require": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz", + "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==", + "dev": true + }, "node_modules/cross-spawn": { "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", @@ -1921,6 +2172,15 @@ "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==", "dev": true }, + "node_modules/diff": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz", + "integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==", + "dev": true, + "engines": { + "node": ">=0.3.1" + } + }, "node_modules/dlv": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz", @@ -1949,6 +2209,15 @@ "node": ">= 0.4" } }, + "node_modules/dynamic-dedupe": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/dynamic-dedupe/-/dynamic-dedupe-0.3.0.tgz", + "integrity": "sha512-ssuANeD+z97meYOqd50e04Ze5qp4bPqo8cCkI4TRjZkzAUgIDTrXV1R8QCdINpiI+hw14+rYazvTRdQrz0/rFQ==", + "dev": true, + "dependencies": { + "xtend": "^4.0.0" + } + }, "node_modules/eastasianwidth": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", @@ -2051,6 +2320,15 @@ "node": ">= 0.4" } }, + "node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "dev": true, + "engines": { + "node": ">=6" + } + }, "node_modules/escape-string-regexp": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", @@ -2225,8 +2503,8 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", - "license": "ISC", - "optional": true + "devOptional": true, + "license": "ISC" }, "node_modules/fsevents": { "version": "2.3.3", @@ -2322,6 +2600,15 @@ "node": ">=8" } }, + "node_modules/get-caller-file": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", + "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", + "dev": true, + "engines": { + "node": "6.* || 8.* || >= 10.*" + } + }, "node_modules/get-intrinsic": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.7.tgz", @@ -2421,6 +2708,15 @@ "license": "ISC", "optional": true }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, "node_modules/has-symbols": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", @@ -2598,8 +2894,8 @@ "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", "deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.", + "devOptional": true, "license": "ISC", - "optional": true, "dependencies": { "once": "^1.3.0", "wrappy": "1" @@ -2784,6 +3080,12 @@ "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==" }, + "node_modules/lodash": { + "version": "4.17.21", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "dev": true + }, "node_modules/loose-envify": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", @@ -2810,6 +3112,12 @@ "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, + "node_modules/make-error": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz", + "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==", + "dev": true + }, "node_modules/make-fetch-happen": { "version": "9.1.0", "resolved": "https://registry.npmjs.org/make-fetch-happen/-/make-fetch-happen-9.1.0.tgz", @@ -3499,8 +3807,8 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "devOptional": true, "license": "MIT", - "optional": true, "engines": { "node": ">=0.10.0" } @@ -4004,6 +4312,15 @@ "node": ">=8.10.0" } }, + "node_modules/require-directory": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", + "integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/resolve": { "version": "1.22.10", "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz", @@ -4137,6 +4454,15 @@ "queue-microtask": "^1.2.2" } }, + "node_modules/rxjs": { + "version": "7.8.2", + "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.2.tgz", + "integrity": "sha512-dhKf903U/PQZY6boNNtAGdWbG85WAbjT/1xYoZIC7FAY0yWapOBQVsVrDl58W86//e1VpMNBtRV4MaXfdMySFA==", + "dev": true, + "dependencies": { + "tslib": "^2.1.0" + } + }, "node_modules/safe-buffer": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", @@ -4247,6 +4573,18 @@ "node": ">=8" } }, + "node_modules/shell-quote": { + "version": "1.8.3", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz", + "integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==", + "dev": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/signal-exit": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", @@ -4370,6 +4708,25 @@ "node": ">=0.10.0" } }, + "node_modules/source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "dev": true, + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/source-map-support/node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/sprintf-js": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz", @@ -4545,6 +4902,15 @@ "node": ">=8" } }, + "node_modules/strip-bom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz", + "integrity": "sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==", + "dev": true, + "engines": { + "node": ">=4" + } + }, "node_modules/strip-json-comments": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz", @@ -4603,6 +4969,21 @@ "node": ">=16 || 14 >=14.17" } }, + "node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, "node_modules/supports-preserve-symlinks-flag": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", @@ -4749,12 +5130,172 @@ "node": ">=8.0" } }, + "node_modules/tree-kill": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/tree-kill/-/tree-kill-1.2.2.tgz", + "integrity": "sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==", + "dev": true, + "bin": { + "tree-kill": "cli.js" + } + }, "node_modules/ts-interface-checker": { "version": "0.1.13", "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==", "dev": true }, + "node_modules/ts-node-dev": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ts-node-dev/-/ts-node-dev-2.0.0.tgz", + "integrity": "sha512-ywMrhCfH6M75yftYvrvNarLEY+SUXtUvU8/0Z6llrHQVBx12GiFk5sStF8UdfE/yfzk9IAq7O5EEbTQsxlBI8w==", + "dev": true, + "dependencies": { + "chokidar": "^3.5.1", + "dynamic-dedupe": "^0.3.0", + "minimist": "^1.2.6", + "mkdirp": "^1.0.4", + "resolve": "^1.0.0", + "rimraf": "^2.6.1", + "source-map-support": "^0.5.12", + "tree-kill": "^1.2.2", + "ts-node": "^10.4.0", + "tsconfig": "^7.0.0" + }, + "bin": { + "ts-node-dev": "lib/bin.js", + "tsnd": "lib/bin.js" + }, + "engines": { + "node": ">=0.8.0" + }, + "peerDependencies": { + "node-notifier": "*", + "typescript": "*" + }, + "peerDependenciesMeta": { + "node-notifier": { + "optional": true + } + } + }, + "node_modules/ts-node-dev/node_modules/arg": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz", + "integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==", + "dev": true + }, + "node_modules/ts-node-dev/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/ts-node-dev/node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", + "dev": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/ts-node-dev/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/ts-node-dev/node_modules/rimraf": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", + "integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==", + "deprecated": "Rimraf versions prior to v4 are no longer supported", + "dev": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + } + }, + "node_modules/ts-node-dev/node_modules/ts-node": { + "version": "10.9.2", + "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.2.tgz", + "integrity": "sha512-f0FFpIdcHgn8zcPSbf1dRevwt047YMnaiJM3u2w2RewrB+fob/zePZcrOyQoLMMO7aBIddLcQIEK5dYjkLnGrQ==", + "dev": true, + "dependencies": { + "@cspotcode/source-map-support": "^0.8.0", + "@tsconfig/node10": "^1.0.7", + "@tsconfig/node12": "^1.0.7", + "@tsconfig/node14": "^1.0.0", + "@tsconfig/node16": "^1.0.2", + "acorn": "^8.4.1", + "acorn-walk": "^8.1.1", + "arg": "^4.1.0", + "create-require": "^1.1.0", + "diff": "^4.0.1", + "make-error": "^1.1.1", + "v8-compile-cache-lib": "^3.0.1", + "yn": "3.1.1" + }, + "bin": { + "ts-node": "dist/bin.js", + "ts-node-cwd": "dist/bin-cwd.js", + "ts-node-esm": "dist/bin-esm.js", + "ts-node-script": "dist/bin-script.js", + "ts-node-transpile-only": "dist/bin-transpile.js", + "ts-script": "dist/bin-script-deprecated.js" + }, + "peerDependencies": { + "@swc/core": ">=1.2.50", + "@swc/wasm": ">=1.2.50", + "@types/node": "*", + "typescript": ">=2.7" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "@swc/wasm": { + "optional": true + } + } + }, + "node_modules/tsconfig": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/tsconfig/-/tsconfig-7.0.0.tgz", + "integrity": "sha512-vZXmzPrL+EmC4T/4rVlT2jNVMWCi/O4DIiSj3UHg1OE5kCKbk4mfrXc6dZksLgRM/TZlKnousKH9bbTazUWRRw==", + "dev": true, + "dependencies": { + "@types/strip-bom": "^3.0.0", + "@types/strip-json-comments": "0.0.30", + "strip-bom": "^3.0.0", + "strip-json-comments": "^2.0.0" + } + }, "node_modules/tslib": { "version": "2.8.1", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", @@ -4829,6 +5370,12 @@ "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==" }, + "node_modules/v8-compile-cache-lib": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", + "integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==", + "dev": true + }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -4996,6 +5543,24 @@ "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", "license": "ISC" }, + "node_modules/xtend": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", + "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", + "dev": true, + "engines": { + "node": ">=0.4" + } + }, + "node_modules/y18n": { + "version": "5.0.8", + "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", + "integrity": "sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==", + "dev": true, + "engines": { + "node": ">=10" + } + }, "node_modules/yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", @@ -5012,6 +5577,83 @@ "engines": { "node": ">= 14" } + }, + "node_modules/yargs": { + "version": "17.7.2", + "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", + "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", + "dev": true, + "dependencies": { + "cliui": "^8.0.1", + "escalade": "^3.1.1", + "get-caller-file": "^2.0.5", + "require-directory": "^2.1.1", + "string-width": "^4.2.3", + "y18n": "^5.0.5", + "yargs-parser": "^21.1.1" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/yargs-parser": { + "version": "21.1.1", + "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", + "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", + "dev": true, + "engines": { + "node": ">=12" + } + }, + "node_modules/yargs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true + }, + "node_modules/yargs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yn": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz", + "integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==", + "dev": true, + "engines": { + "node": ">=6" + } } } } diff --git a/ui/package.json b/ui/package.json index 43e0d861..de04bd2b 100644 --- a/ui/package.json +++ b/ui/package.json @@ -3,9 +3,9 @@ "version": "0.1.0", "private": true, "scripts": { - "dev": "next dev --turbopack", - "build": "next build", - "start": "next start --port 8675", + "dev": "concurrently -k -n WORKER,UI \"ts-node-dev --respawn --watch cron --transpile-only cron/worker.ts\" \"next dev --turbopack\"", + "build": "tsc -p tsconfig.worker.json && next build", + "start": "concurrently --restart-tries -1 --restart-after 1000 -n WORKER,UI \"node dist/worker.js\" \"next start --port 8675\"", "build_and_start": "npm install && npm run update_db && npm run build && npm run start", "lint": "next lint", "update_db": "npx prisma generate && npx prisma db push", @@ -34,10 +34,12 @@ "@types/node": "^20", "@types/react": "^19", "@types/react-dom": "^19", + "concurrently": "^9.1.2", "postcss": "^8", "prettier": "^3.5.1", "prettier-basic": "^1.0.0", "tailwindcss": "^3.4.1", + "ts-node-dev": "^2.0.0", "typescript": "^5" }, "prettier": "prettier-basic" diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma index 1489e26e..96f6399b 100644 --- a/ui/prisma/schema.prisma +++ b/ui/prisma/schema.prisma @@ -26,3 +26,13 @@ model Job { info String @default("") speed_string String @default("") } + +model Queue { + id String @id @default(uuid()) + channel String + job_id String + created_at DateTime @default(now()) + updated_at DateTime @updatedAt + status String @default("waiting") + @@index([job_id, channel]) +} \ No newline at end of file diff --git a/ui/src/app/api/datasets/listImages/route.ts b/ui/src/app/api/datasets/listImages/route.ts index 55a11057..06dca84a 100644 --- a/ui/src/app/api/datasets/listImages/route.ts +++ b/ui/src/app/api/datasets/listImages/route.ts @@ -45,7 +45,7 @@ function findImagesRecursively(dir: string): string[] { const itemPath = path.join(dir, item); const stat = fs.statSync(itemPath); - if (stat.isDirectory()) { + if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) { // If it's a directory, recursively search it results = results.concat(findImagesRecursively(itemPath)); } else { diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index ca587564..0b8b3e47 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -47,6 +47,7 @@ export default function SimpleJob({ setJobConfig(value, 'config.name')} placeholder="Enter training name" disabled={runId !== null} @@ -55,12 +56,14 @@ export default function SimpleJob({ setGpuIDs(value)} options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} /> { if (value?.trim() === '') { value = null; @@ -120,6 +123,7 @@ export default function SimpleJob({ { if (value?.trim() === '') { value = null; @@ -185,22 +189,20 @@ export default function SimpleJob({ max={1024} required /> - { - modelArch?.disableSections?.includes('network.conv') ? null : ( - { - console.log('onChange', value); - setJobConfig(value, 'config.process[0].network.conv'); - setJobConfig(value, 'config.process[0].network.conv_alpha'); - }} - placeholder="eg. 16" - min={0} - max={1024} - /> - ) - } + {modelArch?.disableSections?.includes('network.conv') ? null : ( + { + console.log('onChange', value); + setJobConfig(value, 'config.process[0].network.conv'); + setJobConfig(value, 'config.process[0].network.conv_alpha'); + }} + placeholder="eg. 16" + min={0} + max={1024} + /> + )} )} diff --git a/ui/src/app/layout.tsx b/ui/src/app/layout.tsx index 50f8d79a..e45363f9 100644 --- a/ui/src/app/layout.tsx +++ b/ui/src/app/layout.tsx @@ -7,6 +7,7 @@ import ConfirmModal from '@/components/ConfirmModal'; import SampleImageModal from '@/components/SampleImageModal'; import { Suspense } from 'react'; import AuthWrapper from '@/components/AuthWrapper'; +import DocModal from '@/components/DocModal'; export const dynamic = 'force-dynamic'; @@ -38,6 +39,7 @@ export default function RootLayout({ children }: { children: React.ReactNode }) + diff --git a/ui/src/components/DocModal.tsx b/ui/src/components/DocModal.tsx new file mode 100644 index 00000000..bfdd6bf4 --- /dev/null +++ b/ui/src/components/DocModal.tsx @@ -0,0 +1,59 @@ +'use client'; +import { createGlobalState } from 'react-global-hooks'; +import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react'; +import React from 'react'; +import { ConfigDoc } from '@/types'; + +export const docState = createGlobalState(null); + +export const openDoc = (doc: ConfigDoc) => { + docState.set({ ...doc }); +}; + +export default function DocModal() { + const [doc, setDoc] = docState.use(); + const isOpen = !!doc; + + const onClose = () => { + setDoc(null); + }; + + return ( + + + +
+
+ +
+
+
+ + {doc?.title || 'Confirm Action'} + +
{doc?.description}
+
+
+
+
+ +
+
+
+
+
+ ); +} diff --git a/ui/src/components/Sidebar.tsx b/ui/src/components/Sidebar.tsx index 324b6097..a5b3e2d6 100644 --- a/ui/src/components/Sidebar.tsx +++ b/ui/src/components/Sidebar.tsx @@ -1,5 +1,6 @@ import Link from 'next/link'; -import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react'; +import { Home, Settings, BrainCircuit, Images, Plus} from 'lucide-react'; +import { FaXTwitter, FaDiscord, FaYoutube } from "react-icons/fa6"; const Sidebar = () => { const navigation = [ @@ -10,13 +11,16 @@ const Sidebar = () => { { name: 'Settings', href: '/settings', icon: Settings }, ]; + const socialsBoxClass = 'flex flex-col items-center justify-center p-1 hover:bg-gray-800 rounded-lg transition-colors'; + const socialIconClass = 'w-5 h-5 text-gray-400 hover:text-white'; + return (

Ostris AI Toolkit Ostris - AI-Toolkit + AI-Toolkit

- +
@@ -47,6 +56,39 @@ const Sidebar = () => {
Support AI-Toolkit
+ + {/* Social links grid */} +
); }; diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index a9908bd0..a556a266 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -3,6 +3,10 @@ import React, { forwardRef } from 'react'; import classNames from 'classnames'; import dynamic from 'next/dynamic'; +import { CircleHelp } from 'lucide-react'; +import { getDoc } from '@/docs'; +import { openDoc } from '@/components/DocModal'; + const Select = dynamic(() => import('react-select'), { ssr: false }); const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300'; @@ -11,6 +15,7 @@ const inputClasses = export interface InputProps { label?: string; + docKey?: string; className?: string; placeholder?: string; required?: boolean; @@ -24,10 +29,20 @@ export interface TextInputProps extends InputProps { } export const TextInput = forwardRef( - ({ label, value, onChange, placeholder, required, disabled, type = 'text', className }, ref) => { + ({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => { + const doc = getDoc(docKey); return (
- {label && } + {label && ( + + )} { - const { label, value, onChange, placeholder, required, min, max } = props; + const { label, value, onChange, placeholder, required, min, max, docKey = null } = props; + const doc = getDoc(docKey); // Add controlled internal state to properly handle partial inputs const [inputValue, setInputValue] = React.useState(value ?? ''); @@ -68,7 +84,16 @@ export const NumberInput = (props: NumberInputProps) => { return (
- {label && } + {label && ( + + )} { - const { label, value, onChange, options } = props; + const { label, value, onChange, options, docKey = null } = props; + const doc = getDoc(docKey); const selectedOption = options.find(option => option.value === value); return (
{ 'opacity-30 cursor-not-allowed': props.disabled, })} > - {label && } + {label && ( + + )}