mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Merge pull request #2183 from graemeniedermayer/sd35_integration
sd3.5 integration (naive)
This commit is contained in:
137
backend/diffusion_engine/sd35.py
Normal file
137
backend/diffusion_engine/sd35.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
# from huggingface_guess.latent import SD3
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.modules.k_prediction import PredictionDiscreteFlow
|
||||
|
||||
class StableDiffusion3(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD35]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
'clip_g': huggingface_components['text_encoder_2'],
|
||||
't5xxl': huggingface_components['text_encoder_3']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
'clip_g': huggingface_components['tokenizer_2'],
|
||||
't5xxl': huggingface_components['tokenizer_3']
|
||||
}
|
||||
)
|
||||
|
||||
k_predictor = PredictionDiscreteFlow( shift=3.0)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler= None,
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=1280,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
)
|
||||
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd3 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_g, g_pooled = self.text_processing_engine_g(prompt)
|
||||
cond_l, l_pooled = self.text_processing_engine_l(prompt)
|
||||
# if enabled?
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
l_pooled = torch.zeros_like(l_pooled)
|
||||
g_pooled = torch.zeros_like(g_pooled)
|
||||
cond_l = torch.zeros_like(cond_l)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
cond_t5 = torch.zeros_like(cond_t5)
|
||||
|
||||
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
|
||||
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
|
||||
|
||||
cond = dict(
|
||||
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
|
||||
vector=torch.cat([l_pooled, g_pooled], dim=-1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
|
||||
return sample.to(x)
|
||||
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"_class_name": "StableDiffusion3Pipeline",
|
||||
"_diffusers_version": "0.30.3.dev0",
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"text_encoder_2": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"text_encoder_3": [
|
||||
"transformers",
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"tokenizer_2": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"tokenizer_3": [
|
||||
"transformers",
|
||||
"T5TokenizerFast"
|
||||
],
|
||||
"transformer": [
|
||||
"diffusers",
|
||||
"SD3Transformer2DModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.29.0.dev0",
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"architectures": [
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.41.2",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"architectures": [
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 1280,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.41.2",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"architectures": [
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"classifier_dropout": 0.0,
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "gelu_new",
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.41.2",
|
||||
"use_cache": true,
|
||||
"vocab_size": 32128
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
{
|
||||
"metadata": {
|
||||
"total_size": 9524621312
|
||||
},
|
||||
"weight_map": {
|
||||
"encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||
"encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||
"shared.weight": "model-00001-of-00002.safetensors"
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,125 @@
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -0,0 +1,940 @@
|
||||
{
|
||||
"add_prefix_space": true,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32000": {
|
||||
"content": "<extra_id_99>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32001": {
|
||||
"content": "<extra_id_98>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32002": {
|
||||
"content": "<extra_id_97>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32003": {
|
||||
"content": "<extra_id_96>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32004": {
|
||||
"content": "<extra_id_95>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32005": {
|
||||
"content": "<extra_id_94>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32006": {
|
||||
"content": "<extra_id_93>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32007": {
|
||||
"content": "<extra_id_92>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32008": {
|
||||
"content": "<extra_id_91>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32009": {
|
||||
"content": "<extra_id_90>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32010": {
|
||||
"content": "<extra_id_89>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32011": {
|
||||
"content": "<extra_id_88>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32012": {
|
||||
"content": "<extra_id_87>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32013": {
|
||||
"content": "<extra_id_86>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32014": {
|
||||
"content": "<extra_id_85>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32015": {
|
||||
"content": "<extra_id_84>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32016": {
|
||||
"content": "<extra_id_83>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32017": {
|
||||
"content": "<extra_id_82>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32018": {
|
||||
"content": "<extra_id_81>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32019": {
|
||||
"content": "<extra_id_80>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32020": {
|
||||
"content": "<extra_id_79>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32021": {
|
||||
"content": "<extra_id_78>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32022": {
|
||||
"content": "<extra_id_77>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32023": {
|
||||
"content": "<extra_id_76>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32024": {
|
||||
"content": "<extra_id_75>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32025": {
|
||||
"content": "<extra_id_74>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32026": {
|
||||
"content": "<extra_id_73>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32027": {
|
||||
"content": "<extra_id_72>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32028": {
|
||||
"content": "<extra_id_71>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32029": {
|
||||
"content": "<extra_id_70>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32030": {
|
||||
"content": "<extra_id_69>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32031": {
|
||||
"content": "<extra_id_68>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32032": {
|
||||
"content": "<extra_id_67>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32033": {
|
||||
"content": "<extra_id_66>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32034": {
|
||||
"content": "<extra_id_65>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32035": {
|
||||
"content": "<extra_id_64>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32036": {
|
||||
"content": "<extra_id_63>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32037": {
|
||||
"content": "<extra_id_62>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32038": {
|
||||
"content": "<extra_id_61>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32039": {
|
||||
"content": "<extra_id_60>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32040": {
|
||||
"content": "<extra_id_59>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32041": {
|
||||
"content": "<extra_id_58>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32042": {
|
||||
"content": "<extra_id_57>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32043": {
|
||||
"content": "<extra_id_56>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32044": {
|
||||
"content": "<extra_id_55>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32045": {
|
||||
"content": "<extra_id_54>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32046": {
|
||||
"content": "<extra_id_53>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32047": {
|
||||
"content": "<extra_id_52>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32048": {
|
||||
"content": "<extra_id_51>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32049": {
|
||||
"content": "<extra_id_50>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32050": {
|
||||
"content": "<extra_id_49>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32051": {
|
||||
"content": "<extra_id_48>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32052": {
|
||||
"content": "<extra_id_47>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32053": {
|
||||
"content": "<extra_id_46>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32054": {
|
||||
"content": "<extra_id_45>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32055": {
|
||||
"content": "<extra_id_44>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32056": {
|
||||
"content": "<extra_id_43>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32057": {
|
||||
"content": "<extra_id_42>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32058": {
|
||||
"content": "<extra_id_41>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32059": {
|
||||
"content": "<extra_id_40>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32060": {
|
||||
"content": "<extra_id_39>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32061": {
|
||||
"content": "<extra_id_38>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32062": {
|
||||
"content": "<extra_id_37>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32063": {
|
||||
"content": "<extra_id_36>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32064": {
|
||||
"content": "<extra_id_35>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32065": {
|
||||
"content": "<extra_id_34>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32066": {
|
||||
"content": "<extra_id_33>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32067": {
|
||||
"content": "<extra_id_32>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32068": {
|
||||
"content": "<extra_id_31>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32069": {
|
||||
"content": "<extra_id_30>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32070": {
|
||||
"content": "<extra_id_29>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32071": {
|
||||
"content": "<extra_id_28>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32072": {
|
||||
"content": "<extra_id_27>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32073": {
|
||||
"content": "<extra_id_26>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32074": {
|
||||
"content": "<extra_id_25>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32075": {
|
||||
"content": "<extra_id_24>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32076": {
|
||||
"content": "<extra_id_23>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32077": {
|
||||
"content": "<extra_id_22>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32078": {
|
||||
"content": "<extra_id_21>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32079": {
|
||||
"content": "<extra_id_20>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32080": {
|
||||
"content": "<extra_id_19>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32081": {
|
||||
"content": "<extra_id_18>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32082": {
|
||||
"content": "<extra_id_17>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32083": {
|
||||
"content": "<extra_id_16>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32084": {
|
||||
"content": "<extra_id_15>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32085": {
|
||||
"content": "<extra_id_14>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32086": {
|
||||
"content": "<extra_id_13>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32087": {
|
||||
"content": "<extra_id_12>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32088": {
|
||||
"content": "<extra_id_11>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32089": {
|
||||
"content": "<extra_id_10>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32090": {
|
||||
"content": "<extra_id_9>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32091": {
|
||||
"content": "<extra_id_8>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32092": {
|
||||
"content": "<extra_id_7>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32093": {
|
||||
"content": "<extra_id_6>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32094": {
|
||||
"content": "<extra_id_5>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32095": {
|
||||
"content": "<extra_id_4>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32096": {
|
||||
"content": "<extra_id_3>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32097": {
|
||||
"content": "<extra_id_2>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32098": {
|
||||
"content": "<extra_id_1>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32099": {
|
||||
"content": "<extra_id_0>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": true,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "</s>",
|
||||
"extra_ids": 100,
|
||||
"legacy": true,
|
||||
"model_max_length": 512,
|
||||
"pad_token": "<pad>",
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "T5Tokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"_class_name": "SD3Transformer2DModel",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"attention_head_dim": 64,
|
||||
"caption_projection_dim": 2432,
|
||||
"in_channels": 16,
|
||||
"joint_attention_dim": 4096,
|
||||
"num_attention_heads": 38,
|
||||
"num_layers": 38,
|
||||
"out_channels": 16,
|
||||
"patch_size": 2,
|
||||
"pooled_projection_dim": 2048,
|
||||
"pos_embed_max_size": 192,
|
||||
"qk_norm": "rms_norm",
|
||||
"sample_size": 128
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"_class_name": "AutoencoderKL",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": "../sdxl-vae/",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D"
|
||||
],
|
||||
"force_upcast": true,
|
||||
"in_channels": 3,
|
||||
"latent_channels": 16,
|
||||
"latents_mean": null,
|
||||
"latents_std": null,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_add_attention": true,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 3,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.5305,
|
||||
"shift_factor": 0.0609,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D"
|
||||
],
|
||||
"use_post_quant_conv": false,
|
||||
"use_quant_conv": false
|
||||
}
|
||||
@@ -19,11 +19,12 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sd35 import StableDiffusion3
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, StableDiffusion3, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -107,7 +108,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
|
||||
model_loader = None
|
||||
@@ -116,6 +117,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
if cls_name == 'SD3Transformer2DModel':
|
||||
from modules.models.sd35.mmditx import MMDiTX
|
||||
model_loader = lambda c: MMDiTX(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
||||
@@ -170,7 +174,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
return None
|
||||
|
||||
|
||||
def replace_state_dict(sd, asd, guess):
|
||||
def replace_state_dict(sd, asd, guess, is_clip_g = False):
|
||||
vae_key_prefix = guess.vae_key_prefix[0]
|
||||
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
||||
|
||||
@@ -210,11 +214,18 @@ def replace_state_dict(sd, asd, guess):
|
||||
sd[vae_key_prefix + k] = v
|
||||
|
||||
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||
if is_clip_g:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_g.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_g.transformer.{k}"] = v
|
||||
else:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||
|
||||
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
||||
@@ -241,8 +252,9 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
||||
|
||||
if isinstance(additional_state_dicts, list):
|
||||
for asd in additional_state_dicts:
|
||||
is_clip_g = 'clip_g' in asd
|
||||
asd = load_torch_file(asd)
|
||||
sd = replace_state_dict(sd, asd, guess)
|
||||
sd = replace_state_dict(sd, asd, guess, is_clip_g)
|
||||
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
guess.model_type = guess.model_type(sd)
|
||||
|
||||
@@ -251,6 +251,31 @@ class PredictionFlow(AbstractPrediction):
|
||||
return 1.0 - percent
|
||||
|
||||
|
||||
class PredictionDiscreteFlow(AbstractPrediction):
|
||||
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.0, timesteps = 1000):
|
||||
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||
self.shift = shift
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer("sigmas", ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
|
||||
class PredictionFlux(AbstractPrediction):
|
||||
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
|
||||
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||
|
||||
@@ -2,11 +2,12 @@ import torch
|
||||
import math
|
||||
|
||||
from backend.attention import attention_pytorch as attention_function
|
||||
|
||||
from transformers.activations import NewGELUActivation
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
"relu": torch.nn.functional.relu,
|
||||
"gelu_new": lambda a: NewGELUActivation()(a)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ class ClassicTextProcessingEngine:
|
||||
if self.return_pooled:
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
if self.text_projection:
|
||||
if self.text_projection and self.embedding_key is not 'clip_l':
|
||||
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
|
||||
|
||||
z.pooled = pooled_output
|
||||
|
||||
17
modules/models/sd35/LICENSE-CODE
Normal file
17
modules/models/sd35/LICENSE-CODE
Normal file
@@ -0,0 +1,17 @@
|
||||
MIT License
|
||||
|
||||
Copyright © 2024 Stability AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
918
modules/models/sd35/mmditx.py
Normal file
918
modules/models/sd35/mmditx.py
Normal file
@@ -0,0 +1,918 @@
|
||||
### This file contains impls for MM-DiT, the core model component of SD3
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from modules.models.sd35.other_impls import Mlp, attention
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple(
|
||||
[s // p for s, p in zip(self.img_size, self.patch_size)]
|
||||
)
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
grid_size,
|
||||
cls_token=False,
|
||||
extra_tokens=0,
|
||||
scaling_factor=None,
|
||||
offset=None,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
if scaling_factor is not None:
|
||||
grid = grid / scaling_factor
|
||||
if offset is not None:
|
||||
grid = grid - offset
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""Embeds scalar timesteps into vector representations."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size, frequency_embedding_size=256, dtype=None, device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class VectorEmbedder(nn.Module):
|
||||
"""Embeds a flat vector of dimension input_dim"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
attn_mode: str = "xformers",
|
||||
pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
rmsnorm: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
self.pre_only = pre_only
|
||||
|
||||
if qk_norm == "rms":
|
||||
self.ln_q = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm == "ln":
|
||||
self.ln_q = nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm is None:
|
||||
self.ln_q = nn.Identity()
|
||||
self.ln_k = nn.Identity()
|
||||
else:
|
||||
raise ValueError(qk_norm)
|
||||
|
||||
def pre_attention(self, x: torch.Tensor):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.head_dim)
|
||||
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
(q, k, v) = self.pre_attention(x)
|
||||
x = attention(q, k, v, self.num_heads)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine: bool = False,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.learnable_scale = elementwise_affine
|
||||
if self.learnable_scale:
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class DismantledBlock(nn.Module):
|
||||
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
|
||||
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: str = "xformers",
|
||||
qkv_bias: bool = False,
|
||||
pre_only: bool = False,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if not rmsnorm:
|
||||
self.norm1 = nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=pre_only,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
if not pre_only:
|
||||
if not swiglu:
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=nn.GELU(approximate="tanh"),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.mlp = SwiGLUFeedForward(
|
||||
dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.pre_only = pre_only
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
|
||||
assert x is not None, "pre_attention called with None input"
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(4, dim=1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return (
|
||||
qkv,
|
||||
qkv2,
|
||||
(
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
),
|
||||
)
|
||||
|
||||
def post_attention_x(
|
||||
self,
|
||||
attn,
|
||||
attn2,
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
attn1_dropout: float = 0.0,
|
||||
):
|
||||
assert not self.pre_only
|
||||
if attn1_dropout > 0.0:
|
||||
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
||||
attn1_dropout = torch.bernoulli(
|
||||
torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)
|
||||
)
|
||||
attn_ = (
|
||||
gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
||||
)
|
||||
else:
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + attn_
|
||||
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
||||
x = x + attn2_
|
||||
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
x = x + mlp_
|
||||
return x, (gate_msa, gate_msa2, gate_mlp, attn_, attn2_)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
if self.x_block_self_attn:
|
||||
(q, k, v), (q2, k2, v2), intermediates = self.pre_attention_x(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
attn2 = attention(q2, k2, v2, self.attn2.num_heads)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
(q, k, v), intermediates = self.pre_attention(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(context, x, context_block, x_block, c):
|
||||
assert context is not None, "block_mixing called with None context"
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||
q, k, v = tuple(o)
|
||||
|
||||
attn = attention(q, k, v, x_block.attn.num_heads)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
attn[:, context_qkv[0].shape[1] :],
|
||||
)
|
||||
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
else:
|
||||
context = None
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_q2, x_k2, x_v2 = x_qkv2
|
||||
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
|
||||
return context, x
|
||||
|
||||
|
||||
class JointBlock(nn.Module):
|
||||
"""just a small wrapper to serve as a fsdp unit"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(
|
||||
*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs
|
||||
)
|
||||
self.x_block = DismantledBlock(
|
||||
*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
*args, context_block=self.context_block, x_block=self.x_block, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
total_out_channels: Optional[int] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.linear = (
|
||||
nn.Linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if (total_out_channels is None)
|
||||
else nn.Linear(
|
||||
hidden_size, total_out_channels, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MMDiTX(nn.Module):
|
||||
"""Diffusion model with a Transformer backbone."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 32,
|
||||
patch_size: int = 4,
|
||||
in_channels: int = 4,
|
||||
depth: int = 28,
|
||||
mlp_ratio: float = 4.0,
|
||||
learn_sigma: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
context_embedder_config: Optional[Dict] = None,
|
||||
register_length: int = 0,
|
||||
attn_mode: str = "torch",
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
pos_embed_scaling_factor: Optional[float] = None,
|
||||
pos_embed_offset: Optional[float] = None,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches=None,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||
qkv_bias: bool = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
# input_size = 800
|
||||
if verbose:
|
||||
print(
|
||||
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.out_channels = (
|
||||
out_channels if out_channels is not None else default_out_channels
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||
|
||||
# apply magic --> this defines a head_size of 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.x_embedder = PatchEmbed(
|
||||
32,
|
||||
2,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=self.pos_embed_max_size is None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
if adm_in_channels is not None:
|
||||
assert isinstance(adm_in_channels, int)
|
||||
self.y_embedder = VectorEmbedder(
|
||||
adm_in_channels, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Identity()
|
||||
if context_embedder_config is not None:
|
||||
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||
self.context_embedder = nn.Linear(
|
||||
**context_embedder_config["params"], dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.register_length = register_length
|
||||
if self.register_length > 0:
|
||||
self.register = nn.Parameter(
|
||||
torch.randn(1, register_length, hidden_size, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
# just use a buffer already
|
||||
if num_patches is not None:
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=i == depth - 1,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def cropped_pos_embed(self, hw):
|
||||
assert self.pos_embed_max_size is not None
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
# patched size
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
spatial_pos_embed = rearrange(
|
||||
self.pos_embed,
|
||||
"1 (h w) c -> 1 h w c",
|
||||
h=self.pos_embed_max_size,
|
||||
w=self.pos_embed_max_size,
|
||||
)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x, hw=None):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
if hw is None:
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
else:
|
||||
h, w = hw
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def forward_core_with_concat(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||
context if context is not None else torch.Tensor([]).type_as(x),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
for block in self.joint_blocks:
|
||||
context, x = block(context, x, c=c_mod)
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None, control=None, transformer_options={}, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
hw = x.shape[-2:]
|
||||
# The line below should be unnecessary when full integrated.
|
||||
x = x[:1,:16,:,:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw).to("cuda")
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
c = c + y # (N, D)
|
||||
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x
|
||||
868
modules/models/sd35/other_impls.py
Normal file
868
modules/models/sd35/other_impls.py
Normal file
@@ -0,0 +1,868 @@
|
||||
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
#################################################################################################
|
||||
### Core/Utility
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
bias=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.fc1 = nn.Linear(
|
||||
in_features, hidden_features, bias=bias, dtype=dtype, device=device
|
||||
)
|
||||
self.act = act_layer
|
||||
self.fc2 = nn.Linear(
|
||||
hidden_features, out_features, bias=bias, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### CLIP
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.out_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
out = attention(q, k, v, self.heads, mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
ACTIVATIONS = {
|
||||
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
}
|
||||
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(
|
||||
embed_dim,
|
||||
intermediate_size,
|
||||
embed_dim,
|
||||
act_layer=ACTIVATIONS[intermediate_activation],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x += self.self_attn(self.layer_norm1(x), mask)
|
||||
x += self.mlp(self.layer_norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
CLIPLayer(
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x, mask)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.token_embedding = torch.nn.Embedding(
|
||||
vocab_size, embed_dim, dtype=dtype, device=device
|
||||
)
|
||||
self.position_embedding = torch.nn.Embedding(
|
||||
num_positions, embed_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.encoder = CLIPEncoder(
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True
|
||||
):
|
||||
x = self.embeddings(input_tokens)
|
||||
causal_mask = (
|
||||
torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
|
||||
.fill_(float("-inf"))
|
||||
.triu_(1)
|
||||
)
|
||||
x, i = self.encoder(
|
||||
x, mask=causal_mask, intermediate_output=intermediate_output
|
||||
)
|
||||
x = self.final_layer_norm(x)
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
pooled_output = x[
|
||||
torch.arange(x.shape[0], device=x.device),
|
||||
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
|
||||
]
|
||||
return x, i, pooled_output
|
||||
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = nn.Linear(
|
||||
embed_dim, embed_dim, bias=False, dtype=dtype, device=device
|
||||
)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.text_model(*args, **kwargs)
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
nesting_level = 0
|
||||
for char in string:
|
||||
if char == "(":
|
||||
if nesting_level == 0:
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
current_item = "("
|
||||
else:
|
||||
current_item = "("
|
||||
else:
|
||||
current_item += char
|
||||
nesting_level += 1
|
||||
elif char == ")":
|
||||
nesting_level -= 1
|
||||
if nesting_level == 0:
|
||||
result.append(current_item + ")")
|
||||
current_item = ""
|
||||
else:
|
||||
current_item += char
|
||||
else:
|
||||
current_item += char
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
return result
|
||||
|
||||
|
||||
def token_weights(string, current_weight):
|
||||
a = parse_parentheses(string)
|
||||
out = []
|
||||
for x in a:
|
||||
weight = current_weight
|
||||
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
|
||||
x = x[1:-1]
|
||||
xx = x.rfind(":")
|
||||
weight *= 1.1
|
||||
if xx > 0:
|
||||
try:
|
||||
weight = float(x[xx + 1 :])
|
||||
x = x[:xx]
|
||||
except:
|
||||
pass
|
||||
out += token_weights(x, weight)
|
||||
else:
|
||||
out += [(x, current_weight)]
|
||||
return out
|
||||
|
||||
|
||||
def escape_important(text):
|
||||
text = text.replace("\\)", "\0\1")
|
||||
text = text.replace("\\(", "\0\2")
|
||||
return text
|
||||
|
||||
|
||||
def unescape_important(text):
|
||||
text = text.replace("\0\1", ")")
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
max_length=77,
|
||||
pad_with_end=True,
|
||||
tokenizer=None,
|
||||
has_start_token=True,
|
||||
pad_to_max_length=True,
|
||||
min_length=None,
|
||||
extra_padding_token=None,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
|
||||
empty = self.tokenizer("")["input_ids"]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
self.extra_padding_token = extra_padding_token
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.max_word_length = 8
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
"""
|
||||
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
|
||||
"""
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
# tokenize words
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = (
|
||||
unescape_important(weighted_segment).replace("\n", " ").split(" ")
|
||||
)
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
# parse word
|
||||
tokens.append(
|
||||
[
|
||||
(t, weight)
|
||||
for t in self.tokenizer(word)["input_ids"][
|
||||
self.tokens_start : -1
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
# reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
# determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
# break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend(
|
||||
[(t, w, i + 1) for t, w in t_group[:remaining_length]]
|
||||
)
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
# add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
# start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group])
|
||||
t_group = []
|
||||
|
||||
# pad extra padding token first befor getting to the end token
|
||||
if self.extra_padding_token is not None:
|
||||
batch.extend(
|
||||
[(self.extra_padding_token, 1.0, 0)]
|
||||
* (self.min_length - len(batch) - 1)
|
||||
)
|
||||
# fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||
|
||||
return batched_tokens
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
class SDXLClipGTokenizer(SDTokenizer):
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(pad_with_end=False, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self):
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||
self.t5xxl = T5XXLTokenizer()
|
||||
|
||||
def tokenize_with_weights(self, text: str):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
|
||||
return out
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
|
||||
out, pooled = self([tokens])
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled
|
||||
output = [out[0:1]]
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
max_length=77,
|
||||
layer="last",
|
||||
layer_idx=None,
|
||||
textmodel_json_config=None,
|
||||
dtype=None,
|
||||
model_class=CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
|
||||
layer_norm_hidden_state=True,
|
||||
return_projected_pooled=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
self.max_length = max_length
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (
|
||||
self.layer,
|
||||
self.layer_idx,
|
||||
self.return_projected_pooled,
|
||||
)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get(
|
||||
"projected_pooled", self.return_projected_pooled
|
||||
)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
outputs = self.transformer(
|
||||
tokens,
|
||||
intermediate_output=self.layer_idx,
|
||||
final_layer_norm_intermediate=self.layer_norm_hidden_state,
|
||||
)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs[1]
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if (
|
||||
not self.return_projected_pooled
|
||||
and len(outputs) >= 4
|
||||
and outputs[3] is not None
|
||||
):
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
class SDXLClipG(SDClipModel):
|
||||
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
|
||||
|
||||
def __init__(
|
||||
self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None
|
||||
):
|
||||
if layer == "penultimate":
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
)
|
||||
|
||||
|
||||
class T5XXLModel(SDClipModel):
|
||||
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
|
||||
|
||||
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"end": 1, "pad": 0},
|
||||
model_class=T5,
|
||||
)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
pad_with_end=False,
|
||||
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
|
||||
has_start_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=77,
|
||||
)
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
forwarded_states = self.layer_norm(x)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
x += forwarded_states
|
||||
return x
|
||||
|
||||
|
||||
class T5Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
|
||||
):
|
||||
super().__init__()
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(
|
||||
self.relative_attention_num_buckets, self.num_heads, device=device
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||
):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(
|
||||
relative_position, torch.zeros_like(relative_position)
|
||||
)
|
||||
# now relative_position is in the range [0, inf)
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_position_if_large = torch.min(
|
||||
relative_position_if_large,
|
||||
torch.full_like(relative_position_if_large, num_buckets - 1),
|
||||
)
|
||||
relative_buckets += torch.where(
|
||||
is_small, relative_position, relative_position_if_large
|
||||
)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
|
||||
:, None
|
||||
]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
|
||||
None, :
|
||||
]
|
||||
relative_position = (
|
||||
memory_position - context_position
|
||||
) # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=True,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(
|
||||
relative_position_bucket
|
||||
) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(
|
||||
0
|
||||
) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
if past_bias is not None:
|
||||
mask = past_bias
|
||||
out = attention(
|
||||
q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask
|
||||
)
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
||||
class T5LayerSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(
|
||||
model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
|
||||
)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
|
||||
x += output
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(
|
||||
T5LayerSelfAttention(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
)
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
x, past_bias = self.layer[0](x, past_bias)
|
||||
x = self.layer[-1](x)
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
vocab_size,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
||||
self.block = torch.nn.ModuleList(
|
||||
[
|
||||
T5Block(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True
|
||||
):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids)
|
||||
past_bias = None
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class T5(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
self.encoder = T5Stack(
|
||||
self.num_layers,
|
||||
config_dict["d_model"],
|
||||
config_dict["d_model"],
|
||||
config_dict["d_ff"],
|
||||
config_dict["num_heads"],
|
||||
config_dict["vocab_size"],
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.encoder.embed_tokens = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.encoder(*args, **kwargs)
|
||||
222
modules/models/sd35/sd3_cond.py
Normal file
222
modules/models/sd35/sd3_cond.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
import safetensors
|
||||
import torch
|
||||
import typing
|
||||
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
|
||||
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
|
||||
|
||||
|
||||
class SafetensorsMapping(typing.Mapping):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file.keys())
|
||||
|
||||
def __iter__(self):
|
||||
for key in self.file.keys():
|
||||
yield key
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.file.get_tensor(key)
|
||||
|
||||
|
||||
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"textual_inversion_key": "clip_g",
|
||||
}
|
||||
|
||||
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
|
||||
def __init__(self, clip_l, clip_g):
|
||||
super().__init__()
|
||||
|
||||
self.clip_l = clip_l
|
||||
self.clip_g = clip_g
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.id_start = empty[0]
|
||||
self.id_end = empty[1]
|
||||
self.id_pad = empty[1]
|
||||
|
||||
self.return_pooled = True
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
tokens_g = tokens.clone()
|
||||
|
||||
for batch_pos in range(tokens_g.shape[0]):
|
||||
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
|
||||
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
|
||||
|
||||
l_out, l_pooled = self.clip_l(tokens)
|
||||
g_out, g_pooled = self.clip_g(tokens_g)
|
||||
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
|
||||
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
lg_out.pooled = vector_out
|
||||
return lg_out
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
|
||||
|
||||
|
||||
class Sd3T5(torch.nn.Module):
|
||||
def __init__(self, t5xxl):
|
||||
super().__init__()
|
||||
|
||||
self.t5xxl = t5xxl
|
||||
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
|
||||
|
||||
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
|
||||
self.id_end = empty[0]
|
||||
self.id_pad = empty[1]
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def tokenize_line(self, line, *, target_token_count=None):
|
||||
if shared.opts.emphasis != "None":
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
|
||||
for text_tokens, (text, weight) in zip(tokenized, parsed):
|
||||
if text == 'BREAK' and weight == -1:
|
||||
continue
|
||||
|
||||
tokens += text_tokens
|
||||
multipliers += [weight] * len(text_tokens)
|
||||
|
||||
tokens += [self.id_end]
|
||||
multipliers += [1.0]
|
||||
|
||||
if target_token_count is not None:
|
||||
if len(tokens) < target_token_count:
|
||||
tokens += [self.id_pad] * (target_token_count - len(tokens))
|
||||
multipliers += [1.0] * (target_token_count - len(tokens))
|
||||
else:
|
||||
tokens = tokens[0:target_token_count]
|
||||
multipliers = multipliers[0:target_token_count]
|
||||
|
||||
return tokens, multipliers
|
||||
|
||||
def forward(self, texts, *, token_count):
|
||||
if not self.t5xxl or not shared.opts.sd3_enable_t5:
|
||||
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
|
||||
|
||||
tokens_batch = []
|
||||
|
||||
for text in texts:
|
||||
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
|
||||
tokens_batch.append(tokens)
|
||||
|
||||
t5_out, t5_pooled = self.t5xxl(tokens_batch)
|
||||
|
||||
return t5_out
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
|
||||
|
||||
|
||||
class SD3Cond(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
|
||||
if shared.opts.sd3_enable_t5:
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||
self.model_t5 = Sd3T5(self.t5xxl)
|
||||
|
||||
def forward(self, prompts: list[str]):
|
||||
with devices.without_autocast():
|
||||
lg_out, vector_out = self.model_lg(prompts)
|
||||
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
|
||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
|
||||
return {
|
||||
'crossattn': lgt_out,
|
||||
'vector': vector_out,
|
||||
}
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||
|
||||
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
|
||||
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.model_lg.tokenize(texts)
|
||||
|
||||
def medvram_modules(self):
|
||||
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||
|
||||
def get_token_count(self, text):
|
||||
_, token_count = self.model_lg.process_texts([text])
|
||||
|
||||
return token_count
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
return self.model_lg.get_target_prompt_token_count(token_count)
|
||||
623
modules/models/sd35/sd3_impls.py
Normal file
623
modules/models/sd35/sd3_impls.py
Normal file
@@ -0,0 +1,623 @@
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules.models.sd35.mmditx import MMDiTX
|
||||
|
||||
#################################################################################################
|
||||
### MMDiT Model Wrapping
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||
|
||||
def __init__(self, shift=1.0):
|
||||
super().__init__()
|
||||
self.shift = shift
|
||||
timesteps = 1000
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer("sigmas", ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core MM-DiT model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
shift=1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
prefix = ''
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
qk_norm = (
|
||||
"rms"
|
||||
if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys()
|
||||
else None
|
||||
)
|
||||
x_block_self_attn_layers = sorted(
|
||||
[
|
||||
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
|
||||
for key in list(
|
||||
filter(
|
||||
re.compile(".*.x_block.attn2.ln_k.weight").match, state_dict.keys()
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
"in_features": context_shape[1],
|
||||
"out_features": context_shape[0],
|
||||
},
|
||||
}
|
||||
self.diffusion_model = MMDiTX(
|
||||
input_size=None,
|
||||
pos_embed_scaling_factor=None,
|
||||
pos_embed_offset=None,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=16,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
adm_in_channels=adm_in_channels,
|
||||
context_embedder_config=context_embedder_config,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
# device=kwargs['device'],
|
||||
# dtype=kwargs['dtype'],
|
||||
# verbose=kwargs['verbose'],
|
||||
# **kwargs
|
||||
)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
|
||||
def apply_model(self, x, sigma, y=None, *args, **kwargs):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
model_output = self.diffusion_model(
|
||||
x.to(dtype), timestep, context=kwargs["context"].to(dtype), y=y.to(dtype)
|
||||
).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""Helper for applying CFG Scaling to diffusion outputs"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, x, timestep, cond, uncond, cond_scale):
|
||||
# Run cond and uncond in a batch together
|
||||
batched = self.model.apply_model(
|
||||
torch.cat([x, x]),
|
||||
torch.cat([timestep, timestep]),
|
||||
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
|
||||
y=torch.cat([cond["y"], uncond["y"]]),
|
||||
)
|
||||
# Then split and apply CFG Scaling
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
scaled = neg_out + (pos_out - neg_out) * cond_scale
|
||||
return scaled
|
||||
|
||||
|
||||
class SD3LatentFormat:
|
||||
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
"""Quick RGB approximate preview of sd3 latents"""
|
||||
factors = torch.tensor(
|
||||
[
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[0.0028, 0.0312, 0.0650],
|
||||
[0.1848, 0.0762, 0.0360],
|
||||
[0.0944, 0.0360, 0.0889],
|
||||
[0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259],
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
|
||||
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Samplers
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_euler(model, x, sigmas, extra_args=None):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_dpmpp_2m(model, x, sigmas, extra_args=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_denoised is None or sigmas[i + 1] == 0:
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### VAE
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self, *, in_channels, out_channels=None, dtype=torch.float32, device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = None
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = x
|
||||
hidden = self.norm1(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv1(hidden)
|
||||
hidden = self.norm2(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv2(hidden)
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class AttnBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.norm(x)
|
||||
q = self.q(hidden)
|
||||
k = self.k(hidden)
|
||||
v = self.v(hidden)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(
|
||||
lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
hidden = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v
|
||||
) # scale is dim ** -0.5 per default
|
||||
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
hidden = self.proj_out(hidden)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class Downsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class VAEEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
in_channels=3,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = torch.nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = torch.nn.ModuleList()
|
||||
attn = torch.nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = torch.nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
||||
self.down.append(down)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class VAEDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
resolution=256,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
# upsampling
|
||||
self.up = torch.nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = torch.nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = torch.nn.Module()
|
||||
up.block = block
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
hidden = self.conv_in(z)
|
||||
# middle
|
||||
hidden = self.mid.block_1(hidden)
|
||||
hidden = self.mid.attn_1(hidden)
|
||||
hidden = self.mid.block_2(hidden)
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
hidden = self.up[i_level].block[i_block](hidden)
|
||||
if i_level != 0:
|
||||
hidden = self.up[i_level].upsample(hidden)
|
||||
# end
|
||||
hidden = self.norm_out(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv_out(hidden)
|
||||
return hidden
|
||||
|
||||
|
||||
class SDVAE(torch.nn.Module):
|
||||
def __init__(self, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
||||
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def decode(self, latent):
|
||||
return self.decoder(latent)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def encode(self, image):
|
||||
hidden = self.encoder(image)
|
||||
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
||||
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
436
modules/models/sd35/sd3_infer.py
Normal file
436
modules/models/sd35/sd3_infer.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# NOTE: Must have folder `models` with the following files:
|
||||
# - `clip_g.safetensors` (openclip bigG, same as SDXL)
|
||||
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
|
||||
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
|
||||
# - `sd3_medium.safetensors` (or whichever main MMDiT model file)
|
||||
# Also can have
|
||||
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules.models.sd35 import sd3_impls
|
||||
from modules.models.sd35.other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
|
||||
from modules.models.sd35.sd3_impls import SDVAE, BaseModel, CFGDenoiser, SD3LatentFormat
|
||||
|
||||
#################################################################################################
|
||||
### Wrappers for model parts
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def load_into(f, model, prefix, device, dtype=None):
|
||||
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
|
||||
for key in f.keys():
|
||||
if key.startswith(prefix) and not key.startswith("loss."):
|
||||
path = key[len(prefix) :].split(".")
|
||||
obj = model
|
||||
for p in path:
|
||||
if obj is list:
|
||||
obj = obj[int(p)]
|
||||
else:
|
||||
obj = getattr(obj, p, None)
|
||||
if obj is None:
|
||||
print(
|
||||
f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model"
|
||||
)
|
||||
break
|
||||
if obj is None:
|
||||
continue
|
||||
try:
|
||||
tensor = f.get_tensor(key).to(device=device)
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
obj.requires_grad_(False)
|
||||
obj.set_(tensor)
|
||||
except Exception as e:
|
||||
print(f"Failed to load key '{key}' in safetensors file: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
}
|
||||
|
||||
|
||||
class ClipG:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_g.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
|
||||
class ClipL:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_l.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDClipModel(
|
||||
layer="hidden",
|
||||
layer_idx=-2,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
layer_norm_hidden_state=False,
|
||||
return_projected_pooled=False,
|
||||
textmodel_json_config=CLIPL_CONFIG,
|
||||
)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class T5XXL:
|
||||
def __init__(self):
|
||||
with safe_open("models/t5xxl.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
class SD3:
|
||||
def __init__(self, model, shift, verbose=False):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = BaseModel(
|
||||
shift=shift,
|
||||
file=f,
|
||||
prefix="model.diffusion_model.",
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
verbose=verbose,
|
||||
).eval()
|
||||
load_into(f, self.model, "model.", "cpu", torch.float16)
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, model):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
|
||||
prefix = ""
|
||||
if any(k.startswith("first_stage_model.") for k in f.keys()):
|
||||
prefix = "first_stage_model."
|
||||
load_into(f, self.model, prefix, "cpu", torch.float16)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Main inference logic
|
||||
#################################################################################################
|
||||
|
||||
|
||||
# Note: Sigma shift value, publicly released models use 3.0
|
||||
SHIFT = 3.0
|
||||
# Naturally, adjust to the width/height of the model you have
|
||||
WIDTH = 1024
|
||||
HEIGHT = 1024
|
||||
# Pick your prompt
|
||||
PROMPT = "a photo of a cat"
|
||||
# Most models prefer the range of 4-5, but still work well around 7
|
||||
CFG_SCALE = 4.5
|
||||
# Different models want different step counts but most will be good at 50, albeit that's slow to run
|
||||
# sd3_medium is quite decent at 28 steps
|
||||
STEPS = 40
|
||||
# Seed
|
||||
SEED = 23
|
||||
# SEEDTYPE = "fixed"
|
||||
SEEDTYPE = "rand"
|
||||
# SEEDTYPE = "roll"
|
||||
# Actual model file path
|
||||
# MODEL = "models/sd3_medium.safetensors"
|
||||
# MODEL = "models/sd3.5_large_turbo.safetensors"
|
||||
MODEL = "models/sd3.5_large.safetensors"
|
||||
# VAE model file path, or set None to use the same model file
|
||||
VAEFile = None # "models/sd3_vae.safetensors"
|
||||
# Optional init image file path
|
||||
INIT_IMAGE = None
|
||||
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
|
||||
DENOISE = 0.6
|
||||
# Output file path
|
||||
OUTDIR = "outputs"
|
||||
# SAMPLER
|
||||
# SAMPLER = "euler"
|
||||
SAMPLER = "dpmpp_2m"
|
||||
|
||||
|
||||
class SD3Inferencer:
|
||||
def print(self, txt):
|
||||
if self.verbose:
|
||||
print(txt)
|
||||
|
||||
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT, verbose=False):
|
||||
self.verbose = verbose
|
||||
print("Loading tokenizers...")
|
||||
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
|
||||
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
|
||||
# (T5 tokenizer is different though)
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
print("Loading OpenAI CLIP L...")
|
||||
self.clip_l = ClipL()
|
||||
print("Loading OpenCLIP bigG...")
|
||||
self.clip_g = ClipG()
|
||||
print("Loading Google T5-v1-XXL...")
|
||||
self.t5xxl = T5XXL()
|
||||
print(f"Loading SD3 model {os.path.basename(model)}...")
|
||||
self.sd3 = SD3(model, shift, verbose)
|
||||
print("Loading VAE model...")
|
||||
self.vae = VAE(vae or model)
|
||||
print("Models loaded.")
|
||||
|
||||
def get_empty_latent(self, width, height):
|
||||
self.print("Prep an empty latent...")
|
||||
return torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
|
||||
|
||||
def get_sigmas(self, sampling, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(sampling.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def get_noise(self, seed, latent):
|
||||
generator = torch.manual_seed(seed)
|
||||
self.print(
|
||||
f"dtype = {latent.dtype}, layout = {latent.layout}, device = {latent.device}"
|
||||
)
|
||||
return torch.randn(
|
||||
latent.size(),
|
||||
dtype=torch.float32,
|
||||
layout=latent.layout,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
).to(latent.dtype)
|
||||
|
||||
def get_cond(self, prompt):
|
||||
self.print("Encode prompt...")
|
||||
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
||||
l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
|
||||
g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
|
||||
t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
return torch.cat([lg_out, t5_out], dim=-2), torch.cat(
|
||||
(l_pooled, g_pooled), dim=-1
|
||||
)
|
||||
|
||||
def max_denoise(self, sigmas):
|
||||
max_sigma = float(self.sd3.model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
def fix_cond(self, cond):
|
||||
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
|
||||
return {"c_crossattn": cond, "y": pooled}
|
||||
|
||||
def do_sampling(
|
||||
self,
|
||||
latent,
|
||||
seed,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler="dpmpp_2m",
|
||||
denoise=1.0,
|
||||
) -> torch.Tensor:
|
||||
self.print("Sampling...")
|
||||
latent = latent.half().cuda()
|
||||
self.sd3.model = self.sd3.model.cuda()
|
||||
noise = self.get_noise(seed, latent).cuda()
|
||||
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
|
||||
sigmas = sigmas[int(steps * (1 - denoise)) :]
|
||||
conditioning = self.fix_cond(conditioning)
|
||||
neg_cond = self.fix_cond(neg_cond)
|
||||
extra_args = {"cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale}
|
||||
noise_scaled = self.sd3.model.model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent, self.max_denoise(sigmas)
|
||||
)
|
||||
sample_fn = getattr(sd3_impls, f"sample_{sampler}")
|
||||
latent = sample_fn(
|
||||
CFGDenoiser(self.sd3.model), noise_scaled, sigmas, extra_args=extra_args
|
||||
)
|
||||
latent = SD3LatentFormat().process_out(latent)
|
||||
self.sd3.model = self.sd3.model.cpu()
|
||||
self.print("Sampling done")
|
||||
return latent
|
||||
|
||||
def vae_encode(self, image) -> torch.Tensor:
|
||||
self.print("Encoding image to latent...")
|
||||
image = image.convert("RGB")
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
image_np = np.moveaxis(image_np, 2, 0)
|
||||
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
|
||||
image_torch = torch.from_numpy(batch_images)
|
||||
image_torch = 2.0 * image_torch - 1.0
|
||||
image_torch = image_torch.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
latent = self.vae.model.encode(image_torch).cpu()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
self.print("Encoded")
|
||||
return latent
|
||||
|
||||
def vae_decode(self, latent) -> Image.Image:
|
||||
self.print("Decoding latent to image...")
|
||||
latent = latent.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
image = self.vae.model.decode(latent)
|
||||
image = image.float()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
out_image = Image.fromarray(decoded_np)
|
||||
self.print("Decoded")
|
||||
return out_image
|
||||
|
||||
def gen_image(
|
||||
self,
|
||||
prompts=[PROMPT],
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
steps=STEPS,
|
||||
cfg_scale=CFG_SCALE,
|
||||
sampler=SAMPLER,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
out_dir=OUTDIR,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
):
|
||||
latent = self.get_empty_latent(width, height)
|
||||
if init_image:
|
||||
image_data = Image.open(init_image)
|
||||
image_data = image_data.resize((width, height), Image.LANCZOS)
|
||||
latent = self.vae_encode(image_data)
|
||||
latent = SD3LatentFormat().process_in(latent)
|
||||
neg_cond = self.get_cond("")
|
||||
seed_num = None
|
||||
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True)
|
||||
for i, prompt in pbar:
|
||||
if seed_type == "roll":
|
||||
seed_num = seed if seed_num is None else seed_num + 1
|
||||
elif seed_type == "rand":
|
||||
seed_num = torch.randint(0, 100000, (1,)).item()
|
||||
else: # fixed
|
||||
seed_num = seed
|
||||
conditioning = self.get_cond(prompt)
|
||||
sampled_latent = self.do_sampling(
|
||||
latent,
|
||||
seed_num,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler,
|
||||
denoise if init_image else 1.0,
|
||||
)
|
||||
image = self.vae_decode(sampled_latent)
|
||||
save_path = os.path.join(out_dir, f"{i:06d}.png")
|
||||
self.print(f"Will save to {save_path}")
|
||||
image.save(save_path)
|
||||
self.print("Done")
|
||||
|
||||
|
||||
CONFIGS = {
|
||||
"sd3_medium": {
|
||||
"shift": 1.0,
|
||||
"cfg": 5.0,
|
||||
"steps": 50,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large": {
|
||||
"shift": 3.0,
|
||||
"cfg": 4.5,
|
||||
"steps": 40,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(
|
||||
prompt=PROMPT,
|
||||
model=MODEL,
|
||||
out_dir=OUTDIR,
|
||||
postfix=None,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
sampler=None,
|
||||
steps=None,
|
||||
cfg=None,
|
||||
shift=None,
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
vae=VAEFile,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
verbose=False,
|
||||
):
|
||||
steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
|
||||
cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
|
||||
shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
|
||||
sampler = (
|
||||
sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
|
||||
)
|
||||
|
||||
inferencer = SD3Inferencer()
|
||||
inferencer.load(model, vae, shift, verbose)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if os.path.splitext(prompt)[-1] == ".txt":
|
||||
with open(prompt, "r") as f:
|
||||
prompts = [l.strip() for l in f.readlines()]
|
||||
else:
|
||||
prompts = [prompt]
|
||||
|
||||
out_dir = os.path.join(
|
||||
out_dir,
|
||||
os.path.splitext(os.path.basename(model))[0],
|
||||
os.path.splitext(os.path.basename(prompt))[0][:50]
|
||||
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
|
||||
)
|
||||
print(f"Saving images to {out_dir}")
|
||||
os.makedirs(out_dir, exist_ok=False)
|
||||
|
||||
inferencer.gen_image(
|
||||
prompts,
|
||||
width,
|
||||
height,
|
||||
steps,
|
||||
cfg,
|
||||
sampler,
|
||||
seed,
|
||||
seed_type,
|
||||
out_dir,
|
||||
init_image,
|
||||
denoise,
|
||||
)
|
||||
|
||||
|
||||
fire.Fire(main)
|
||||
Reference in New Issue
Block a user