mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Changes to handle a different DFE arch
This commit is contained in:
@@ -305,8 +305,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# enable gradient checkpointing on the vae
|
||||
if vae is not None and self.train_config.gradient_checkpointing:
|
||||
vae.enable_gradient_checkpointing()
|
||||
vae.train()
|
||||
try:
|
||||
vae.enable_gradient_checkpointing()
|
||||
vae.train()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
||||
|
||||
@@ -140,13 +140,13 @@ class DFEBlock(nn.Module):
|
||||
|
||||
|
||||
class DiffusionFeatureExtractor(nn.Module):
|
||||
def __init__(self, in_channels=16):
|
||||
def __init__(self, in_channels=16, out_channels=512):
|
||||
super().__init__()
|
||||
self.version = 1
|
||||
num_blocks = 6
|
||||
self.conv_in = nn.Conv2d(in_channels, 512, 1)
|
||||
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
|
||||
self.conv_out = nn.Conv2d(512, 512, 1)
|
||||
self.conv_out = nn.Conv2d(512, out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
@@ -611,7 +611,9 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||
state_dict = state_dict['model_state_dict']
|
||||
|
||||
if 'conv_in.weight' in state_dict:
|
||||
dfe = DiffusionFeatureExtractor()
|
||||
# determine num out channels
|
||||
out_channels = state_dict['conv_out.weight'].shape[0]
|
||||
dfe = DiffusionFeatureExtractor(out_channels=out_channels)
|
||||
else:
|
||||
dfe = DiffusionFeatureExtractor2()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user