Added experimental dfe 5

This commit is contained in:
Jaret Burkett
2025-09-21 10:48:52 -06:00
parent 20dfe1b4d5
commit 28728a1e92
2 changed files with 76 additions and 24 deletions

View File

@@ -470,10 +470,31 @@ class DiffusionFeatureExtractor4(nn.Module):
output_hidden_states=True,
)
# embeds = id_embeds['hidden_states'][-2] # penultimate layer
image_embeds = id_embeds['pooler_output']
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
image_embeds = id_embeds['hidden_states'][-2] # penultimate layer
# image_embeds = id_embeds['pooler_output']
# image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
return image_embeds
def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler):
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
scheduler._step_index = None
scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
return stepped_latents
def forward(
self,
@@ -509,26 +530,7 @@ class DiffusionFeatureExtractor4(nn.Module):
if model is not None and hasattr(model, 'get_stepped_pred'):
stepped_latents = model.get_stepped_pred(noise_pred, noise)
else:
# stepped_latents = noise - noise_pred
# first we step the scheduler from current timestep to the very end for a full denoise
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
scheduler._step_index = None
scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
stepped_latents = self.step_latents(noise, noise_pred, noisy_latents, timesteps, scheduler)
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
@@ -590,6 +592,52 @@ class DiffusionFeatureExtractor4(nn.Module):
self.step += 1
return total_loss
class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
super().__init__(device=device, dtype=dtype, vae=vae)
self.version = 5
def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler, total_steps: int = 1000, eps: float = 1e-6):
bs = noise_pred.shape[0]
# Chunk inputs per-sample (keeps existing structure)
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
noise_chunks = torch.chunk(noise, bs)
stepped_chunks = []
x0_pred_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent)
timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t])
sample = noisy_latent_chunks[idx].to(torch.float32)
noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device)
# Initialize scheduler step index for this sample
scheduler._step_index = None
scheduler._init_step_index(timestep)
# ---- Step +50 indices (or to the end) in sigma-space ----
sigma = scheduler.sigmas[scheduler.step_index]
target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1)
sigma_next = scheduler.sigmas[target_idx]
# One-step update along the model-predicted direction
stepped = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(stepped)
# ---- Inverse-Gaussian recovery at the target timestep ----
t_01 = (scheduler.sigmas[target_idx] / 1000).to(stepped.device).to(stepped.dtype)
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
x0_pred_chunks.append(original_samples)
# stepped_latents = torch.cat(stepped_chunks, dim=0)
predicted_images = torch.cat(x0_pred_chunks, dim=0)
# return stepped_latents, predicted_images
return predicted_images
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
if model_path == "v3":
@@ -600,6 +648,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
dfe = DiffusionFeatureExtractor4(vae=vae)
dfe.eval()
return dfe
if model_path == "v5":
dfe = DiffusionFeatureExtractor5(vae=vae)
dfe.eval()
return dfe
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# if it ende with safetensors