mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-04-27 17:51:50 +00:00
Add consistency decoder
This commit is contained in:
35
modules/sd_vae_consistency.py
Normal file
35
modules/sd_vae_consistency.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Consistency Decoder
|
||||
Improved decoding for stable diffusion vaes.
|
||||
|
||||
https://github.com/openai/consistencydecoder
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal, shared
|
||||
from consistencydecoder import ConsistencyDecoder
|
||||
|
||||
|
||||
sd_vae_consistency_models = None
|
||||
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')
|
||||
|
||||
|
||||
def decoder_model():
|
||||
global sd_vae_consistency_models
|
||||
if getattr(shared.sd_model, 'is_sdxl', False):
|
||||
raise NotImplementedError("SDXL is not supported for consistency decoder")
|
||||
if sd_vae_consistency_models is not None:
|
||||
sd_vae_consistency_models.ckpt.to(devices.device)
|
||||
return sd_vae_consistency_models
|
||||
|
||||
loaded_model = ConsistencyDecoder(devices.device, model_path)
|
||||
sd_vae_consistency_models = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def unload():
|
||||
global sd_vae_consistency_models
|
||||
if sd_vae_consistency_models is not None:
|
||||
sd_vae_consistency_models.ckpt.to('cpu')
|
||||
Reference in New Issue
Block a user