diff --git a/backend/latent_spaces.py b/backend/latent_spaces.py new file mode 100644 index 00000000..23b8bb80 --- /dev/null +++ b/backend/latent_spaces.py @@ -0,0 +1,96 @@ +import torch + + +class LatentSpace: + scale_factor = 0.18215 + rgb_conversion = [ + [0.3512, 0.2297, 0.3227], + [0.3250, 0.4974, 0.2350], + [-0.2829, 0.1762, 0.2721], + [-0.2120, -0.2616, -0.7177] + ] + taesd_name = "taesd_decoder" + + +class SD15(LatentSpace): + scale_factor = 0.18215 + rgb_conversion = [ + [0.3512, 0.2297, 0.3227], + [0.3250, 0.4974, 0.2350], + [-0.2829, 0.1762, 0.2721], + [-0.2120, -0.2616, -0.7177] + ] + taesd_name = "taesd_decoder" + + +class SDXL(LatentSpace): + scale_factor = 0.13025 + rgb_conversion = [ + [0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] + taesd_name = "taesdxl_decoder" + + +class SDXL_Playground_2_5(LatentSpace): + scale_factor = 0.5 + latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) + latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) + rgb_conversion = [ + [0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] + taesd_name = "taesdxl_decoder" + + +class SD3(LatentSpace): + latent_channels = 16 + scale_factor = 1.5305 + shift_factor = 0.0609 + rgb_conversion = [ + [-0.0645, 0.0177, 0.1052], + [0.0028, 0.0312, 0.0650], + [0.1848, 0.0762, 0.0360], + [0.0944, 0.0360, 0.0889], + [0.0897, 0.0506, -0.0364], + [-0.0020, 0.1203, 0.0284], + [0.0855, 0.0118, 0.0283], + [-0.0539, 0.0658, 0.1047], + [-0.0057, 0.0116, 0.0700], + [-0.0412, 0.0281, -0.0039], + [0.1106, 0.1171, 0.1220], + [-0.0248, 0.0682, -0.0481], + [0.0815, 0.0846, 0.1207], + [-0.0120, -0.0055, -0.0867], + [-0.0749, -0.0634, -0.0456], + [-0.1418, -0.1457, -0.1259] + ] + taesd_name = "taesd3_decoder" + + +class Flux(LatentSpace): + latent_channels = 16 + scale_factor = 0.3611 + shift_factor = 0.1159 + rgb_conversion = [ + [-0.0404, 0.0159, 0.0609], + [0.0043, 0.0298, 0.0850], + [0.0328, -0.0749, -0.0503], + [-0.0245, 0.0085, 0.0549], + [0.0966, 0.0894, 0.0530], + [0.0035, 0.0399, 0.0123], + [0.0583, 0.1184, 0.1262], + [-0.0191, -0.0206, -0.0306], + [-0.0324, 0.0055, 0.1001], + [0.0955, 0.0659, -0.0545], + [-0.0504, 0.0231, -0.0013], + [0.0500, -0.0008, -0.0088], + [0.0982, 0.0941, 0.0976], + [-0.1233, -0.0280, -0.0897], + [-0.0005, -0.0530, -0.0020], + [-0.1273, -0.0932, -0.0680] + ]