diff --git a/.vscode/launch.json b/.vscode/launch.json index 483703eb..02d5cacf 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -16,6 +16,22 @@ "console": "integratedTerminal", "justMyCode": false }, + { + "name": "Run current config (cuda:1)", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/run.py", + "args": [ + "${file}" + ], + "env": { + "CUDA_LAUNCH_BLOCKING": "1", + "DEBUG_TOOLKIT": "1", + "CUDA_VISIBLE_DEVICES": "1" + }, + "console": "integratedTerminal", + "justMyCode": false + }, { "name": "Python: Debug Current File", "type": "python", diff --git a/build_and_push_docker_dev b/build_and_push_docker_dev new file mode 100644 index 00000000..6a1a17d0 --- /dev/null +++ b/build_and_push_docker_dev @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +VERSION=dev +GIT_COMMIT=dev + +echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." +echo "Building version: $VERSION and latest" +# wait 2 seconds +sleep 2 + +# Build the image with cache busting +docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile . + +# Tag with version and latest +docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION + +# Push both tags +echo "Pushing images to Docker Hub..." +docker push ostris/aitoolkit:$VERSION + +echo "Successfully built and pushed ostris/aitoolkit:$VERSION" \ No newline at end of file diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 8a8ba738..7468e0fc 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -35,6 +35,7 @@ import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.util.wavelet_loss import wavelet_loss +import torch.nn.functional as F def flush(): @@ -60,6 +61,7 @@ class SDTrainer(BaseSDTrainProcess): self._clip_image_embeds_unconditional: Union[List[str], None] = None self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None + self.cfm_cache = None self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" @@ -197,7 +199,7 @@ class SDTrainer(BaseSDTrainProcess): flush() if self.train_config.diffusion_feature_extractor_path is not None: - vae = None + vae = self.sd.vae # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": # vae = self.sd.vae self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) @@ -756,13 +758,13 @@ class SDTrainer(BaseSDTrainProcess): pass def predict_noise( - self, - noisy_latents: torch.Tensor, - timesteps: Union[int, torch.Tensor] = 1, - conditional_embeds: Union[PromptEmbeds, None] = None, - unconditional_embeds: Union[PromptEmbeds, None] = None, - batch: Optional['DataLoaderBatchDTO'] = None, - **kwargs, + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + batch: Optional['DataLoaderBatchDTO'] = None, + **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) return self.sd.predict_noise( @@ -778,6 +780,81 @@ class SDTrainer(BaseSDTrainProcess): batch=batch, **kwargs ) + + def cfm_augment_tensors( + self, + images: torch.Tensor + ) -> torch.Tensor: + if self.cfm_cache is None: + # flip the current one. Only need this for first time + self.cfm_cache = torch.flip(images, [3]).clone() + augmented_tensor_list = [] + for i in range(images.shape[0]): + # get a random one + idx = random.randint(0, self.cfm_cache.shape[0] - 1) + augmented_tensor_list.append(self.cfm_cache[idx:idx + 1]) + augmented = torch.cat(augmented_tensor_list, dim=0) + # resize to match the input + augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear') + self.cfm_cache = images.clone() + return augmented + + def get_cfm_loss( + self, + noisy_latents: torch.Tensor, + noise: torch.Tensor, + noise_pred: torch.Tensor, + conditional_embeds: PromptEmbeds, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + alpha: float = 0.1, + ): + dtype = get_torch_dtype(self.train_config.dtype) + if hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + + elif self.sd.is_flow_matching: + # forward ODE + target = (noise - batch.latents).detach() + else: + raise ValueError("CFM loss only works with flow matching") + fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + with torch.no_grad(): + # we need to compute the contrast + cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype) + cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype) + cfm_noisy_latents = self.sd.add_noise( + original_samples=cfm_latents, + noise=noise, + timesteps=timesteps, + ) + cfm_pred = self.predict_noise( + noisy_latents=cfm_noisy_latents, + timesteps=timesteps, + conditional_embeds=conditional_embeds, + unconditional_embeds=None, + batch=batch, + ) + + # v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1) + # v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W) + + # # Compute cosine similarity at each pixel + # sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W) + + cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) + # Compute cosine similarity at each pixel + sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W) + + # Average over spatial dimensions, then batch + contrastive_loss = -sim.mean() + + loss = fm_loss.mean() + alpha * contrastive_loss + return loss def train_single_accumulation(self, batch: DataLoaderBatchDTO): self.timer.start('preprocess_batch') @@ -1431,6 +1508,44 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, CustomAdapter): noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + if self.train_config.timestep_type == 'next_sample': + with self.timer('next_sample_step'): + with torch.no_grad(): + + stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] + stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] + stepped_timesteps = torch.stack(stepped_timesteps, dim=0) + + # do a sample at the current timestep and step it, then determine new noise + next_sample_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + stepped_latents = self.sd.step_scheduler( + next_sample_pred, + noisy_latents, + timesteps, + self.sd.noise_scheduler + ) + # stepped latents is our new noisy latents. Now we need to determine noise in the current sample + noisy_latents = stepped_latents + original_samples = batch.latents.to(self.device_torch, dtype=dtype) + # todo calc next timestep, for now this may work as it + t_01 = (stepped_timesteps / 1000).to(original_samples.device) + if len(stepped_latents.shape) == 4: + t_01 = t_01.view(-1, 1, 1, 1) + elif len(stepped_latents.shape) == 5: + t_01 = t_01.view(-1, 1, 1, 1, 1) + else: + raise ValueError("Unknown stepped latents shape", stepped_latents.shape) + next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 + noise = next_sample_noise + timesteps = stepped_timesteps + with self.timer('predict_unet'): noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), @@ -1450,15 +1565,25 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diff_output_preservation and not do_inverted_masked_prior: prior_to_calculate_loss = None - loss = self.calculate_loss( - noise_pred=noise_pred, - noise=noise, - noisy_latents=noisy_latents, - timesteps=timesteps, - batch=batch, - mask_multiplier=mask_multiplier, - prior_pred=prior_to_calculate_loss, - ) + if self.train_config.loss_type == 'cfm': + loss = self.get_cfm_loss( + noisy_latents=noisy_latents, + noise=noise, + noise_pred=noise_pred, + conditional_embeds=conditional_embeds, + timesteps=timesteps, + batch=batch, + ) + else: + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_to_calculate_loss, + ) if self.train_config.diff_output_preservation: # send the loss backwards otherwise checkpointing will fail diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1ccb0c3d..393ba831 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -931,16 +931,16 @@ class BaseSDTrainProcess(BaseTrainProcess): noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) - if self.train_config.random_noise_shift > 0.0: - # get random noise -1 to 1 - noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, - dtype=noise.dtype) * 2 - 1 + # if self.train_config.random_noise_shift > 0.0: + # # get random noise -1 to 1 + # noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, + # dtype=noise.dtype) * 2 - 1 - # multiply by shift amount - noise_shift *= self.train_config.random_noise_shift + # # multiply by shift amount + # noise_shift *= self.train_config.random_noise_shift - # add to noise - noise += noise_shift + # # add to noise + # noise += noise_shift if self.train_config.blended_blur_noise: noise = get_blended_blur_noise( @@ -1011,6 +1011,7 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) imgs = None is_reg = any(batch.get_is_reg_list()) + cfm_batch = None if batch.tensor is not None: imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) @@ -1118,8 +1119,13 @@ class BaseSDTrainProcess(BaseTrainProcess): if timestep_type is None: timestep_type = self.train_config.timestep_type + if self.train_config.timestep_type == 'next_sample': + # simulate a sample + num_train_timesteps = self.train_config.next_sample_timesteps + timestep_type = 'shift' + patch_size = 1 - if self.sd.is_flux: + if self.sd.is_flux or 'flex' in self.sd.arch: # flux is a patch size of 1, but latents are divided by 2, so we need to double it patch_size = 2 elif hasattr(self.sd.unet.config, 'patch_size'): @@ -1142,7 +1148,15 @@ class BaseSDTrainProcess(BaseTrainProcess): content_or_style = self.train_config.content_or_style_reg # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': - if content_or_style in ['style', 'content']: + if self.train_config.timestep_type == 'next_sample': + timestep_indices = torch.randint( + 0, + num_train_timesteps - 2, # -1 for 0 idx, -1 so we can step + (batch_size,), + device=self.device_torch + ) + timestep_indices = timestep_indices.long() + elif content_or_style in ['style', 'content']: # this is from diffusers training code # Cubic sampling for favoring later or earlier timesteps # For more details about why cubic sampling is used for content / structure, @@ -1169,7 +1183,7 @@ class BaseSDTrainProcess(BaseTrainProcess): min_noise_steps + 1, max_noise_steps - 1 ) - + elif content_or_style == 'balanced': if min_noise_steps == max_noise_steps: timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps @@ -1185,16 +1199,6 @@ class BaseSDTrainProcess(BaseTrainProcess): else: raise ValueError(f"Unknown content_or_style {content_or_style}") - # do flow matching - # if self.sd.is_flow_matching: - # u = compute_density_for_timestep_sampling( - # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] - # batch_size=batch_size, - # logit_mean=0.0, - # logit_std=1.0, - # mode_scale=1.29, - # ) - # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() # convert the timestep_indices to a timestep timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] timesteps = torch.stack(timesteps, dim=0) @@ -1218,8 +1222,32 @@ class BaseSDTrainProcess(BaseTrainProcess): latents = unaugmented_latents noise_multiplier = self.train_config.noise_multiplier + + s = (noise.shape[0], noise.shape[1], 1, 1) + if len(noise.shape) == 5: + # if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame + s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1) + + if self.train_config.random_noise_multiplier > 0.0: + + # do it on a per batch item, per channel basis + noise_multiplier = 1 + torch.randn( + s, + device=noise.device, + dtype=noise.dtype + ) * self.train_config.random_noise_multiplier noise = noise * noise_multiplier + + if self.train_config.random_noise_shift > 0.0: + # get random noise -1 to 1 + noise_shift = torch.randn( + s, + device=noise.device, + dtype=noise.dtype + ) * self.train_config.random_noise_shift + # add to noise + noise += noise_shift latent_multiplier = self.train_config.latent_multiplier diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e3f80ae3..aa3d84a6 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -325,6 +325,8 @@ class TrainConfig: self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0) + self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0) + self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) @@ -333,7 +335,6 @@ class TrainConfig: # multiplier applied to loos on regularization images self.reg_weight = kwargs.get('reg_weight', 1.0) self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) - self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) # automatically adapte the vae scaling based on the image norm self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False) @@ -412,7 +413,7 @@ class TrainConfig: self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) - self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) @@ -436,7 +437,8 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) - self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample + self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8) self.linear_timesteps = kwargs.get('linear_timesteps', False) self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) self.disable_sampling = kwargs.get('disable_sampling', False) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index d35fc09e..b9e1b71d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -142,7 +142,9 @@ class StableDiffusion: ): self.accelerator = get_accelerator() self.custom_pipeline = custom_pipeline - self.device = device + self.device = str(device) + if "cuda" in self.device and ":" not in self.device: + self.device = f"{self.device}:0" self.device_torch = torch.device(device) self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) @@ -2086,7 +2088,10 @@ class StableDiffusion: noise_pred = noise_pred else: if self.unet.device != self.device_torch: - self.unet.to(self.device_torch) + try: + self.unet.to(self.device_torch) + except Exception as e: + pass if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype) if self.is_flux: