diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/backend/diffusion_engine/sd15.py @@ -0,0 +1 @@ +# diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/backend/diffusion_engine/sdxl.py @@ -0,0 +1 @@ +# diff --git a/backend/huggingface_mapping.py b/backend/huggingface_mapping.py deleted file mode 100644 index 39c566e8..00000000 --- a/backend/huggingface_mapping.py +++ /dev/null @@ -1,64 +0,0 @@ -from backend import latent_spaces - - -class SupportedModel: - unet_config = {} - latent = latent_spaces.LatentSpace - huggingface_mappings = [] - - -class SD15(SupportedModel): - unet_config = { - "context_dim": 768, - "model_channels": 320, - "use_linear_in_transformer": False, - "adm_in_channels": None, - } - latent = latent_spaces.SD15 - huggingface_mappings = [ - "runwayml/stable-diffusion-v1-5", - "runwayml/stable-diffusion-inpainting" - ] - - -class SD21(SupportedModel): - unet_config = { - "context_dim": 1024, - "model_channels": 320, - "use_linear_in_transformer": True, - "adm_in_channels": None, - } - latent = latent_spaces.SD15 - huggingface_mappings = [ - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-2-inpainting" - ] - - -class SDXL(SupportedModel): - unet_config = { - "model_channels": 320, - "use_linear_in_transformer": True, - "transformer_depth": [0, 0, 2, 2, 10, 10], - "context_dim": 2048, - "adm_in_channels": 2816, - } - latent = latent_spaces.SDXL - huggingface_mappings = [ - "stabilityai/stable-diffusion-xl-base-1.0", - "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", - "playgroundai/playground-v2.5-1024px-aesthetic", - - ] - - -class SD3(SupportedModel): - unet_config = {} - latent = latent_spaces.SD3 - huggingface_mappings = [ - "stabilityai/stable-diffusion-3-medium-diffusers" - ] - - -class Flux(SupportedModel): - pass diff --git a/backend/latent_spaces.py b/backend/latent_spaces.py deleted file mode 100644 index 23b8bb80..00000000 --- a/backend/latent_spaces.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch - - -class LatentSpace: - scale_factor = 0.18215 - rgb_conversion = [ - [0.3512, 0.2297, 0.3227], - [0.3250, 0.4974, 0.2350], - [-0.2829, 0.1762, 0.2721], - [-0.2120, -0.2616, -0.7177] - ] - taesd_name = "taesd_decoder" - - -class SD15(LatentSpace): - scale_factor = 0.18215 - rgb_conversion = [ - [0.3512, 0.2297, 0.3227], - [0.3250, 0.4974, 0.2350], - [-0.2829, 0.1762, 0.2721], - [-0.2120, -0.2616, -0.7177] - ] - taesd_name = "taesd_decoder" - - -class SDXL(LatentSpace): - scale_factor = 0.13025 - rgb_conversion = [ - [0.3920, 0.4054, 0.4549], - [-0.2634, -0.0196, 0.0653], - [0.0568, 0.1687, -0.0755], - [-0.3112, -0.2359, -0.2076] - ] - taesd_name = "taesdxl_decoder" - - -class SDXL_Playground_2_5(LatentSpace): - scale_factor = 0.5 - latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) - latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) - rgb_conversion = [ - [0.3920, 0.4054, 0.4549], - [-0.2634, -0.0196, 0.0653], - [0.0568, 0.1687, -0.0755], - [-0.3112, -0.2359, -0.2076] - ] - taesd_name = "taesdxl_decoder" - - -class SD3(LatentSpace): - latent_channels = 16 - scale_factor = 1.5305 - shift_factor = 0.0609 - rgb_conversion = [ - [-0.0645, 0.0177, 0.1052], - [0.0028, 0.0312, 0.0650], - [0.1848, 0.0762, 0.0360], - [0.0944, 0.0360, 0.0889], - [0.0897, 0.0506, -0.0364], - [-0.0020, 0.1203, 0.0284], - [0.0855, 0.0118, 0.0283], - [-0.0539, 0.0658, 0.1047], - [-0.0057, 0.0116, 0.0700], - [-0.0412, 0.0281, -0.0039], - [0.1106, 0.1171, 0.1220], - [-0.0248, 0.0682, -0.0481], - [0.0815, 0.0846, 0.1207], - [-0.0120, -0.0055, -0.0867], - [-0.0749, -0.0634, -0.0456], - [-0.1418, -0.1457, -0.1259] - ] - taesd_name = "taesd3_decoder" - - -class Flux(LatentSpace): - latent_channels = 16 - scale_factor = 0.3611 - shift_factor = 0.1159 - rgb_conversion = [ - [-0.0404, 0.0159, 0.0609], - [0.0043, 0.0298, 0.0850], - [0.0328, -0.0749, -0.0503], - [-0.0245, 0.0085, 0.0549], - [0.0966, 0.0894, 0.0530], - [0.0035, 0.0399, 0.0123], - [0.0583, 0.1184, 0.1262], - [-0.0191, -0.0206, -0.0306], - [-0.0324, 0.0055, 0.1001], - [0.0955, 0.0659, -0.0545], - [-0.0504, 0.0231, -0.0013], - [0.0500, -0.0008, -0.0088], - [0.0982, 0.0941, 0.0976], - [-0.1233, -0.0280, -0.0897], - [-0.0005, -0.0530, -0.0020], - [-0.1273, -0.0932, -0.0680] - ]