Allow DFE to not have a VAE

This commit is contained in:
Jaret Burkett
2025-03-30 09:23:01 -06:00
parent 860d892214
commit c083a0e5ea
2 changed files with 10 additions and 6 deletions

View File

@@ -197,7 +197,10 @@ class SDTrainer(BaseSDTrainProcess):
flush()
if self.train_config.diffusion_feature_extractor_path is not None:
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path)
vae = None
if self.model_config.arch != "flux":
vae = self.sd.vae
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
self.dfe.to(self.device_torch)
self.dfe.eval()

View File

@@ -154,11 +154,12 @@ class DiffusionFeatureExtractor(nn.Module):
class DiffusionFeatureExtractor3(nn.Module):
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16):
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
super().__init__()
self.version = 3
vae = AutoencoderTiny.from_pretrained(
"madebyollin/taef1", torch_dtype=torch.bfloat16)
if vae is None:
vae = AutoencoderTiny.from_pretrained(
"madebyollin/taef1", torch_dtype=torch.bfloat16)
self.vae = vae
image_encoder_path = "google/siglip-so400m-patch14-384"
try:
@@ -342,9 +343,9 @@ class DiffusionFeatureExtractor3(nn.Module):
return total_loss
def load_dfe(model_path) -> DiffusionFeatureExtractor:
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
if model_path == "v3":
dfe = DiffusionFeatureExtractor3()
dfe = DiffusionFeatureExtractor3(vae=vae)
dfe.eval()
return dfe
if not os.path.exists(model_path):