Updated diffusion feature extractor

This commit is contained in:
Jaret Burkett
2025-06-19 15:36:10 -06:00
parent 4586eb5392
commit 8602470952
4 changed files with 44 additions and 12 deletions

View File

@@ -414,15 +414,41 @@ class SDTrainer(BaseSDTrainProcess):
if self.dfe is not None:
if self.dfe.version == 1:
# do diffusion feature extraction on target
model = self.sd
if model is not None and hasattr(model, 'get_stepped_pred'):
stepped_latents = model.get_stepped_pred(noise_pred, noise)
else:
# stepped_latents = noise - noise_pred
# first we step the scheduler from current timestep to the very end for a full denoise
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
self.sd.noise_scheduler._step_index = None
self.sd.noise_scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index]
sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
pred_features = self.dfe(stepped_latents.float())
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
# scale dfe so it is weaker at higher noise levels
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
# do diffusion feature extraction on prediction
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
self.train_config.diffusion_feature_extractor_weight
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
additional_loss += dfe_loss.mean()
elif self.dfe.version == 2:
# version 2
# do diffusion feature extraction on target

View File

@@ -454,8 +454,12 @@ class TrainConfig:
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
# diffusion feature extractor
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0)
self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None)
self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0)
# we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path)
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight)
# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)

View File

@@ -127,18 +127,20 @@ class DFEBlock(nn.Module):
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.act = nn.GELU()
self.proj = nn.Conv2d(channels, channels, 1)
def forward(self, x):
x_in = x
x = self.conv1(x)
x = self.conv2(x)
x = self.act(x)
x = self.proj(x)
x = x + x_in
return x
class DiffusionFeatureExtractor(nn.Module):
def __init__(self, in_channels=32):
def __init__(self, in_channels=16):
super().__init__()
self.version = 1
num_blocks = 6

View File

@@ -2538,8 +2538,8 @@ class StableDiffusion:
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
latents = latents.to(device, dtype=dtype)
self.vae.to(self.device_torch)
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
images = self.vae.decode(latents).sample
images = images.to(device, dtype=dtype)