Add DFE8 with partial step

This commit is contained in:
Jaret Burkett
2026-04-17 17:40:16 -06:00
parent 7c4f18ce51
commit beb40ae29b
2 changed files with 47 additions and 12 deletions

View File

@@ -664,7 +664,7 @@ class SDTrainer(BaseSDTrainProcess):
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
elif self.dfe.version in [3, 4, 5, 6, 7]:
elif self.dfe.version in [3, 4, 5, 6, 7, 8]:
dfe_loss = self.dfe(
noise=noise,
noise_pred=noise_pred,

View File

@@ -808,8 +808,16 @@ class DiffusionFeatureExtractor6(nn.Module):
class DiffusionFeatureExtractor7(nn.Module):
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None, sd=None):
def __init__(
self,
device=torch.device("cuda"),
dtype=torch.bfloat16,
vae=None,
sd=None,
partial_step: bool = False
):
super().__init__()
self.version = 7
self.sd_ref = weakref.ref(sd) if sd is not None else None
pretrained_model_name = "google/tipsv2-b14-dpt"
@@ -823,6 +831,7 @@ class DiffusionFeatureExtractor7(nn.Module):
self.losses = {}
self.log_every = 100
self.step = 0
self.do_partial_step = partial_step
def prepare_inputs(self, tensor_0_1: torch.Tensor):
"""
@@ -886,14 +895,31 @@ class DiffusionFeatureExtractor7(nn.Module):
# expand shape to match noise_pred
while len(tv.shape) < len(noise_pred.shape):
tv = tv.unsqueeze(-1)
# min 0.001
tv = torch.clamp(tv, min=0.001)
# step latent
x0 = noisy_latents - tv * noise_pred
stepped_latents = x0
with torch.no_grad():
target_0_1 = (tensors + 1) / 2 # 0 to 1
if not self.do_partial_step:
# step latent
x0 = noisy_latents - tv * noise_pred
stepped_latents = x0
# min 0.001
tv = torch.clamp(tv, min=0.001)
else:
# step is random 0.05 to 0.02
step = torch.rand_like(tv) * 0.03 + 0.02
next_step = tv - step
next_step = torch.clamp(next_step, min=0.0)
stepped_latents = noisy_latents + (next_step - tv) * noise_pred
with torch.no_grad():
# make a noisy target at next timestep
target_latents = batch.latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype)
# add noise
target_latents = (1.0 - next_step) * target_latents + next_step * noise
target_n1p1 = self.sd_ref().decode_latents(target_latents)
target_0_1 = (target_n1p1 + 1) / 2 # 0 to 1
latents = stepped_latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype)
tensors_n1p1 = self.sd_ref().decode_latents(latents)
@@ -904,10 +930,7 @@ class DiffusionFeatureExtractor7(nn.Module):
dtype = self.model.dtype
with torch.no_grad():
target_img = tensors.to(device, dtype=dtype)
# go from -1 to 1 to 0 to 1
target_img = (target_img + 1) / 2
target = self.prepare_inputs(target_img)
target = self.prepare_inputs(target_0_1)
target = self.model(target)
pred_images = pred_images.to(device, dtype=dtype)
@@ -929,6 +952,9 @@ class DiffusionFeatureExtractor7(nn.Module):
total_loss = (depth_loss + normals_loss + segmentation_loss) / 3.0
if self.do_partial_step:
total_loss = total_loss * 10.0
if 'total' not in self.losses:
self.losses['total'] = total_loss.item()
else:
@@ -963,6 +989,11 @@ class DiffusionFeatureExtractor7(nn.Module):
return total_loss
class DiffusionFeatureExtractor8(DiffusionFeatureExtractor7):
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None, sd=None):
super().__init__(device=device, dtype=dtype, vae=vae, sd=sd, partial_step=True)
self.version = 8
def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureExtractor:
if model_path == "v3":
dfe = DiffusionFeatureExtractor3(vae=vae)
@@ -984,6 +1015,10 @@ def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureEx
dfe = DiffusionFeatureExtractor7(vae=vae, sd=sd)
dfe.eval()
return dfe
if model_path == "v8":
dfe = DiffusionFeatureExtractor8(vae=vae, sd=sd)
dfe.eval()
return dfe
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# if it ende with safetensors