mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Add DFE8 with partial step
This commit is contained in:
@@ -664,7 +664,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||
elif self.dfe.version in [3, 4, 5, 6, 7]:
|
||||
elif self.dfe.version in [3, 4, 5, 6, 7, 8]:
|
||||
dfe_loss = self.dfe(
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
|
||||
@@ -808,8 +808,16 @@ class DiffusionFeatureExtractor6(nn.Module):
|
||||
|
||||
|
||||
class DiffusionFeatureExtractor7(nn.Module):
|
||||
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None, sd=None):
|
||||
def __init__(
|
||||
self,
|
||||
device=torch.device("cuda"),
|
||||
dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
sd=None,
|
||||
partial_step: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.version = 7
|
||||
self.sd_ref = weakref.ref(sd) if sd is not None else None
|
||||
pretrained_model_name = "google/tipsv2-b14-dpt"
|
||||
@@ -823,6 +831,7 @@ class DiffusionFeatureExtractor7(nn.Module):
|
||||
self.losses = {}
|
||||
self.log_every = 100
|
||||
self.step = 0
|
||||
self.do_partial_step = partial_step
|
||||
|
||||
def prepare_inputs(self, tensor_0_1: torch.Tensor):
|
||||
"""
|
||||
@@ -886,14 +895,31 @@ class DiffusionFeatureExtractor7(nn.Module):
|
||||
# expand shape to match noise_pred
|
||||
while len(tv.shape) < len(noise_pred.shape):
|
||||
tv = tv.unsqueeze(-1)
|
||||
# min 0.001
|
||||
tv = torch.clamp(tv, min=0.001)
|
||||
|
||||
# step latent
|
||||
x0 = noisy_latents - tv * noise_pred
|
||||
|
||||
stepped_latents = x0
|
||||
with torch.no_grad():
|
||||
target_0_1 = (tensors + 1) / 2 # 0 to 1
|
||||
|
||||
if not self.do_partial_step:
|
||||
# step latent
|
||||
x0 = noisy_latents - tv * noise_pred
|
||||
stepped_latents = x0
|
||||
# min 0.001
|
||||
tv = torch.clamp(tv, min=0.001)
|
||||
else:
|
||||
# step is random 0.05 to 0.02
|
||||
step = torch.rand_like(tv) * 0.03 + 0.02
|
||||
next_step = tv - step
|
||||
next_step = torch.clamp(next_step, min=0.0)
|
||||
stepped_latents = noisy_latents + (next_step - tv) * noise_pred
|
||||
|
||||
with torch.no_grad():
|
||||
# make a noisy target at next timestep
|
||||
target_latents = batch.latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype)
|
||||
# add noise
|
||||
target_latents = (1.0 - next_step) * target_latents + next_step * noise
|
||||
target_n1p1 = self.sd_ref().decode_latents(target_latents)
|
||||
target_0_1 = (target_n1p1 + 1) / 2 # 0 to 1
|
||||
|
||||
latents = stepped_latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype)
|
||||
|
||||
tensors_n1p1 = self.sd_ref().decode_latents(latents)
|
||||
@@ -904,10 +930,7 @@ class DiffusionFeatureExtractor7(nn.Module):
|
||||
dtype = self.model.dtype
|
||||
|
||||
with torch.no_grad():
|
||||
target_img = tensors.to(device, dtype=dtype)
|
||||
# go from -1 to 1 to 0 to 1
|
||||
target_img = (target_img + 1) / 2
|
||||
target = self.prepare_inputs(target_img)
|
||||
target = self.prepare_inputs(target_0_1)
|
||||
target = self.model(target)
|
||||
|
||||
pred_images = pred_images.to(device, dtype=dtype)
|
||||
@@ -929,6 +952,9 @@ class DiffusionFeatureExtractor7(nn.Module):
|
||||
|
||||
total_loss = (depth_loss + normals_loss + segmentation_loss) / 3.0
|
||||
|
||||
if self.do_partial_step:
|
||||
total_loss = total_loss * 10.0
|
||||
|
||||
if 'total' not in self.losses:
|
||||
self.losses['total'] = total_loss.item()
|
||||
else:
|
||||
@@ -963,6 +989,11 @@ class DiffusionFeatureExtractor7(nn.Module):
|
||||
|
||||
return total_loss
|
||||
|
||||
class DiffusionFeatureExtractor8(DiffusionFeatureExtractor7):
|
||||
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None, sd=None):
|
||||
super().__init__(device=device, dtype=dtype, vae=vae, sd=sd, partial_step=True)
|
||||
self.version = 8
|
||||
|
||||
def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureExtractor:
|
||||
if model_path == "v3":
|
||||
dfe = DiffusionFeatureExtractor3(vae=vae)
|
||||
@@ -984,6 +1015,10 @@ def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureEx
|
||||
dfe = DiffusionFeatureExtractor7(vae=vae, sd=sd)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if model_path == "v8":
|
||||
dfe = DiffusionFeatureExtractor8(vae=vae, sd=sd)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
# if it ende with safetensors
|
||||
|
||||
Reference in New Issue
Block a user