mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added experimental dfe 5
This commit is contained in:
@@ -577,7 +577,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||
elif self.dfe.version == 3 or self.dfe.version == 4:
|
||||
elif self.dfe.version in [3, 4, 5]:
|
||||
dfe_loss = self.dfe(
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
|
||||
@@ -470,10 +470,31 @@ class DiffusionFeatureExtractor4(nn.Module):
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# embeds = id_embeds['hidden_states'][-2] # penultimate layer
|
||||
image_embeds = id_embeds['pooler_output']
|
||||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
image_embeds = id_embeds['hidden_states'][-2] # penultimate layer
|
||||
# image_embeds = id_embeds['pooler_output']
|
||||
# image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
return image_embeds
|
||||
|
||||
def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler):
|
||||
bs = noise_pred.shape[0]
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
stepped_chunks = []
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx]
|
||||
timestep = timestep_chunks[idx]
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
stepped_chunks.append(prev_sample)
|
||||
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
return stepped_latents
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -509,26 +530,7 @@ class DiffusionFeatureExtractor4(nn.Module):
|
||||
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||
else:
|
||||
# stepped_latents = noise - noise_pred
|
||||
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||
bs = noise_pred.shape[0]
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
stepped_chunks = []
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx]
|
||||
timestep = timestep_chunks[idx]
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
stepped_chunks.append(prev_sample)
|
||||
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
stepped_latents = self.step_latents(noise, noise_pred, noisy_latents, timesteps, scheduler)
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
@@ -590,6 +592,52 @@ class DiffusionFeatureExtractor4(nn.Module):
|
||||
self.step += 1
|
||||
|
||||
return total_loss
|
||||
|
||||
class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
|
||||
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
|
||||
super().__init__(device=device, dtype=dtype, vae=vae)
|
||||
self.version = 5
|
||||
|
||||
def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler, total_steps: int = 1000, eps: float = 1e-6):
|
||||
bs = noise_pred.shape[0]
|
||||
|
||||
# Chunk inputs per-sample (keeps existing structure)
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
noise_chunks = torch.chunk(noise, bs)
|
||||
|
||||
stepped_chunks = []
|
||||
x0_pred_chunks = []
|
||||
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent)
|
||||
timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t])
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device)
|
||||
|
||||
# Initialize scheduler step index for this sample
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
|
||||
# ---- Step +50 indices (or to the end) in sigma-space ----
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1)
|
||||
sigma_next = scheduler.sigmas[target_idx]
|
||||
|
||||
# One-step update along the model-predicted direction
|
||||
stepped = sample + (sigma_next - sigma) * model_output
|
||||
stepped_chunks.append(stepped)
|
||||
|
||||
# ---- Inverse-Gaussian recovery at the target timestep ----
|
||||
t_01 = (scheduler.sigmas[target_idx] / 1000).to(stepped.device).to(stepped.dtype)
|
||||
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
|
||||
x0_pred_chunks.append(original_samples)
|
||||
|
||||
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
predicted_images = torch.cat(x0_pred_chunks, dim=0)
|
||||
# return stepped_latents, predicted_images
|
||||
return predicted_images
|
||||
|
||||
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||
if model_path == "v3":
|
||||
@@ -600,6 +648,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||
dfe = DiffusionFeatureExtractor4(vae=vae)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if model_path == "v5":
|
||||
dfe = DiffusionFeatureExtractor5(vae=vae)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
# if it ende with safetensors
|
||||
|
||||
Reference in New Issue
Block a user