mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Added some experimental training techniques. Ignore for now. Still in testing.
This commit is contained in:
16
.vscode/launch.json
vendored
16
.vscode/launch.json
vendored
@@ -16,6 +16,22 @@
|
|||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": false
|
"justMyCode": false
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "Run current config (cuda:1)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/run.py",
|
||||||
|
"args": [
|
||||||
|
"${file}"
|
||||||
|
],
|
||||||
|
"env": {
|
||||||
|
"CUDA_LAUNCH_BLOCKING": "1",
|
||||||
|
"DEBUG_TOOLKIT": "1",
|
||||||
|
"CUDA_VISIBLE_DEVICES": "1"
|
||||||
|
},
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Debug Current File",
|
"name": "Python: Debug Current File",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
|
|||||||
21
build_and_push_docker_dev
Normal file
21
build_and_push_docker_dev
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
VERSION=dev
|
||||||
|
GIT_COMMIT=dev
|
||||||
|
|
||||||
|
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
||||||
|
echo "Building version: $VERSION and latest"
|
||||||
|
# wait 2 seconds
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# Build the image with cache busting
|
||||||
|
docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
|
||||||
|
|
||||||
|
# Tag with version and latest
|
||||||
|
docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
|
||||||
|
|
||||||
|
# Push both tags
|
||||||
|
echo "Pushing images to Docker Hub..."
|
||||||
|
docker push ostris/aitoolkit:$VERSION
|
||||||
|
|
||||||
|
echo "Successfully built and pushed ostris/aitoolkit:$VERSION"
|
||||||
@@ -35,6 +35,7 @@ import math
|
|||||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||||
from toolkit.util.wavelet_loss import wavelet_loss
|
from toolkit.util.wavelet_loss import wavelet_loss
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -60,6 +61,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||||
self.negative_prompt_pool: Union[List[str], None] = None
|
self.negative_prompt_pool: Union[List[str], None] = None
|
||||||
self.batch_negative_prompt: Union[List[str], None] = None
|
self.batch_negative_prompt: Union[List[str], None] = None
|
||||||
|
self.cfm_cache = None
|
||||||
|
|
||||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||||
|
|
||||||
@@ -197,7 +199,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||||
vae = None
|
vae = self.sd.vae
|
||||||
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||||
# vae = self.sd.vae
|
# vae = self.sd.vae
|
||||||
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
||||||
@@ -756,13 +758,13 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def predict_noise(
|
def predict_noise(
|
||||||
self,
|
self,
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
timesteps: Union[int, torch.Tensor] = 1,
|
timesteps: Union[int, torch.Tensor] = 1,
|
||||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||||
batch: Optional['DataLoaderBatchDTO'] = None,
|
batch: Optional['DataLoaderBatchDTO'] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
return self.sd.predict_noise(
|
return self.sd.predict_noise(
|
||||||
@@ -778,6 +780,81 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def cfm_augment_tensors(
|
||||||
|
self,
|
||||||
|
images: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.cfm_cache is None:
|
||||||
|
# flip the current one. Only need this for first time
|
||||||
|
self.cfm_cache = torch.flip(images, [3]).clone()
|
||||||
|
augmented_tensor_list = []
|
||||||
|
for i in range(images.shape[0]):
|
||||||
|
# get a random one
|
||||||
|
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
|
||||||
|
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
|
||||||
|
augmented = torch.cat(augmented_tensor_list, dim=0)
|
||||||
|
# resize to match the input
|
||||||
|
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
|
||||||
|
self.cfm_cache = images.clone()
|
||||||
|
return augmented
|
||||||
|
|
||||||
|
def get_cfm_loss(
|
||||||
|
self,
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
noise_pred: torch.Tensor,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
alpha: float = 0.1,
|
||||||
|
):
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
if hasattr(self.sd, 'get_loss_target'):
|
||||||
|
target = self.sd.get_loss_target(
|
||||||
|
noise=noise,
|
||||||
|
batch=batch,
|
||||||
|
timesteps=timesteps,
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
elif self.sd.is_flow_matching:
|
||||||
|
# forward ODE
|
||||||
|
target = (noise - batch.latents).detach()
|
||||||
|
else:
|
||||||
|
raise ValueError("CFM loss only works with flow matching")
|
||||||
|
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
with torch.no_grad():
|
||||||
|
# we need to compute the contrast
|
||||||
|
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
|
||||||
|
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
|
||||||
|
cfm_noisy_latents = self.sd.add_noise(
|
||||||
|
original_samples=cfm_latents,
|
||||||
|
noise=noise,
|
||||||
|
timesteps=timesteps,
|
||||||
|
)
|
||||||
|
cfm_pred = self.predict_noise(
|
||||||
|
noisy_latents=cfm_noisy_latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
unconditional_embeds=None,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
|
||||||
|
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
|
||||||
|
|
||||||
|
# # Compute cosine similarity at each pixel
|
||||||
|
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
|
||||||
|
|
||||||
|
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||||
|
# Compute cosine similarity at each pixel
|
||||||
|
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
|
||||||
|
|
||||||
|
# Average over spatial dimensions, then batch
|
||||||
|
contrastive_loss = -sim.mean()
|
||||||
|
|
||||||
|
loss = fm_loss.mean() + alpha * contrastive_loss
|
||||||
|
return loss
|
||||||
|
|
||||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||||
self.timer.start('preprocess_batch')
|
self.timer.start('preprocess_batch')
|
||||||
@@ -1431,6 +1508,44 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||||
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
|
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
|
||||||
|
|
||||||
|
if self.train_config.timestep_type == 'next_sample':
|
||||||
|
with self.timer('next_sample_step'):
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
|
||||||
|
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
|
||||||
|
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
|
||||||
|
|
||||||
|
# do a sample at the current timestep and step it, then determine new noise
|
||||||
|
next_sample_pred = self.predict_noise(
|
||||||
|
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
|
timesteps=timesteps,
|
||||||
|
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
batch=batch,
|
||||||
|
**pred_kwargs
|
||||||
|
)
|
||||||
|
stepped_latents = self.sd.step_scheduler(
|
||||||
|
next_sample_pred,
|
||||||
|
noisy_latents,
|
||||||
|
timesteps,
|
||||||
|
self.sd.noise_scheduler
|
||||||
|
)
|
||||||
|
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
|
||||||
|
noisy_latents = stepped_latents
|
||||||
|
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
|
||||||
|
# todo calc next timestep, for now this may work as it
|
||||||
|
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
|
||||||
|
if len(stepped_latents.shape) == 4:
|
||||||
|
t_01 = t_01.view(-1, 1, 1, 1)
|
||||||
|
elif len(stepped_latents.shape) == 5:
|
||||||
|
t_01 = t_01.view(-1, 1, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
|
||||||
|
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
|
||||||
|
noise = next_sample_noise
|
||||||
|
timesteps = stepped_timesteps
|
||||||
|
|
||||||
with self.timer('predict_unet'):
|
with self.timer('predict_unet'):
|
||||||
noise_pred = self.predict_noise(
|
noise_pred = self.predict_noise(
|
||||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
@@ -1450,15 +1565,25 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||||
prior_to_calculate_loss = None
|
prior_to_calculate_loss = None
|
||||||
|
|
||||||
loss = self.calculate_loss(
|
if self.train_config.loss_type == 'cfm':
|
||||||
noise_pred=noise_pred,
|
loss = self.get_cfm_loss(
|
||||||
noise=noise,
|
noisy_latents=noisy_latents,
|
||||||
noisy_latents=noisy_latents,
|
noise=noise,
|
||||||
timesteps=timesteps,
|
noise_pred=noise_pred,
|
||||||
batch=batch,
|
conditional_embeds=conditional_embeds,
|
||||||
mask_multiplier=mask_multiplier,
|
timesteps=timesteps,
|
||||||
prior_pred=prior_to_calculate_loss,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
loss = self.calculate_loss(
|
||||||
|
noise_pred=noise_pred,
|
||||||
|
noise=noise,
|
||||||
|
noisy_latents=noisy_latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
batch=batch,
|
||||||
|
mask_multiplier=mask_multiplier,
|
||||||
|
prior_pred=prior_to_calculate_loss,
|
||||||
|
)
|
||||||
|
|
||||||
if self.train_config.diff_output_preservation:
|
if self.train_config.diff_output_preservation:
|
||||||
# send the loss backwards otherwise checkpointing will fail
|
# send the loss backwards otherwise checkpointing will fail
|
||||||
|
|||||||
@@ -931,16 +931,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
noise_offset=self.train_config.noise_offset,
|
noise_offset=self.train_config.noise_offset,
|
||||||
).to(self.device_torch, dtype=dtype)
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
if self.train_config.random_noise_shift > 0.0:
|
# if self.train_config.random_noise_shift > 0.0:
|
||||||
# get random noise -1 to 1
|
# # get random noise -1 to 1
|
||||||
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||||
dtype=noise.dtype) * 2 - 1
|
# dtype=noise.dtype) * 2 - 1
|
||||||
|
|
||||||
# multiply by shift amount
|
# # multiply by shift amount
|
||||||
noise_shift *= self.train_config.random_noise_shift
|
# noise_shift *= self.train_config.random_noise_shift
|
||||||
|
|
||||||
# add to noise
|
# # add to noise
|
||||||
noise += noise_shift
|
# noise += noise_shift
|
||||||
|
|
||||||
if self.train_config.blended_blur_noise:
|
if self.train_config.blended_blur_noise:
|
||||||
noise = get_blended_blur_noise(
|
noise = get_blended_blur_noise(
|
||||||
@@ -1011,6 +1011,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
imgs = None
|
imgs = None
|
||||||
is_reg = any(batch.get_is_reg_list())
|
is_reg = any(batch.get_is_reg_list())
|
||||||
|
cfm_batch = None
|
||||||
if batch.tensor is not None:
|
if batch.tensor is not None:
|
||||||
imgs = batch.tensor
|
imgs = batch.tensor
|
||||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||||
@@ -1118,8 +1119,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if timestep_type is None:
|
if timestep_type is None:
|
||||||
timestep_type = self.train_config.timestep_type
|
timestep_type = self.train_config.timestep_type
|
||||||
|
|
||||||
|
if self.train_config.timestep_type == 'next_sample':
|
||||||
|
# simulate a sample
|
||||||
|
num_train_timesteps = self.train_config.next_sample_timesteps
|
||||||
|
timestep_type = 'shift'
|
||||||
|
|
||||||
patch_size = 1
|
patch_size = 1
|
||||||
if self.sd.is_flux:
|
if self.sd.is_flux or 'flex' in self.sd.arch:
|
||||||
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
|
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
elif hasattr(self.sd.unet.config, 'patch_size'):
|
elif hasattr(self.sd.unet.config, 'patch_size'):
|
||||||
@@ -1142,7 +1148,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
content_or_style = self.train_config.content_or_style_reg
|
content_or_style = self.train_config.content_or_style_reg
|
||||||
|
|
||||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||||
if content_or_style in ['style', 'content']:
|
if self.train_config.timestep_type == 'next_sample':
|
||||||
|
timestep_indices = torch.randint(
|
||||||
|
0,
|
||||||
|
num_train_timesteps - 2, # -1 for 0 idx, -1 so we can step
|
||||||
|
(batch_size,),
|
||||||
|
device=self.device_torch
|
||||||
|
)
|
||||||
|
timestep_indices = timestep_indices.long()
|
||||||
|
elif content_or_style in ['style', 'content']:
|
||||||
# this is from diffusers training code
|
# this is from diffusers training code
|
||||||
# Cubic sampling for favoring later or earlier timesteps
|
# Cubic sampling for favoring later or earlier timesteps
|
||||||
# For more details about why cubic sampling is used for content / structure,
|
# For more details about why cubic sampling is used for content / structure,
|
||||||
@@ -1169,7 +1183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
min_noise_steps + 1,
|
min_noise_steps + 1,
|
||||||
max_noise_steps - 1
|
max_noise_steps - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
elif content_or_style == 'balanced':
|
elif content_or_style == 'balanced':
|
||||||
if min_noise_steps == max_noise_steps:
|
if min_noise_steps == max_noise_steps:
|
||||||
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||||
@@ -1185,16 +1199,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||||
|
|
||||||
# do flow matching
|
|
||||||
# if self.sd.is_flow_matching:
|
|
||||||
# u = compute_density_for_timestep_sampling(
|
|
||||||
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
|
||||||
# batch_size=batch_size,
|
|
||||||
# logit_mean=0.0,
|
|
||||||
# logit_std=1.0,
|
|
||||||
# mode_scale=1.29,
|
|
||||||
# )
|
|
||||||
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
|
|
||||||
# convert the timestep_indices to a timestep
|
# convert the timestep_indices to a timestep
|
||||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||||
timesteps = torch.stack(timesteps, dim=0)
|
timesteps = torch.stack(timesteps, dim=0)
|
||||||
@@ -1218,8 +1222,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
latents = unaugmented_latents
|
latents = unaugmented_latents
|
||||||
|
|
||||||
noise_multiplier = self.train_config.noise_multiplier
|
noise_multiplier = self.train_config.noise_multiplier
|
||||||
|
|
||||||
|
s = (noise.shape[0], noise.shape[1], 1, 1)
|
||||||
|
if len(noise.shape) == 5:
|
||||||
|
# if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame
|
||||||
|
s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1)
|
||||||
|
|
||||||
|
if self.train_config.random_noise_multiplier > 0.0:
|
||||||
|
|
||||||
|
# do it on a per batch item, per channel basis
|
||||||
|
noise_multiplier = 1 + torch.randn(
|
||||||
|
s,
|
||||||
|
device=noise.device,
|
||||||
|
dtype=noise.dtype
|
||||||
|
) * self.train_config.random_noise_multiplier
|
||||||
|
|
||||||
noise = noise * noise_multiplier
|
noise = noise * noise_multiplier
|
||||||
|
|
||||||
|
if self.train_config.random_noise_shift > 0.0:
|
||||||
|
# get random noise -1 to 1
|
||||||
|
noise_shift = torch.randn(
|
||||||
|
s,
|
||||||
|
device=noise.device,
|
||||||
|
dtype=noise.dtype
|
||||||
|
) * self.train_config.random_noise_shift
|
||||||
|
# add to noise
|
||||||
|
noise += noise_shift
|
||||||
|
|
||||||
latent_multiplier = self.train_config.latent_multiplier
|
latent_multiplier = self.train_config.latent_multiplier
|
||||||
|
|
||||||
|
|||||||
@@ -325,6 +325,8 @@ class TrainConfig:
|
|||||||
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
|
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
|
||||||
|
self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0)
|
||||||
|
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||||
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||||
@@ -333,7 +335,6 @@ class TrainConfig:
|
|||||||
# multiplier applied to loos on regularization images
|
# multiplier applied to loos on regularization images
|
||||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
|
||||||
# automatically adapte the vae scaling based on the image norm
|
# automatically adapte the vae scaling based on the image norm
|
||||||
self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False)
|
self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False)
|
||||||
|
|
||||||
@@ -412,7 +413,7 @@ class TrainConfig:
|
|||||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||||
|
|
||||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace
|
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
|
||||||
|
|
||||||
# scale the prediction by this. Increase for more detail, decrease for less
|
# scale the prediction by this. Increase for more detail, decrease for less
|
||||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||||
@@ -436,7 +437,8 @@ class TrainConfig:
|
|||||||
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
||||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||||
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend
|
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample
|
||||||
|
self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8)
|
||||||
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
||||||
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
|
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
|
||||||
self.disable_sampling = kwargs.get('disable_sampling', False)
|
self.disable_sampling = kwargs.get('disable_sampling', False)
|
||||||
|
|||||||
@@ -142,7 +142,9 @@ class StableDiffusion:
|
|||||||
):
|
):
|
||||||
self.accelerator = get_accelerator()
|
self.accelerator = get_accelerator()
|
||||||
self.custom_pipeline = custom_pipeline
|
self.custom_pipeline = custom_pipeline
|
||||||
self.device = device
|
self.device = str(device)
|
||||||
|
if "cuda" in self.device and ":" not in self.device:
|
||||||
|
self.device = f"{self.device}:0"
|
||||||
self.device_torch = torch.device(device)
|
self.device_torch = torch.device(device)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.torch_dtype = get_torch_dtype(dtype)
|
self.torch_dtype = get_torch_dtype(dtype)
|
||||||
@@ -2086,7 +2088,10 @@ class StableDiffusion:
|
|||||||
noise_pred = noise_pred
|
noise_pred = noise_pred
|
||||||
else:
|
else:
|
||||||
if self.unet.device != self.device_torch:
|
if self.unet.device != self.device_torch:
|
||||||
self.unet.to(self.device_torch)
|
try:
|
||||||
|
self.unet.to(self.device_torch)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
if self.unet.dtype != self.torch_dtype:
|
if self.unet.dtype != self.torch_dtype:
|
||||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||||
if self.is_flux:
|
if self.is_flux:
|
||||||
|
|||||||
Reference in New Issue
Block a user