mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Updated diffusion feature extractor
This commit is contained in:
@@ -454,8 +454,12 @@ class TrainConfig:
|
||||
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
|
||||
|
||||
# diffusion feature extractor
|
||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
||||
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0)
|
||||
self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None)
|
||||
self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0)
|
||||
|
||||
# we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture
|
||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path)
|
||||
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight)
|
||||
|
||||
# optimal noise pairing
|
||||
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
||||
|
||||
@@ -127,18 +127,20 @@ class DFEBlock(nn.Module):
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.act = nn.GELU()
|
||||
self.proj = nn.Conv2d(channels, channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.act(x)
|
||||
x = self.proj(x)
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionFeatureExtractor(nn.Module):
|
||||
def __init__(self, in_channels=32):
|
||||
def __init__(self, in_channels=16):
|
||||
super().__init__()
|
||||
self.version = 1
|
||||
num_blocks = 6
|
||||
|
||||
@@ -2538,8 +2538,8 @@ class StableDiffusion:
|
||||
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
self.vae.to(self.device_torch)
|
||||
latents = latents.to(self.device_torch, dtype=self.torch_dtype)
|
||||
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
||||
images = self.vae.decode(latents).sample
|
||||
images = images.to(device, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user