mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Updated diffusion feature extractor
This commit is contained in:
@@ -414,15 +414,41 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if self.dfe is not None:
|
if self.dfe is not None:
|
||||||
if self.dfe.version == 1:
|
if self.dfe.version == 1:
|
||||||
# do diffusion feature extraction on target
|
model = self.sd
|
||||||
|
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||||
|
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||||
|
else:
|
||||||
|
# stepped_latents = noise - noise_pred
|
||||||
|
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||||
|
bs = noise_pred.shape[0]
|
||||||
|
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||||
|
timestep_chunks = torch.chunk(timesteps, bs)
|
||||||
|
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||||
|
stepped_chunks = []
|
||||||
|
for idx in range(bs):
|
||||||
|
model_output = noise_pred_chunks[idx]
|
||||||
|
timestep = timestep_chunks[idx]
|
||||||
|
self.sd.noise_scheduler._step_index = None
|
||||||
|
self.sd.noise_scheduler._init_step_index(timestep)
|
||||||
|
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||||
|
|
||||||
|
sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index]
|
||||||
|
sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step
|
||||||
|
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||||
|
stepped_chunks.append(prev_sample)
|
||||||
|
|
||||||
|
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||||
|
|
||||||
|
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
|
||||||
|
pred_features = self.dfe(stepped_latents.float())
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rectified_flow_target = noise.float() - batch.latents.float()
|
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
|
||||||
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
# scale dfe so it is weaker at higher noise levels
|
||||||
|
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
|
||||||
|
|
||||||
# do diffusion feature extraction on prediction
|
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
|
||||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
|
||||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
additional_loss += dfe_loss.mean()
|
||||||
self.train_config.diffusion_feature_extractor_weight
|
|
||||||
elif self.dfe.version == 2:
|
elif self.dfe.version == 2:
|
||||||
# version 2
|
# version 2
|
||||||
# do diffusion feature extraction on target
|
# do diffusion feature extraction on target
|
||||||
|
|||||||
@@ -454,8 +454,12 @@ class TrainConfig:
|
|||||||
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
|
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
|
||||||
|
|
||||||
# diffusion feature extractor
|
# diffusion feature extractor
|
||||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None)
|
||||||
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0)
|
self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0)
|
||||||
|
|
||||||
|
# we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture
|
||||||
|
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path)
|
||||||
|
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight)
|
||||||
|
|
||||||
# optimal noise pairing
|
# optimal noise pairing
|
||||||
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
||||||
|
|||||||
@@ -127,18 +127,20 @@ class DFEBlock(nn.Module):
|
|||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.act = nn.GELU()
|
self.act = nn.GELU()
|
||||||
|
self.proj = nn.Conv2d(channels, channels, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
|
x = self.proj(x)
|
||||||
x = x + x_in
|
x = x + x_in
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DiffusionFeatureExtractor(nn.Module):
|
class DiffusionFeatureExtractor(nn.Module):
|
||||||
def __init__(self, in_channels=32):
|
def __init__(self, in_channels=16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.version = 1
|
self.version = 1
|
||||||
num_blocks = 6
|
num_blocks = 6
|
||||||
|
|||||||
@@ -2538,8 +2538,8 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# Move to vae to device if on cpu
|
# Move to vae to device if on cpu
|
||||||
if self.vae.device == 'cpu':
|
if self.vae.device == 'cpu':
|
||||||
self.vae.to(self.device)
|
self.vae.to(self.device_torch)
|
||||||
latents = latents.to(device, dtype=dtype)
|
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
|
||||||
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
||||||
images = self.vae.decode(latents).sample
|
images = self.vae.decode(latents).sample
|
||||||
images = images.to(device, dtype=dtype)
|
images = images.to(device, dtype=dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user