150 lines
5.7 KiB
Python
Executable File
150 lines
5.7 KiB
Python
Executable File
import torch
|
|
|
|
from huggingface_guess import model_list
|
|
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
|
from backend.patcher.clip import CLIP
|
|
from backend.patcher.vae import VAE
|
|
from backend.patcher.unet import UnetPatcher
|
|
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
|
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
|
from backend.args import dynamic_args
|
|
from backend import memory_management
|
|
from backend.modules.k_prediction import PredictionDiscreteFlow
|
|
|
|
from modules.shared import opts
|
|
|
|
|
|
## patch SD3 Class in huggingface_guess.model_list
|
|
def SD3_clip_target(self, state_dict={}):
|
|
return {'clip_l': 'text_encoder', 'clip_g': 'text_encoder_2', 't5xxl': 'text_encoder_3'}
|
|
|
|
model_list.SD3.unet_target = 'transformer'
|
|
model_list.SD3.clip_target = SD3_clip_target
|
|
## end patch
|
|
|
|
class StableDiffusion3(ForgeDiffusionEngine):
|
|
matched_guesses = [model_list.SD3]
|
|
|
|
def __init__(self, estimated_config, huggingface_components):
|
|
super().__init__(estimated_config, huggingface_components)
|
|
self.is_inpaint = False
|
|
|
|
clip = CLIP(
|
|
model_dict={
|
|
'clip_l': huggingface_components['text_encoder'],
|
|
'clip_g': huggingface_components['text_encoder_2'],
|
|
't5xxl' : huggingface_components['text_encoder_3']
|
|
},
|
|
tokenizer_dict={
|
|
'clip_l': huggingface_components['tokenizer'],
|
|
'clip_g': huggingface_components['tokenizer_2'],
|
|
't5xxl' : huggingface_components['tokenizer_3']
|
|
}
|
|
)
|
|
|
|
k_predictor = PredictionDiscreteFlow(shift=3.0)
|
|
|
|
vae = VAE(model=huggingface_components['vae'])
|
|
|
|
unet = UnetPatcher.from_model(
|
|
model=huggingface_components['transformer'],
|
|
diffusers_scheduler= None,
|
|
k_predictor=k_predictor,
|
|
config=estimated_config
|
|
)
|
|
|
|
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
|
text_encoder=clip.cond_stage_model.clip_l,
|
|
tokenizer=clip.tokenizer.clip_l,
|
|
embedding_dir=dynamic_args['embedding_dir'],
|
|
embedding_key='clip_l',
|
|
embedding_expected_shape=768,
|
|
emphasis_name=dynamic_args['emphasis_name'],
|
|
text_projection=True,
|
|
minimal_clip_skip=1,
|
|
clip_skip=1,
|
|
return_pooled=True,
|
|
final_layer_norm=False,
|
|
)
|
|
|
|
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
|
text_encoder=clip.cond_stage_model.clip_g,
|
|
tokenizer=clip.tokenizer.clip_g,
|
|
embedding_dir=dynamic_args['embedding_dir'],
|
|
embedding_key='clip_g',
|
|
embedding_expected_shape=1280,
|
|
emphasis_name=dynamic_args['emphasis_name'],
|
|
text_projection=True,
|
|
minimal_clip_skip=1,
|
|
clip_skip=1,
|
|
return_pooled=True,
|
|
final_layer_norm=False,
|
|
)
|
|
|
|
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
|
text_encoder=clip.cond_stage_model.t5xxl,
|
|
tokenizer=clip.tokenizer.t5xxl,
|
|
emphasis_name=dynamic_args['emphasis_name'],
|
|
)
|
|
|
|
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
|
self.forge_objects_original = self.forge_objects.shallow_copy()
|
|
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
|
|
|
# WebUI Legacy
|
|
self.is_sd3 = True
|
|
|
|
def set_clip_skip(self, clip_skip):
|
|
self.text_processing_engine_l.clip_skip = clip_skip
|
|
self.text_processing_engine_g.clip_skip = clip_skip
|
|
|
|
@torch.inference_mode()
|
|
def get_learned_conditioning(self, prompt: list[str]):
|
|
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
|
|
|
cond_g, g_pooled = self.text_processing_engine_g(prompt)
|
|
cond_l, l_pooled = self.text_processing_engine_l(prompt)
|
|
if opts.sd3_enable_t5:
|
|
cond_t5 = self.text_processing_engine_t5(prompt)
|
|
else:
|
|
cond_t5 = torch.zeros([len(prompt), 256, 4096]).to(cond_l.device)
|
|
|
|
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
|
|
|
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
|
|
|
if force_zero_negative_prompt:
|
|
l_pooled = torch.zeros_like(l_pooled)
|
|
g_pooled = torch.zeros_like(g_pooled)
|
|
cond_l = torch.zeros_like(cond_l)
|
|
cond_g = torch.zeros_like(cond_g)
|
|
cond_t5 = torch.zeros_like(cond_t5)
|
|
|
|
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
|
|
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
|
|
|
|
cond = dict(
|
|
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
|
|
vector=torch.cat([l_pooled, g_pooled], dim=-1),
|
|
)
|
|
|
|
return cond
|
|
|
|
@torch.inference_mode()
|
|
def get_prompt_lengths_on_ui(self, prompt):
|
|
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
|
return token_count, max(255, token_count)
|
|
|
|
@torch.inference_mode()
|
|
def encode_first_stage(self, x):
|
|
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
|
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
|
return sample.to(x)
|
|
|
|
@torch.inference_mode()
|
|
def decode_first_stage(self, x):
|
|
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
|
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
|
|
|
return sample.to(x)
|