mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Allow DFE to not have a VAE
This commit is contained in:
@@ -197,7 +197,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||||
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path)
|
vae = None
|
||||||
|
if self.model_config.arch != "flux":
|
||||||
|
vae = self.sd.vae
|
||||||
|
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
||||||
self.dfe.to(self.device_torch)
|
self.dfe.to(self.device_torch)
|
||||||
self.dfe.eval()
|
self.dfe.eval()
|
||||||
|
|
||||||
|
|||||||
@@ -154,11 +154,12 @@ class DiffusionFeatureExtractor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DiffusionFeatureExtractor3(nn.Module):
|
class DiffusionFeatureExtractor3(nn.Module):
|
||||||
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16):
|
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.version = 3
|
self.version = 3
|
||||||
vae = AutoencoderTiny.from_pretrained(
|
if vae is None:
|
||||||
"madebyollin/taef1", torch_dtype=torch.bfloat16)
|
vae = AutoencoderTiny.from_pretrained(
|
||||||
|
"madebyollin/taef1", torch_dtype=torch.bfloat16)
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
image_encoder_path = "google/siglip-so400m-patch14-384"
|
image_encoder_path = "google/siglip-so400m-patch14-384"
|
||||||
try:
|
try:
|
||||||
@@ -342,9 +343,9 @@ class DiffusionFeatureExtractor3(nn.Module):
|
|||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
def load_dfe(model_path) -> DiffusionFeatureExtractor:
|
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||||
if model_path == "v3":
|
if model_path == "v3":
|
||||||
dfe = DiffusionFeatureExtractor3()
|
dfe = DiffusionFeatureExtractor3(vae=vae)
|
||||||
dfe.eval()
|
dfe.eval()
|
||||||
return dfe
|
return dfe
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user