initial commit
This commit is contained in:
87
backend/diffusion_engine/base.py
Executable file
87
backend/diffusion_engine/base.py
Executable file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import safetensors.torch as sf
|
||||
|
||||
from backend import utils
|
||||
|
||||
|
||||
class ForgeObjects:
|
||||
def __init__(self, unet, clip, vae, clipvision):
|
||||
self.unet = unet
|
||||
self.clip = clip
|
||||
self.vae = vae
|
||||
self.clipvision = clipvision
|
||||
|
||||
def shallow_copy(self):
|
||||
return ForgeObjects(
|
||||
self.unet,
|
||||
self.clip,
|
||||
self.vae,
|
||||
self.clipvision
|
||||
)
|
||||
|
||||
|
||||
class ForgeDiffusionEngine:
|
||||
matched_guesses = []
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
self.model_config = estimated_config
|
||||
self.is_inpaint = estimated_config.inpaint_model()
|
||||
|
||||
self.forge_objects = None
|
||||
self.forge_objects_original = None
|
||||
self.forge_objects_after_applying_lora = None
|
||||
|
||||
self.current_lora_hash = str([])
|
||||
|
||||
self.fix_for_webui_backward_compatibility()
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
pass
|
||||
|
||||
def get_first_stage_encoding(self, x):
|
||||
return x # legacy code, do not change
|
||||
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
pass
|
||||
|
||||
def encode_first_stage(self, x):
|
||||
pass
|
||||
|
||||
def decode_first_stage(self, x):
|
||||
pass
|
||||
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
return 0, 75
|
||||
|
||||
def is_webui_legacy_model(self):
|
||||
return self.is_sd1 or self.is_sd2 or self.is_sdxl or self.is_sd3
|
||||
|
||||
def fix_for_webui_backward_compatibility(self):
|
||||
self.tiling_enabled = False
|
||||
self.first_stage_model = None
|
||||
self.cond_stage_model = None
|
||||
self.use_distilled_cfg_scale = False
|
||||
self.is_sd1 = False
|
||||
self.is_sd2 = False
|
||||
self.is_sdxl = False
|
||||
self.is_sd3 = False
|
||||
return
|
||||
|
||||
def save_unet(self, filename):
|
||||
sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
114
backend/diffusion_engine/flux.py
Executable file
114
backend/diffusion_engine/flux.py
Executable file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend.modules.k_prediction import PredictionFlux
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
class Flux(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.Flux, model_list.FluxSchnell]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
self.is_inpaint = False
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
't5xxl': huggingface_components['text_encoder_2']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
't5xxl': huggingface_components['tokenizer_2']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
if 'schnell' in estimated_config.huggingface_repo.lower():
|
||||
k_predictor = PredictionFlux(
|
||||
mu=1.0
|
||||
)
|
||||
else:
|
||||
k_predictor = PredictionFlux(
|
||||
seq_len=4096,
|
||||
base_seq_len=256,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.5,
|
||||
max_shift=1.15,
|
||||
)
|
||||
self.use_distilled_cfg_scale = True
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler=None,
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=False,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=True,
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
cond_l, pooled_l = self.text_processing_engine_l(prompt)
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
cond = dict(crossattn=cond_t5, vector=pooled_l)
|
||||
|
||||
if self.use_distilled_cfg_scale:
|
||||
distilled_cfg_scale = getattr(prompt, 'distilled_cfg_scale', 3.5) or 3.5
|
||||
cond['guidance'] = torch.FloatTensor([distilled_cfg_scale] * len(prompt))
|
||||
print(f'Distilled CFG Scale: {distilled_cfg_scale}')
|
||||
else:
|
||||
print('Distilled CFG Scale will be ignored for Schnell')
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
100
backend/diffusion_engine/sd15.py
Executable file
100
backend/diffusion_engine/sd15.py
Executable file
@@ -0,0 +1,100 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
|
||||
class StableDiffusion(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD15]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=False,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=False,
|
||||
final_layer_norm=True,
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd1 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
cond = self.text_processing_engine(prompt)
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine.process_texts([prompt])
|
||||
return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SD15.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
100
backend/diffusion_engine/sd20.py
Executable file
100
backend/diffusion_engine/sd20.py
Executable file
@@ -0,0 +1,100 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
|
||||
class StableDiffusion2(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD20]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_h': huggingface_components['text_encoder']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_h': huggingface_components['tokenizer']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_h,
|
||||
tokenizer=clip.tokenizer.clip_h,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_h',
|
||||
embedding_expected_shape=1024,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=False,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=False,
|
||||
final_layer_norm=True,
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd2 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
cond = self.text_processing_engine(prompt)
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine.process_texts([prompt])
|
||||
return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SD20.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
149
backend/diffusion_engine/sd35.py
Executable file
149
backend/diffusion_engine/sd35.py
Executable file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.modules.k_prediction import PredictionDiscreteFlow
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
## patch SD3 Class in huggingface_guess.model_list
|
||||
def SD3_clip_target(self, state_dict={}):
|
||||
return {'clip_l': 'text_encoder', 'clip_g': 'text_encoder_2', 't5xxl': 'text_encoder_3'}
|
||||
|
||||
model_list.SD3.unet_target = 'transformer'
|
||||
model_list.SD3.clip_target = SD3_clip_target
|
||||
## end patch
|
||||
|
||||
class StableDiffusion3(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD3]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
self.is_inpaint = False
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
'clip_g': huggingface_components['text_encoder_2'],
|
||||
't5xxl' : huggingface_components['text_encoder_3']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
'clip_g': huggingface_components['tokenizer_2'],
|
||||
't5xxl' : huggingface_components['tokenizer_3']
|
||||
}
|
||||
)
|
||||
|
||||
k_predictor = PredictionDiscreteFlow(shift=3.0)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler= None,
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=1280,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd3 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_g, g_pooled = self.text_processing_engine_g(prompt)
|
||||
cond_l, l_pooled = self.text_processing_engine_l(prompt)
|
||||
if opts.sd3_enable_t5:
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
else:
|
||||
cond_t5 = torch.zeros([len(prompt), 256, 4096]).to(cond_l.device)
|
||||
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
l_pooled = torch.zeros_like(l_pooled)
|
||||
g_pooled = torch.zeros_like(g_pooled)
|
||||
cond_l = torch.zeros_like(cond_l)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
cond_t5 = torch.zeros_like(cond_t5)
|
||||
|
||||
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
|
||||
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
|
||||
|
||||
cond = dict(
|
||||
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
|
||||
vector=torch.cat([l_pooled, g_pooled], dim=-1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
|
||||
return sample.to(x)
|
||||
272
backend/diffusion_engine/sdxl.py
Executable file
272
backend/diffusion_engine/sdxl.py
Executable file
@@ -0,0 +1,272 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.nn.unet import Timestep
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SDXL]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
'clip_g': huggingface_components['text_encoder_2']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
'clip_g': huggingface_components['tokenizer_2']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=2048,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=False,
|
||||
minimal_clip_skip=2,
|
||||
clip_skip=2,
|
||||
return_pooled=False,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=2048,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=2,
|
||||
clip_skip=2,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sdxl = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_l = self.text_processing_engine_l(prompt)
|
||||
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
|
||||
|
||||
width = getattr(prompt, 'width', 1024) or 1024
|
||||
height = getattr(prompt, 'height', 1024) or 1024
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
crop_w = opts.sdxl_crop_left
|
||||
crop_h = opts.sdxl_crop_top
|
||||
target_width = width
|
||||
target_height = height
|
||||
|
||||
out = [
|
||||
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
|
||||
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
|
||||
self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))
|
||||
]
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
clip_pooled = torch.zeros_like(clip_pooled)
|
||||
cond_l = torch.zeros_like(cond_l)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
|
||||
cond = dict(
|
||||
crossattn=torch.cat([cond_l, cond_g], dim=2),
|
||||
vector=torch.cat([clip_pooled, flat], dim=1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine_l.process_texts([prompt])
|
||||
return token_count, self.text_processing_engine_l.get_target_prompt_token_count(token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SDXL.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
|
||||
class StableDiffusionXLRefiner(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SDXLRefiner]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_g': huggingface_components['text_encoder']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_g': huggingface_components['tokenizer'],
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=2048,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=2,
|
||||
clip_skip=2,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sdxl = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
|
||||
|
||||
width = getattr(prompt, 'width', 1024) or 1024
|
||||
height = getattr(prompt, 'height', 1024) or 1024
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
crop_w = opts.sdxl_crop_left
|
||||
crop_h = opts.sdxl_crop_top
|
||||
aesthetic = opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
out = [
|
||||
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
|
||||
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
|
||||
self.embedder(torch.Tensor([aesthetic]))
|
||||
]
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
clip_pooled = torch.zeros_like(clip_pooled)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
|
||||
cond = dict(
|
||||
crossattn=cond_g,
|
||||
vector=torch.cat([clip_pooled, flat], dim=1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine_g.process_texts([prompt])
|
||||
return token_count, self.text_processing_engine_g.get_target_prompt_token_count(token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SDXLRefiner.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
Reference in New Issue
Block a user