mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 02:31:16 +00:00
Merge pull request #2183 from graemeniedermayer/sd35_integration
sd3.5 integration (naive)
This commit is contained in:
137
backend/diffusion_engine/sd35.py
Normal file
137
backend/diffusion_engine/sd35.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from huggingface_guess import model_list
|
||||||
|
# from huggingface_guess.latent import SD3
|
||||||
|
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||||
|
from backend.patcher.clip import CLIP
|
||||||
|
from backend.patcher.vae import VAE
|
||||||
|
from backend.patcher.unet import UnetPatcher
|
||||||
|
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||||
|
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||||
|
from backend.args import dynamic_args
|
||||||
|
from backend import memory_management
|
||||||
|
from backend.modules.k_prediction import PredictionDiscreteFlow
|
||||||
|
|
||||||
|
class StableDiffusion3(ForgeDiffusionEngine):
|
||||||
|
matched_guesses = [model_list.SD35]
|
||||||
|
|
||||||
|
def __init__(self, estimated_config, huggingface_components):
|
||||||
|
super().__init__(estimated_config, huggingface_components)
|
||||||
|
|
||||||
|
clip = CLIP(
|
||||||
|
model_dict={
|
||||||
|
'clip_l': huggingface_components['text_encoder'],
|
||||||
|
'clip_g': huggingface_components['text_encoder_2'],
|
||||||
|
't5xxl': huggingface_components['text_encoder_3']
|
||||||
|
},
|
||||||
|
tokenizer_dict={
|
||||||
|
'clip_l': huggingface_components['tokenizer'],
|
||||||
|
'clip_g': huggingface_components['tokenizer_2'],
|
||||||
|
't5xxl': huggingface_components['tokenizer_3']
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
k_predictor = PredictionDiscreteFlow( shift=3.0)
|
||||||
|
|
||||||
|
vae = VAE(model=huggingface_components['vae'])
|
||||||
|
|
||||||
|
unet = UnetPatcher.from_model(
|
||||||
|
model=huggingface_components['transformer'],
|
||||||
|
diffusers_scheduler= None,
|
||||||
|
k_predictor=k_predictor,
|
||||||
|
config=estimated_config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||||
|
text_encoder=clip.cond_stage_model.clip_l,
|
||||||
|
tokenizer=clip.tokenizer.clip_l,
|
||||||
|
embedding_dir=dynamic_args['embedding_dir'],
|
||||||
|
embedding_key='clip_l',
|
||||||
|
embedding_expected_shape=768,
|
||||||
|
emphasis_name=dynamic_args['emphasis_name'],
|
||||||
|
text_projection=True,
|
||||||
|
minimal_clip_skip=1,
|
||||||
|
clip_skip=1,
|
||||||
|
return_pooled=True,
|
||||||
|
final_layer_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||||
|
text_encoder=clip.cond_stage_model.clip_g,
|
||||||
|
tokenizer=clip.tokenizer.clip_g,
|
||||||
|
embedding_dir=dynamic_args['embedding_dir'],
|
||||||
|
embedding_key='clip_g',
|
||||||
|
embedding_expected_shape=1280,
|
||||||
|
emphasis_name=dynamic_args['emphasis_name'],
|
||||||
|
text_projection=True,
|
||||||
|
minimal_clip_skip=1,
|
||||||
|
clip_skip=1,
|
||||||
|
return_pooled=True,
|
||||||
|
final_layer_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||||
|
text_encoder=clip.cond_stage_model.t5xxl,
|
||||||
|
tokenizer=clip.tokenizer.t5xxl,
|
||||||
|
emphasis_name=dynamic_args['emphasis_name'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||||
|
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||||
|
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||||
|
|
||||||
|
# WebUI Legacy
|
||||||
|
self.is_sd3 = True
|
||||||
|
|
||||||
|
def set_clip_skip(self, clip_skip):
|
||||||
|
self.text_processing_engine_l.clip_skip = clip_skip
|
||||||
|
self.text_processing_engine_g.clip_skip = clip_skip
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_learned_conditioning(self, prompt: list[str]):
|
||||||
|
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||||
|
|
||||||
|
cond_g, g_pooled = self.text_processing_engine_g(prompt)
|
||||||
|
cond_l, l_pooled = self.text_processing_engine_l(prompt)
|
||||||
|
# if enabled?
|
||||||
|
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||||
|
|
||||||
|
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||||
|
|
||||||
|
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||||
|
|
||||||
|
if force_zero_negative_prompt:
|
||||||
|
l_pooled = torch.zeros_like(l_pooled)
|
||||||
|
g_pooled = torch.zeros_like(g_pooled)
|
||||||
|
cond_l = torch.zeros_like(cond_l)
|
||||||
|
cond_g = torch.zeros_like(cond_g)
|
||||||
|
cond_t5 = torch.zeros_like(cond_t5)
|
||||||
|
|
||||||
|
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
|
||||||
|
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
|
||||||
|
|
||||||
|
cond = dict(
|
||||||
|
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
|
||||||
|
vector=torch.cat([l_pooled, g_pooled], dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cond
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_prompt_lengths_on_ui(self, prompt):
|
||||||
|
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||||
|
return token_count, max(255, token_count)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def encode_first_stage(self, x):
|
||||||
|
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||||
|
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||||
|
return sample.to(x)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def decode_first_stage(self, x):
|
||||||
|
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||||
|
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||||
|
|
||||||
|
return sample.to(x)
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "StableDiffusion3Pipeline",
|
||||||
|
"_diffusers_version": "0.30.3.dev0",
|
||||||
|
"scheduler": [
|
||||||
|
"diffusers",
|
||||||
|
"FlowMatchEulerDiscreteScheduler"
|
||||||
|
],
|
||||||
|
"text_encoder": [
|
||||||
|
"transformers",
|
||||||
|
"CLIPTextModelWithProjection"
|
||||||
|
],
|
||||||
|
"text_encoder_2": [
|
||||||
|
"transformers",
|
||||||
|
"CLIPTextModelWithProjection"
|
||||||
|
],
|
||||||
|
"text_encoder_3": [
|
||||||
|
"transformers",
|
||||||
|
"T5EncoderModel"
|
||||||
|
],
|
||||||
|
"tokenizer": [
|
||||||
|
"transformers",
|
||||||
|
"CLIPTokenizer"
|
||||||
|
],
|
||||||
|
"tokenizer_2": [
|
||||||
|
"transformers",
|
||||||
|
"CLIPTokenizer"
|
||||||
|
],
|
||||||
|
"tokenizer_3": [
|
||||||
|
"transformers",
|
||||||
|
"T5TokenizerFast"
|
||||||
|
],
|
||||||
|
"transformer": [
|
||||||
|
"diffusers",
|
||||||
|
"SD3Transformer2DModel"
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"diffusers",
|
||||||
|
"AutoencoderKL"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||||
|
"_diffusers_version": "0.29.0.dev0",
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"shift": 3.0
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"CLIPTextModelWithProjection"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"max_position_embeddings": 77,
|
||||||
|
"model_type": "clip_text_model",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"projection_dim": 768,
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"transformers_version": "4.41.2",
|
||||||
|
"vocab_size": 49408
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"CLIPTextModelWithProjection"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 5120,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"max_position_embeddings": 77,
|
||||||
|
"model_type": "clip_text_model",
|
||||||
|
"num_attention_heads": 20,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"projection_dim": 1280,
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"transformers_version": "4.41.2",
|
||||||
|
"vocab_size": 49408
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"T5EncoderModel"
|
||||||
|
],
|
||||||
|
"classifier_dropout": 0.0,
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 4096,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dense_act_fn": "gelu_new",
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"feed_forward_proj": "gated-gelu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_max_distance": 128,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"transformers_version": "4.41.2",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"total_size": 9524621312
|
||||||
|
},
|
||||||
|
"weight_map": {
|
||||||
|
"encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"shared.weight": "model-00001-of-00002.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"bos_token": {
|
||||||
|
"content": "<|startoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"49406": {
|
||||||
|
"content": "<|startoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"49407": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"bos_token": "<|startoftext|>",
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"do_lower_case": true,
|
||||||
|
"eos_token": "<|endoftext|>",
|
||||||
|
"errors": "replace",
|
||||||
|
"model_max_length": 77,
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"tokenizer_class": "CLIPTokenizer",
|
||||||
|
"unk_token": "<|endoftext|>"
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"bos_token": {
|
||||||
|
"content": "<|startoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "!",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "!",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"49406": {
|
||||||
|
"content": "<|startoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"49407": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"bos_token": "<|startoftext|>",
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"do_lower_case": true,
|
||||||
|
"eos_token": "<|endoftext|>",
|
||||||
|
"errors": "replace",
|
||||||
|
"model_max_length": 77,
|
||||||
|
"pad_token": "!",
|
||||||
|
"tokenizer_class": "CLIPTokenizer",
|
||||||
|
"unk_token": "<|endoftext|>"
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,125 @@
|
|||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<extra_id_0>",
|
||||||
|
"<extra_id_1>",
|
||||||
|
"<extra_id_2>",
|
||||||
|
"<extra_id_3>",
|
||||||
|
"<extra_id_4>",
|
||||||
|
"<extra_id_5>",
|
||||||
|
"<extra_id_6>",
|
||||||
|
"<extra_id_7>",
|
||||||
|
"<extra_id_8>",
|
||||||
|
"<extra_id_9>",
|
||||||
|
"<extra_id_10>",
|
||||||
|
"<extra_id_11>",
|
||||||
|
"<extra_id_12>",
|
||||||
|
"<extra_id_13>",
|
||||||
|
"<extra_id_14>",
|
||||||
|
"<extra_id_15>",
|
||||||
|
"<extra_id_16>",
|
||||||
|
"<extra_id_17>",
|
||||||
|
"<extra_id_18>",
|
||||||
|
"<extra_id_19>",
|
||||||
|
"<extra_id_20>",
|
||||||
|
"<extra_id_21>",
|
||||||
|
"<extra_id_22>",
|
||||||
|
"<extra_id_23>",
|
||||||
|
"<extra_id_24>",
|
||||||
|
"<extra_id_25>",
|
||||||
|
"<extra_id_26>",
|
||||||
|
"<extra_id_27>",
|
||||||
|
"<extra_id_28>",
|
||||||
|
"<extra_id_29>",
|
||||||
|
"<extra_id_30>",
|
||||||
|
"<extra_id_31>",
|
||||||
|
"<extra_id_32>",
|
||||||
|
"<extra_id_33>",
|
||||||
|
"<extra_id_34>",
|
||||||
|
"<extra_id_35>",
|
||||||
|
"<extra_id_36>",
|
||||||
|
"<extra_id_37>",
|
||||||
|
"<extra_id_38>",
|
||||||
|
"<extra_id_39>",
|
||||||
|
"<extra_id_40>",
|
||||||
|
"<extra_id_41>",
|
||||||
|
"<extra_id_42>",
|
||||||
|
"<extra_id_43>",
|
||||||
|
"<extra_id_44>",
|
||||||
|
"<extra_id_45>",
|
||||||
|
"<extra_id_46>",
|
||||||
|
"<extra_id_47>",
|
||||||
|
"<extra_id_48>",
|
||||||
|
"<extra_id_49>",
|
||||||
|
"<extra_id_50>",
|
||||||
|
"<extra_id_51>",
|
||||||
|
"<extra_id_52>",
|
||||||
|
"<extra_id_53>",
|
||||||
|
"<extra_id_54>",
|
||||||
|
"<extra_id_55>",
|
||||||
|
"<extra_id_56>",
|
||||||
|
"<extra_id_57>",
|
||||||
|
"<extra_id_58>",
|
||||||
|
"<extra_id_59>",
|
||||||
|
"<extra_id_60>",
|
||||||
|
"<extra_id_61>",
|
||||||
|
"<extra_id_62>",
|
||||||
|
"<extra_id_63>",
|
||||||
|
"<extra_id_64>",
|
||||||
|
"<extra_id_65>",
|
||||||
|
"<extra_id_66>",
|
||||||
|
"<extra_id_67>",
|
||||||
|
"<extra_id_68>",
|
||||||
|
"<extra_id_69>",
|
||||||
|
"<extra_id_70>",
|
||||||
|
"<extra_id_71>",
|
||||||
|
"<extra_id_72>",
|
||||||
|
"<extra_id_73>",
|
||||||
|
"<extra_id_74>",
|
||||||
|
"<extra_id_75>",
|
||||||
|
"<extra_id_76>",
|
||||||
|
"<extra_id_77>",
|
||||||
|
"<extra_id_78>",
|
||||||
|
"<extra_id_79>",
|
||||||
|
"<extra_id_80>",
|
||||||
|
"<extra_id_81>",
|
||||||
|
"<extra_id_82>",
|
||||||
|
"<extra_id_83>",
|
||||||
|
"<extra_id_84>",
|
||||||
|
"<extra_id_85>",
|
||||||
|
"<extra_id_86>",
|
||||||
|
"<extra_id_87>",
|
||||||
|
"<extra_id_88>",
|
||||||
|
"<extra_id_89>",
|
||||||
|
"<extra_id_90>",
|
||||||
|
"<extra_id_91>",
|
||||||
|
"<extra_id_92>",
|
||||||
|
"<extra_id_93>",
|
||||||
|
"<extra_id_94>",
|
||||||
|
"<extra_id_95>",
|
||||||
|
"<extra_id_96>",
|
||||||
|
"<extra_id_97>",
|
||||||
|
"<extra_id_98>",
|
||||||
|
"<extra_id_99>"
|
||||||
|
],
|
||||||
|
"eos_token": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -0,0 +1,940 @@
|
|||||||
|
{
|
||||||
|
"add_prefix_space": true,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32000": {
|
||||||
|
"content": "<extra_id_99>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32001": {
|
||||||
|
"content": "<extra_id_98>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32002": {
|
||||||
|
"content": "<extra_id_97>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32003": {
|
||||||
|
"content": "<extra_id_96>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32004": {
|
||||||
|
"content": "<extra_id_95>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32005": {
|
||||||
|
"content": "<extra_id_94>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32006": {
|
||||||
|
"content": "<extra_id_93>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32007": {
|
||||||
|
"content": "<extra_id_92>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32008": {
|
||||||
|
"content": "<extra_id_91>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32009": {
|
||||||
|
"content": "<extra_id_90>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32010": {
|
||||||
|
"content": "<extra_id_89>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32011": {
|
||||||
|
"content": "<extra_id_88>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32012": {
|
||||||
|
"content": "<extra_id_87>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32013": {
|
||||||
|
"content": "<extra_id_86>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32014": {
|
||||||
|
"content": "<extra_id_85>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32015": {
|
||||||
|
"content": "<extra_id_84>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32016": {
|
||||||
|
"content": "<extra_id_83>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32017": {
|
||||||
|
"content": "<extra_id_82>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32018": {
|
||||||
|
"content": "<extra_id_81>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32019": {
|
||||||
|
"content": "<extra_id_80>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32020": {
|
||||||
|
"content": "<extra_id_79>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32021": {
|
||||||
|
"content": "<extra_id_78>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32022": {
|
||||||
|
"content": "<extra_id_77>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32023": {
|
||||||
|
"content": "<extra_id_76>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32024": {
|
||||||
|
"content": "<extra_id_75>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32025": {
|
||||||
|
"content": "<extra_id_74>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32026": {
|
||||||
|
"content": "<extra_id_73>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32027": {
|
||||||
|
"content": "<extra_id_72>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32028": {
|
||||||
|
"content": "<extra_id_71>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32029": {
|
||||||
|
"content": "<extra_id_70>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32030": {
|
||||||
|
"content": "<extra_id_69>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32031": {
|
||||||
|
"content": "<extra_id_68>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32032": {
|
||||||
|
"content": "<extra_id_67>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32033": {
|
||||||
|
"content": "<extra_id_66>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32034": {
|
||||||
|
"content": "<extra_id_65>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32035": {
|
||||||
|
"content": "<extra_id_64>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32036": {
|
||||||
|
"content": "<extra_id_63>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32037": {
|
||||||
|
"content": "<extra_id_62>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32038": {
|
||||||
|
"content": "<extra_id_61>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32039": {
|
||||||
|
"content": "<extra_id_60>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32040": {
|
||||||
|
"content": "<extra_id_59>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32041": {
|
||||||
|
"content": "<extra_id_58>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32042": {
|
||||||
|
"content": "<extra_id_57>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32043": {
|
||||||
|
"content": "<extra_id_56>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32044": {
|
||||||
|
"content": "<extra_id_55>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32045": {
|
||||||
|
"content": "<extra_id_54>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32046": {
|
||||||
|
"content": "<extra_id_53>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32047": {
|
||||||
|
"content": "<extra_id_52>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32048": {
|
||||||
|
"content": "<extra_id_51>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32049": {
|
||||||
|
"content": "<extra_id_50>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32050": {
|
||||||
|
"content": "<extra_id_49>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32051": {
|
||||||
|
"content": "<extra_id_48>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32052": {
|
||||||
|
"content": "<extra_id_47>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32053": {
|
||||||
|
"content": "<extra_id_46>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32054": {
|
||||||
|
"content": "<extra_id_45>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32055": {
|
||||||
|
"content": "<extra_id_44>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32056": {
|
||||||
|
"content": "<extra_id_43>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32057": {
|
||||||
|
"content": "<extra_id_42>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32058": {
|
||||||
|
"content": "<extra_id_41>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32059": {
|
||||||
|
"content": "<extra_id_40>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32060": {
|
||||||
|
"content": "<extra_id_39>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32061": {
|
||||||
|
"content": "<extra_id_38>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32062": {
|
||||||
|
"content": "<extra_id_37>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32063": {
|
||||||
|
"content": "<extra_id_36>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32064": {
|
||||||
|
"content": "<extra_id_35>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32065": {
|
||||||
|
"content": "<extra_id_34>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32066": {
|
||||||
|
"content": "<extra_id_33>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32067": {
|
||||||
|
"content": "<extra_id_32>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32068": {
|
||||||
|
"content": "<extra_id_31>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32069": {
|
||||||
|
"content": "<extra_id_30>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32070": {
|
||||||
|
"content": "<extra_id_29>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32071": {
|
||||||
|
"content": "<extra_id_28>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32072": {
|
||||||
|
"content": "<extra_id_27>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32073": {
|
||||||
|
"content": "<extra_id_26>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32074": {
|
||||||
|
"content": "<extra_id_25>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32075": {
|
||||||
|
"content": "<extra_id_24>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32076": {
|
||||||
|
"content": "<extra_id_23>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32077": {
|
||||||
|
"content": "<extra_id_22>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32078": {
|
||||||
|
"content": "<extra_id_21>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32079": {
|
||||||
|
"content": "<extra_id_20>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32080": {
|
||||||
|
"content": "<extra_id_19>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32081": {
|
||||||
|
"content": "<extra_id_18>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32082": {
|
||||||
|
"content": "<extra_id_17>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32083": {
|
||||||
|
"content": "<extra_id_16>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32084": {
|
||||||
|
"content": "<extra_id_15>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32085": {
|
||||||
|
"content": "<extra_id_14>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32086": {
|
||||||
|
"content": "<extra_id_13>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32087": {
|
||||||
|
"content": "<extra_id_12>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32088": {
|
||||||
|
"content": "<extra_id_11>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32089": {
|
||||||
|
"content": "<extra_id_10>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32090": {
|
||||||
|
"content": "<extra_id_9>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32091": {
|
||||||
|
"content": "<extra_id_8>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32092": {
|
||||||
|
"content": "<extra_id_7>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32093": {
|
||||||
|
"content": "<extra_id_6>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32094": {
|
||||||
|
"content": "<extra_id_5>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32095": {
|
||||||
|
"content": "<extra_id_4>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32096": {
|
||||||
|
"content": "<extra_id_3>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32097": {
|
||||||
|
"content": "<extra_id_2>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32098": {
|
||||||
|
"content": "<extra_id_1>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32099": {
|
||||||
|
"content": "<extra_id_0>",
|
||||||
|
"lstrip": true,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": true,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<extra_id_0>",
|
||||||
|
"<extra_id_1>",
|
||||||
|
"<extra_id_2>",
|
||||||
|
"<extra_id_3>",
|
||||||
|
"<extra_id_4>",
|
||||||
|
"<extra_id_5>",
|
||||||
|
"<extra_id_6>",
|
||||||
|
"<extra_id_7>",
|
||||||
|
"<extra_id_8>",
|
||||||
|
"<extra_id_9>",
|
||||||
|
"<extra_id_10>",
|
||||||
|
"<extra_id_11>",
|
||||||
|
"<extra_id_12>",
|
||||||
|
"<extra_id_13>",
|
||||||
|
"<extra_id_14>",
|
||||||
|
"<extra_id_15>",
|
||||||
|
"<extra_id_16>",
|
||||||
|
"<extra_id_17>",
|
||||||
|
"<extra_id_18>",
|
||||||
|
"<extra_id_19>",
|
||||||
|
"<extra_id_20>",
|
||||||
|
"<extra_id_21>",
|
||||||
|
"<extra_id_22>",
|
||||||
|
"<extra_id_23>",
|
||||||
|
"<extra_id_24>",
|
||||||
|
"<extra_id_25>",
|
||||||
|
"<extra_id_26>",
|
||||||
|
"<extra_id_27>",
|
||||||
|
"<extra_id_28>",
|
||||||
|
"<extra_id_29>",
|
||||||
|
"<extra_id_30>",
|
||||||
|
"<extra_id_31>",
|
||||||
|
"<extra_id_32>",
|
||||||
|
"<extra_id_33>",
|
||||||
|
"<extra_id_34>",
|
||||||
|
"<extra_id_35>",
|
||||||
|
"<extra_id_36>",
|
||||||
|
"<extra_id_37>",
|
||||||
|
"<extra_id_38>",
|
||||||
|
"<extra_id_39>",
|
||||||
|
"<extra_id_40>",
|
||||||
|
"<extra_id_41>",
|
||||||
|
"<extra_id_42>",
|
||||||
|
"<extra_id_43>",
|
||||||
|
"<extra_id_44>",
|
||||||
|
"<extra_id_45>",
|
||||||
|
"<extra_id_46>",
|
||||||
|
"<extra_id_47>",
|
||||||
|
"<extra_id_48>",
|
||||||
|
"<extra_id_49>",
|
||||||
|
"<extra_id_50>",
|
||||||
|
"<extra_id_51>",
|
||||||
|
"<extra_id_52>",
|
||||||
|
"<extra_id_53>",
|
||||||
|
"<extra_id_54>",
|
||||||
|
"<extra_id_55>",
|
||||||
|
"<extra_id_56>",
|
||||||
|
"<extra_id_57>",
|
||||||
|
"<extra_id_58>",
|
||||||
|
"<extra_id_59>",
|
||||||
|
"<extra_id_60>",
|
||||||
|
"<extra_id_61>",
|
||||||
|
"<extra_id_62>",
|
||||||
|
"<extra_id_63>",
|
||||||
|
"<extra_id_64>",
|
||||||
|
"<extra_id_65>",
|
||||||
|
"<extra_id_66>",
|
||||||
|
"<extra_id_67>",
|
||||||
|
"<extra_id_68>",
|
||||||
|
"<extra_id_69>",
|
||||||
|
"<extra_id_70>",
|
||||||
|
"<extra_id_71>",
|
||||||
|
"<extra_id_72>",
|
||||||
|
"<extra_id_73>",
|
||||||
|
"<extra_id_74>",
|
||||||
|
"<extra_id_75>",
|
||||||
|
"<extra_id_76>",
|
||||||
|
"<extra_id_77>",
|
||||||
|
"<extra_id_78>",
|
||||||
|
"<extra_id_79>",
|
||||||
|
"<extra_id_80>",
|
||||||
|
"<extra_id_81>",
|
||||||
|
"<extra_id_82>",
|
||||||
|
"<extra_id_83>",
|
||||||
|
"<extra_id_84>",
|
||||||
|
"<extra_id_85>",
|
||||||
|
"<extra_id_86>",
|
||||||
|
"<extra_id_87>",
|
||||||
|
"<extra_id_88>",
|
||||||
|
"<extra_id_89>",
|
||||||
|
"<extra_id_90>",
|
||||||
|
"<extra_id_91>",
|
||||||
|
"<extra_id_92>",
|
||||||
|
"<extra_id_93>",
|
||||||
|
"<extra_id_94>",
|
||||||
|
"<extra_id_95>",
|
||||||
|
"<extra_id_96>",
|
||||||
|
"<extra_id_97>",
|
||||||
|
"<extra_id_98>",
|
||||||
|
"<extra_id_99>"
|
||||||
|
],
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"extra_ids": 100,
|
||||||
|
"legacy": true,
|
||||||
|
"model_max_length": 512,
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"tokenizer_class": "T5Tokenizer",
|
||||||
|
"unk_token": "<unk>"
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "SD3Transformer2DModel",
|
||||||
|
"_diffusers_version": "0.31.0.dev0",
|
||||||
|
"attention_head_dim": 64,
|
||||||
|
"caption_projection_dim": 2432,
|
||||||
|
"in_channels": 16,
|
||||||
|
"joint_attention_dim": 4096,
|
||||||
|
"num_attention_heads": 38,
|
||||||
|
"num_layers": 38,
|
||||||
|
"out_channels": 16,
|
||||||
|
"patch_size": 2,
|
||||||
|
"pooled_projection_dim": 2048,
|
||||||
|
"pos_embed_max_size": 192,
|
||||||
|
"qk_norm": "rms_norm",
|
||||||
|
"sample_size": 128
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "AutoencoderKL",
|
||||||
|
"_diffusers_version": "0.31.0.dev0",
|
||||||
|
"_name_or_path": "../sdxl-vae/",
|
||||||
|
"act_fn": "silu",
|
||||||
|
"block_out_channels": [
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
512
|
||||||
|
],
|
||||||
|
"down_block_types": [
|
||||||
|
"DownEncoderBlock2D",
|
||||||
|
"DownEncoderBlock2D",
|
||||||
|
"DownEncoderBlock2D",
|
||||||
|
"DownEncoderBlock2D"
|
||||||
|
],
|
||||||
|
"force_upcast": true,
|
||||||
|
"in_channels": 3,
|
||||||
|
"latent_channels": 16,
|
||||||
|
"latents_mean": null,
|
||||||
|
"latents_std": null,
|
||||||
|
"layers_per_block": 2,
|
||||||
|
"mid_block_add_attention": true,
|
||||||
|
"norm_num_groups": 32,
|
||||||
|
"out_channels": 3,
|
||||||
|
"sample_size": 1024,
|
||||||
|
"scaling_factor": 1.5305,
|
||||||
|
"shift_factor": 0.0609,
|
||||||
|
"up_block_types": [
|
||||||
|
"UpDecoderBlock2D",
|
||||||
|
"UpDecoderBlock2D",
|
||||||
|
"UpDecoderBlock2D",
|
||||||
|
"UpDecoderBlock2D"
|
||||||
|
],
|
||||||
|
"use_post_quant_conv": false,
|
||||||
|
"use_quant_conv": false
|
||||||
|
}
|
||||||
@@ -19,11 +19,12 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
|||||||
|
|
||||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||||
|
from backend.diffusion_engine.sd35 import StableDiffusion3
|
||||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||||
from backend.diffusion_engine.flux import Flux
|
from backend.diffusion_engine.flux import Flux
|
||||||
|
|
||||||
|
|
||||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
|
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, StableDiffusion3, Flux]
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||||
@@ -107,7 +108,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||||
|
|
||||||
return model
|
return model
|
||||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
|
||||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||||
|
|
||||||
model_loader = None
|
model_loader = None
|
||||||
@@ -116,6 +117,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
if cls_name == 'FluxTransformer2DModel':
|
if cls_name == 'FluxTransformer2DModel':
|
||||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||||
|
if cls_name == 'SD3Transformer2DModel':
|
||||||
|
from modules.models.sd35.mmditx import MMDiTX
|
||||||
|
model_loader = lambda c: MMDiTX(**c)
|
||||||
|
|
||||||
unet_config = guess.unet_config.copy()
|
unet_config = guess.unet_config.copy()
|
||||||
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
||||||
@@ -170,7 +174,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def replace_state_dict(sd, asd, guess):
|
def replace_state_dict(sd, asd, guess, is_clip_g = False):
|
||||||
vae_key_prefix = guess.vae_key_prefix[0]
|
vae_key_prefix = guess.vae_key_prefix[0]
|
||||||
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
||||||
|
|
||||||
@@ -210,11 +214,18 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
sd[vae_key_prefix + k] = v
|
sd[vae_key_prefix + k] = v
|
||||||
|
|
||||||
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
||||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
if is_clip_g:
|
||||||
for k in keys_to_delete:
|
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_g.")]
|
||||||
del sd[k]
|
for k in keys_to_delete:
|
||||||
for k, v in asd.items():
|
del sd[k]
|
||||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
for k, v in asd.items():
|
||||||
|
sd[f"{text_encoder_key_prefix}clip_g.transformer.{k}"] = v
|
||||||
|
else:
|
||||||
|
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||||
|
for k in keys_to_delete:
|
||||||
|
del sd[k]
|
||||||
|
for k, v in asd.items():
|
||||||
|
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||||
|
|
||||||
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
||||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
||||||
@@ -241,8 +252,9 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
|||||||
|
|
||||||
if isinstance(additional_state_dicts, list):
|
if isinstance(additional_state_dicts, list):
|
||||||
for asd in additional_state_dicts:
|
for asd in additional_state_dicts:
|
||||||
|
is_clip_g = 'clip_g' in asd
|
||||||
asd = load_torch_file(asd)
|
asd = load_torch_file(asd)
|
||||||
sd = replace_state_dict(sd, asd, guess)
|
sd = replace_state_dict(sd, asd, guess, is_clip_g)
|
||||||
|
|
||||||
guess.clip_target = guess.clip_target(sd)
|
guess.clip_target = guess.clip_target(sd)
|
||||||
guess.model_type = guess.model_type(sd)
|
guess.model_type = guess.model_type(sd)
|
||||||
|
|||||||
@@ -251,6 +251,31 @@ class PredictionFlow(AbstractPrediction):
|
|||||||
return 1.0 - percent
|
return 1.0 - percent
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionDiscreteFlow(AbstractPrediction):
|
||||||
|
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.0, timesteps = 1000):
|
||||||
|
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||||
|
self.shift = shift
|
||||||
|
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||||
|
self.register_buffer("sigmas", ts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma * 1000
|
||||||
|
|
||||||
|
def sigma(self, timestep: torch.Tensor):
|
||||||
|
timestep = timestep / 1000.0
|
||||||
|
if self.shift == 1.0:
|
||||||
|
return timestep
|
||||||
|
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||||
|
|
||||||
|
|
||||||
class PredictionFlux(AbstractPrediction):
|
class PredictionFlux(AbstractPrediction):
|
||||||
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
|
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
|
||||||
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ import torch
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from backend.attention import attention_pytorch as attention_function
|
from backend.attention import attention_pytorch as attention_function
|
||||||
|
from transformers.activations import NewGELUActivation
|
||||||
|
|
||||||
activations = {
|
activations = {
|
||||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||||
"relu": torch.nn.functional.relu,
|
"relu": torch.nn.functional.relu,
|
||||||
|
"gelu_new": lambda a: NewGELUActivation()(a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class ClassicTextProcessingEngine:
|
|||||||
if self.return_pooled:
|
if self.return_pooled:
|
||||||
pooled_output = outputs.pooler_output
|
pooled_output = outputs.pooler_output
|
||||||
|
|
||||||
if self.text_projection:
|
if self.text_projection and self.embedding_key is not 'clip_l':
|
||||||
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
|
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
|
||||||
|
|
||||||
z.pooled = pooled_output
|
z.pooled = pooled_output
|
||||||
|
|||||||
17
modules/models/sd35/LICENSE-CODE
Normal file
17
modules/models/sd35/LICENSE-CODE
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright © 2024 Stability AI
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
918
modules/models/sd35/mmditx.py
Normal file
918
modules/models/sd35/mmditx.py
Normal file
@@ -0,0 +1,918 @@
|
|||||||
|
### This file contains impls for MM-DiT, the core model component of SD3
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from modules.models.sd35.other_impls import Mlp, attention
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""2D Image to Patch Embedding"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: Optional[int] = 224,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
flatten: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
strict_img_size: bool = True,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = (patch_size, patch_size)
|
||||||
|
if img_size is not None:
|
||||||
|
self.img_size = (img_size, img_size)
|
||||||
|
self.grid_size = tuple(
|
||||||
|
[s // p for s, p in zip(self.img_size, self.patch_size)]
|
||||||
|
)
|
||||||
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.img_size = None
|
||||||
|
self.grid_size = None
|
||||||
|
self.num_patches = None
|
||||||
|
|
||||||
|
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||||
|
self.flatten = flatten
|
||||||
|
self.strict_img_size = strict_img_size
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
in_chans,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = self.proj(x)
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
if shift is None:
|
||||||
|
shift = torch.zeros_like(scale)
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Sine/Cosine Positional Embedding Functions #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed(
|
||||||
|
embed_dim,
|
||||||
|
grid_size,
|
||||||
|
cls_token=False,
|
||||||
|
extra_tokens=0,
|
||||||
|
scaling_factor=None,
|
||||||
|
offset=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
grid_size: int of the grid height and width
|
||||||
|
return:
|
||||||
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||||
|
"""
|
||||||
|
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||||
|
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
if scaling_factor is not None:
|
||||||
|
grid = grid / scaling_factor
|
||||||
|
if offset is not None:
|
||||||
|
grid = grid - offset
|
||||||
|
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||||
|
if cls_token and extra_tokens > 0:
|
||||||
|
pos_embed = np.concatenate(
|
||||||
|
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||||
|
)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
# use half of dimensions to encode grid_h
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
|
omega /= embed_dim / 2.0
|
||||||
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Embedding Layers for Timesteps and Class Labels #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""Embeds scalar timesteps into vector representations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, hidden_size, frequency_embedding_size=256, dtype=None, device=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
frequency_embedding_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period)
|
||||||
|
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||||
|
/ half
|
||||||
|
).to(device=t.device)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
embedding = embedding.to(dtype=t.dtype)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, dtype, **kwargs):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class VectorEmbedder(nn.Module):
|
||||||
|
"""Embeds a flat vector of dimension input_dim"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.mlp(x)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Core DiT Model #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def split_qkv(qkv, head_dim):
|
||||||
|
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||||
|
return qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
|
||||||
|
def optimized_attention(qkv, num_heads):
|
||||||
|
return attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_scale: Optional[float] = None,
|
||||||
|
attn_mode: str = "xformers",
|
||||||
|
pre_only: bool = False,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
rmsnorm: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
if not pre_only:
|
||||||
|
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.pre_only = pre_only
|
||||||
|
|
||||||
|
if qk_norm == "rms":
|
||||||
|
self.ln_q = RMSNorm(
|
||||||
|
self.head_dim,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps=1.0e-6,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.ln_k = RMSNorm(
|
||||||
|
self.head_dim,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps=1.0e-6,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif qk_norm == "ln":
|
||||||
|
self.ln_q = nn.LayerNorm(
|
||||||
|
self.head_dim,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps=1.0e-6,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.ln_k = nn.LayerNorm(
|
||||||
|
self.head_dim,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps=1.0e-6,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif qk_norm is None:
|
||||||
|
self.ln_q = nn.Identity()
|
||||||
|
self.ln_k = nn.Identity()
|
||||||
|
else:
|
||||||
|
raise ValueError(qk_norm)
|
||||||
|
|
||||||
|
def pre_attention(self, x: torch.Tensor):
|
||||||
|
B, L, C = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = split_qkv(qkv, self.head_dim)
|
||||||
|
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||||
|
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||||
|
return (q, k, v)
|
||||||
|
|
||||||
|
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert not self.pre_only
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
(q, k, v) = self.pre_attention(x)
|
||||||
|
x = attention(q, k, v, self.num_heads)
|
||||||
|
x = self.post_attention(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
elementwise_affine: bool = False,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RMSNorm normalization layer.
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input tensor.
|
||||||
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||||
|
Attributes:
|
||||||
|
eps (float): A small value added to the denominator for numerical stability.
|
||||||
|
weight (nn.Parameter): Learnable scaling parameter.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.learnable_scale = elementwise_affine
|
||||||
|
if self.learnable_scale:
|
||||||
|
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
"""
|
||||||
|
Apply the RMSNorm normalization to the input tensor.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
"""
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the RMSNorm layer.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output tensor after applying RMSNorm.
|
||||||
|
"""
|
||||||
|
x = self._norm(x)
|
||||||
|
if self.learnable_scale:
|
||||||
|
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the FeedForward module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Input dimension.
|
||||||
|
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||||
|
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||||
|
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||||
|
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||||
|
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class DismantledBlock(nn.Module):
|
||||||
|
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
|
||||||
|
|
||||||
|
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: str = "xformers",
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
pre_only: bool = False,
|
||||||
|
rmsnorm: bool = False,
|
||||||
|
scale_mod_only: bool = False,
|
||||||
|
swiglu: bool = False,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
x_block_self_attn: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
|
if not rmsnorm:
|
||||||
|
self.norm1 = nn.LayerNorm(
|
||||||
|
hidden_size,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.attn = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=pre_only,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
self.x_block_self_attn = True
|
||||||
|
self.attn2 = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.x_block_self_attn = False
|
||||||
|
if not pre_only:
|
||||||
|
if not rmsnorm:
|
||||||
|
self.norm2 = nn.LayerNorm(
|
||||||
|
hidden_size,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
if not pre_only:
|
||||||
|
if not swiglu:
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=hidden_size,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=nn.GELU(approximate="tanh"),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mlp = SwiGLUFeedForward(
|
||||||
|
dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256
|
||||||
|
)
|
||||||
|
self.scale_mod_only = scale_mod_only
|
||||||
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
n_mods = 9
|
||||||
|
elif not scale_mod_only:
|
||||||
|
n_mods = 6 if not pre_only else 2
|
||||||
|
else:
|
||||||
|
n_mods = 4 if not pre_only else 1
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(
|
||||||
|
hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.pre_only = pre_only
|
||||||
|
|
||||||
|
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
|
||||||
|
assert x is not None, "pre_attention called with None input"
|
||||||
|
if not self.pre_only:
|
||||||
|
if not self.scale_mod_only:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
|
self.adaLN_modulation(c).chunk(6, dim=1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shift_msa = None
|
||||||
|
shift_mlp = None
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||||
|
c
|
||||||
|
).chunk(4, dim=1)
|
||||||
|
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||||
|
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||||
|
else:
|
||||||
|
if not self.scale_mod_only:
|
||||||
|
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa = None
|
||||||
|
scale_msa = self.adaLN_modulation(c)
|
||||||
|
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||||
|
return qkv, None
|
||||||
|
|
||||||
|
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||||
|
assert not self.pre_only
|
||||||
|
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||||
|
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert self.x_block_self_attn
|
||||||
|
(
|
||||||
|
shift_msa,
|
||||||
|
scale_msa,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
shift_msa2,
|
||||||
|
scale_msa2,
|
||||||
|
gate_msa2,
|
||||||
|
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||||
|
x_norm = self.norm1(x)
|
||||||
|
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||||
|
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||||
|
return (
|
||||||
|
qkv,
|
||||||
|
qkv2,
|
||||||
|
(
|
||||||
|
x,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
gate_msa2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_attention_x(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
attn2,
|
||||||
|
x,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
gate_msa2,
|
||||||
|
attn1_dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
assert not self.pre_only
|
||||||
|
if attn1_dropout > 0.0:
|
||||||
|
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
||||||
|
attn1_dropout = torch.bernoulli(
|
||||||
|
torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)
|
||||||
|
)
|
||||||
|
attn_ = (
|
||||||
|
gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||||
|
x = x + attn_
|
||||||
|
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
||||||
|
x = x + attn2_
|
||||||
|
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(
|
||||||
|
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
)
|
||||||
|
x = x + mlp_
|
||||||
|
return x, (gate_msa, gate_msa2, gate_mlp, attn_, attn2_)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert not self.pre_only
|
||||||
|
if self.x_block_self_attn:
|
||||||
|
(q, k, v), (q2, k2, v2), intermediates = self.pre_attention_x(x, c)
|
||||||
|
attn = attention(q, k, v, self.attn.num_heads)
|
||||||
|
attn2 = attention(q2, k2, v2, self.attn2.num_heads)
|
||||||
|
return self.post_attention_x(attn, attn2, *intermediates)
|
||||||
|
else:
|
||||||
|
(q, k, v), intermediates = self.pre_attention(x, c)
|
||||||
|
attn = attention(q, k, v, self.attn.num_heads)
|
||||||
|
return self.post_attention(attn, *intermediates)
|
||||||
|
|
||||||
|
|
||||||
|
def block_mixing(context, x, context_block, x_block, c):
|
||||||
|
assert context is not None, "block_mixing called with None context"
|
||||||
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||||
|
|
||||||
|
if x_block.x_block_self_attn:
|
||||||
|
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||||
|
else:
|
||||||
|
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||||
|
|
||||||
|
o = []
|
||||||
|
for t in range(3):
|
||||||
|
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||||
|
q, k, v = tuple(o)
|
||||||
|
|
||||||
|
attn = attention(q, k, v, x_block.attn.num_heads)
|
||||||
|
context_attn, x_attn = (
|
||||||
|
attn[:, : context_qkv[0].shape[1]],
|
||||||
|
attn[:, context_qkv[0].shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context_block.pre_only:
|
||||||
|
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||||
|
else:
|
||||||
|
context = None
|
||||||
|
|
||||||
|
if x_block.x_block_self_attn:
|
||||||
|
x_q2, x_k2, x_v2 = x_qkv2
|
||||||
|
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
|
||||||
|
else:
|
||||||
|
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||||
|
|
||||||
|
return context, x
|
||||||
|
|
||||||
|
|
||||||
|
class JointBlock(nn.Module):
|
||||||
|
"""just a small wrapper to serve as a fsdp unit"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
pre_only = kwargs.pop("pre_only")
|
||||||
|
qk_norm = kwargs.pop("qk_norm", None)
|
||||||
|
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||||
|
self.context_block = DismantledBlock(
|
||||||
|
*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs
|
||||||
|
)
|
||||||
|
self.x_block = DismantledBlock(
|
||||||
|
*args,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
x_block_self_attn=x_block_self_attn,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return block_mixing(
|
||||||
|
*args, context_block=self.context_block, x_block=self.x_block, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
out_channels: int,
|
||||||
|
total_out_channels: Optional[int] = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.linear = (
|
||||||
|
nn.Linear(
|
||||||
|
hidden_size,
|
||||||
|
patch_size * patch_size * out_channels,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if (total_out_channels is None)
|
||||||
|
else nn.Linear(
|
||||||
|
hidden_size, total_out_channels, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(
|
||||||
|
hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiTX(nn.Module):
|
||||||
|
"""Diffusion model with a Transformer backbone."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int = 32,
|
||||||
|
patch_size: int = 4,
|
||||||
|
in_channels: int = 4,
|
||||||
|
depth: int = 28,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
learn_sigma: bool = False,
|
||||||
|
adm_in_channels: Optional[int] = None,
|
||||||
|
context_embedder_config: Optional[Dict] = None,
|
||||||
|
register_length: int = 0,
|
||||||
|
attn_mode: str = "torch",
|
||||||
|
rmsnorm: bool = False,
|
||||||
|
scale_mod_only: bool = False,
|
||||||
|
swiglu: bool = False,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
pos_embed_scaling_factor: Optional[float] = None,
|
||||||
|
pos_embed_offset: Optional[float] = None,
|
||||||
|
pos_embed_max_size: Optional[int] = None,
|
||||||
|
num_patches=None,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# input_size = 800
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
|
||||||
|
)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.out_channels = (
|
||||||
|
out_channels if out_channels is not None else default_out_channels
|
||||||
|
)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||||
|
self.pos_embed_offset = pos_embed_offset
|
||||||
|
self.pos_embed_max_size = pos_embed_max_size
|
||||||
|
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||||
|
|
||||||
|
# apply magic --> this defines a head_size of 64
|
||||||
|
hidden_size = 64 * depth
|
||||||
|
num_heads = depth
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
32,
|
||||||
|
2,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
bias=True,
|
||||||
|
strict_img_size=self.pos_embed_max_size is None,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if adm_in_channels is not None:
|
||||||
|
assert isinstance(adm_in_channels, int)
|
||||||
|
self.y_embedder = VectorEmbedder(
|
||||||
|
adm_in_channels, hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.context_embedder = nn.Identity()
|
||||||
|
if context_embedder_config is not None:
|
||||||
|
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||||
|
self.context_embedder = nn.Linear(
|
||||||
|
**context_embedder_config["params"], dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_length = register_length
|
||||||
|
if self.register_length > 0:
|
||||||
|
self.register = nn.Parameter(
|
||||||
|
torch.randn(1, register_length, hidden_size, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
# num_patches = self.x_embedder.num_patches
|
||||||
|
# Will use fixed sin-cos embedding:
|
||||||
|
# just use a buffer already
|
||||||
|
if num_patches is not None:
|
||||||
|
self.register_buffer(
|
||||||
|
"pos_embed",
|
||||||
|
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pos_embed = None
|
||||||
|
|
||||||
|
self.joint_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointBlock(
|
||||||
|
hidden_size,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=i == depth - 1,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
scale_mod_only=scale_mod_only,
|
||||||
|
swiglu=swiglu,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def cropped_pos_embed(self, hw):
|
||||||
|
assert self.pos_embed_max_size is not None
|
||||||
|
p = self.x_embedder.patch_size[0]
|
||||||
|
h, w = hw
|
||||||
|
# patched size
|
||||||
|
h = h // p
|
||||||
|
w = w // p
|
||||||
|
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||||
|
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||||
|
top = (self.pos_embed_max_size - h) // 2
|
||||||
|
left = (self.pos_embed_max_size - w) // 2
|
||||||
|
spatial_pos_embed = rearrange(
|
||||||
|
self.pos_embed,
|
||||||
|
"1 (h w) c -> 1 h w c",
|
||||||
|
h=self.pos_embed_max_size,
|
||||||
|
w=self.pos_embed_max_size,
|
||||||
|
)
|
||||||
|
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||||
|
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||||
|
return spatial_pos_embed
|
||||||
|
|
||||||
|
def unpatchify(self, x, hw=None):
|
||||||
|
"""
|
||||||
|
x: (N, T, patch_size**2 * C)
|
||||||
|
imgs: (N, H, W, C)
|
||||||
|
"""
|
||||||
|
c = self.out_channels
|
||||||
|
p = self.x_embedder.patch_size[0]
|
||||||
|
if hw is None:
|
||||||
|
h = w = int(x.shape[1] ** 0.5)
|
||||||
|
else:
|
||||||
|
h, w = hw
|
||||||
|
h = h // p
|
||||||
|
w = w // p
|
||||||
|
assert h * w == x.shape[1]
|
||||||
|
|
||||||
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||||
|
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||||
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def forward_core_with_concat(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c_mod: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.register_length > 0:
|
||||||
|
context = torch.cat(
|
||||||
|
(
|
||||||
|
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||||
|
context if context is not None else torch.Tensor([]).type_as(x),
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# context is B, L', D
|
||||||
|
# x is B, L, D
|
||||||
|
for block in self.joint_blocks:
|
||||||
|
context, x = block(context, x, c=c_mod)
|
||||||
|
|
||||||
|
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None, control=None, transformer_options={}, **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of DiT.
|
||||||
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
t: (N,) tensor of diffusion timesteps
|
||||||
|
y: (N,) tensor of class labels
|
||||||
|
"""
|
||||||
|
hw = x.shape[-2:]
|
||||||
|
# The line below should be unnecessary when full integrated.
|
||||||
|
x = x[:1,:16,:,:]
|
||||||
|
x = self.x_embedder(x) + self.cropped_pos_embed(hw).to("cuda")
|
||||||
|
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||||
|
if y is not None:
|
||||||
|
y = self.y_embedder(y) # (N, D)
|
||||||
|
c = c + y # (N, D)
|
||||||
|
|
||||||
|
context = self.context_embedder(context)
|
||||||
|
|
||||||
|
x = self.forward_core_with_concat(x, c, context)
|
||||||
|
|
||||||
|
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||||
|
return x
|
||||||
868
modules/models/sd35/other_impls.py
Normal file
868
modules/models/sd35/other_impls.py
Normal file
@@ -0,0 +1,868 @@
|
|||||||
|
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### Core/Utility
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v, heads, mask=None):
|
||||||
|
"""Convenience wrapper around a basic attention operation"""
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
bias=True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(
|
||||||
|
in_features, hidden_features, bias=bias, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.act = act_layer
|
||||||
|
self.fc2 = nn.Linear(
|
||||||
|
hidden_features, out_features, bias=bias, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### CLIP
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPAttention(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, heads, dtype, device):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.out_proj = nn.Linear(
|
||||||
|
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
q = self.q_proj(x)
|
||||||
|
k = self.k_proj(x)
|
||||||
|
v = self.v_proj(x)
|
||||||
|
out = attention(q, k, v, self.heads, mask)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
ACTIVATIONS = {
|
||||||
|
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||||
|
"gelu": torch.nn.functional.gelu,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPLayer(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
heads,
|
||||||
|
intermediate_size,
|
||||||
|
intermediate_activation,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
embed_dim,
|
||||||
|
intermediate_size,
|
||||||
|
embed_dim,
|
||||||
|
act_layer=ACTIVATIONS[intermediate_activation],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
x += self.self_attn(self.layer_norm1(x), mask)
|
||||||
|
x += self.mlp(self.layer_norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoder(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers,
|
||||||
|
embed_dim,
|
||||||
|
heads,
|
||||||
|
intermediate_size,
|
||||||
|
intermediate_activation,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
CLIPLayer(
|
||||||
|
embed_dim,
|
||||||
|
heads,
|
||||||
|
intermediate_size,
|
||||||
|
intermediate_activation,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, intermediate_output=None):
|
||||||
|
if intermediate_output is not None:
|
||||||
|
if intermediate_output < 0:
|
||||||
|
intermediate_output = len(self.layers) + intermediate_output
|
||||||
|
intermediate = None
|
||||||
|
for i, l in enumerate(self.layers):
|
||||||
|
x = l(x, mask)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.token_embedding = torch.nn.Embedding(
|
||||||
|
vocab_size, embed_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.position_embedding = torch.nn.Embedding(
|
||||||
|
num_positions, embed_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_tokens):
|
||||||
|
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextModel_(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device):
|
||||||
|
num_layers = config_dict["num_hidden_layers"]
|
||||||
|
embed_dim = config_dict["hidden_size"]
|
||||||
|
heads = config_dict["num_attention_heads"]
|
||||||
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
super().__init__()
|
||||||
|
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||||
|
self.encoder = CLIPEncoder(
|
||||||
|
num_layers,
|
||||||
|
embed_dim,
|
||||||
|
heads,
|
||||||
|
intermediate_size,
|
||||||
|
intermediate_activation,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True
|
||||||
|
):
|
||||||
|
x = self.embeddings(input_tokens)
|
||||||
|
causal_mask = (
|
||||||
|
torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
|
||||||
|
.fill_(float("-inf"))
|
||||||
|
.triu_(1)
|
||||||
|
)
|
||||||
|
x, i = self.encoder(
|
||||||
|
x, mask=causal_mask, intermediate_output=intermediate_output
|
||||||
|
)
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if i is not None and final_layer_norm_intermediate:
|
||||||
|
i = self.final_layer_norm(i)
|
||||||
|
pooled_output = x[
|
||||||
|
torch.arange(x.shape[0], device=x.device),
|
||||||
|
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
|
||||||
|
]
|
||||||
|
return x, i, pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextModel(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = config_dict["num_hidden_layers"]
|
||||||
|
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||||
|
embed_dim = config_dict["hidden_size"]
|
||||||
|
self.text_projection = nn.Linear(
|
||||||
|
embed_dim, embed_dim, bias=False, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.text_model.embeddings.token_embedding
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.text_model.embeddings.token_embedding = embeddings
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
x = self.text_model(*args, **kwargs)
|
||||||
|
out = self.text_projection(x[2])
|
||||||
|
return (x[0], x[1], out, x[2])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_parentheses(string):
|
||||||
|
result = []
|
||||||
|
current_item = ""
|
||||||
|
nesting_level = 0
|
||||||
|
for char in string:
|
||||||
|
if char == "(":
|
||||||
|
if nesting_level == 0:
|
||||||
|
if current_item:
|
||||||
|
result.append(current_item)
|
||||||
|
current_item = "("
|
||||||
|
else:
|
||||||
|
current_item = "("
|
||||||
|
else:
|
||||||
|
current_item += char
|
||||||
|
nesting_level += 1
|
||||||
|
elif char == ")":
|
||||||
|
nesting_level -= 1
|
||||||
|
if nesting_level == 0:
|
||||||
|
result.append(current_item + ")")
|
||||||
|
current_item = ""
|
||||||
|
else:
|
||||||
|
current_item += char
|
||||||
|
else:
|
||||||
|
current_item += char
|
||||||
|
if current_item:
|
||||||
|
result.append(current_item)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def token_weights(string, current_weight):
|
||||||
|
a = parse_parentheses(string)
|
||||||
|
out = []
|
||||||
|
for x in a:
|
||||||
|
weight = current_weight
|
||||||
|
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
|
||||||
|
x = x[1:-1]
|
||||||
|
xx = x.rfind(":")
|
||||||
|
weight *= 1.1
|
||||||
|
if xx > 0:
|
||||||
|
try:
|
||||||
|
weight = float(x[xx + 1 :])
|
||||||
|
x = x[:xx]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
out += token_weights(x, weight)
|
||||||
|
else:
|
||||||
|
out += [(x, current_weight)]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def escape_important(text):
|
||||||
|
text = text.replace("\\)", "\0\1")
|
||||||
|
text = text.replace("\\(", "\0\2")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def unescape_important(text):
|
||||||
|
text = text.replace("\0\1", ")")
|
||||||
|
text = text.replace("\0\2", "(")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class SDTokenizer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_length=77,
|
||||||
|
pad_with_end=True,
|
||||||
|
tokenizer=None,
|
||||||
|
has_start_token=True,
|
||||||
|
pad_to_max_length=True,
|
||||||
|
min_length=None,
|
||||||
|
extra_padding_token=None,
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.min_length = min_length
|
||||||
|
|
||||||
|
empty = self.tokenizer("")["input_ids"]
|
||||||
|
if has_start_token:
|
||||||
|
self.tokens_start = 1
|
||||||
|
self.start_token = empty[0]
|
||||||
|
self.end_token = empty[1]
|
||||||
|
else:
|
||||||
|
self.tokens_start = 0
|
||||||
|
self.start_token = None
|
||||||
|
self.end_token = empty[0]
|
||||||
|
self.pad_with_end = pad_with_end
|
||||||
|
self.pad_to_max_length = pad_to_max_length
|
||||||
|
self.extra_padding_token = extra_padding_token
|
||||||
|
|
||||||
|
vocab = self.tokenizer.get_vocab()
|
||||||
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
|
self.max_word_length = 8
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||||
|
"""
|
||||||
|
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||||
|
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
|
||||||
|
"""
|
||||||
|
if self.pad_with_end:
|
||||||
|
pad_token = self.end_token
|
||||||
|
else:
|
||||||
|
pad_token = 0
|
||||||
|
|
||||||
|
text = escape_important(text)
|
||||||
|
parsed_weights = token_weights(text, 1.0)
|
||||||
|
|
||||||
|
# tokenize words
|
||||||
|
tokens = []
|
||||||
|
for weighted_segment, weight in parsed_weights:
|
||||||
|
to_tokenize = (
|
||||||
|
unescape_important(weighted_segment).replace("\n", " ").split(" ")
|
||||||
|
)
|
||||||
|
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||||
|
for word in to_tokenize:
|
||||||
|
# parse word
|
||||||
|
tokens.append(
|
||||||
|
[
|
||||||
|
(t, weight)
|
||||||
|
for t in self.tokenizer(word)["input_ids"][
|
||||||
|
self.tokens_start : -1
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# reshape token array to CLIP input size
|
||||||
|
batched_tokens = []
|
||||||
|
batch = []
|
||||||
|
if self.start_token is not None:
|
||||||
|
batch.append((self.start_token, 1.0, 0))
|
||||||
|
batched_tokens.append(batch)
|
||||||
|
for i, t_group in enumerate(tokens):
|
||||||
|
# determine if we're going to try and keep the tokens in a single batch
|
||||||
|
is_large = len(t_group) >= self.max_word_length
|
||||||
|
|
||||||
|
while len(t_group) > 0:
|
||||||
|
if len(t_group) + len(batch) > self.max_length - 1:
|
||||||
|
remaining_length = self.max_length - len(batch) - 1
|
||||||
|
# break word in two and add end token
|
||||||
|
if is_large:
|
||||||
|
batch.extend(
|
||||||
|
[(t, w, i + 1) for t, w in t_group[:remaining_length]]
|
||||||
|
)
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
|
t_group = t_group[remaining_length:]
|
||||||
|
# add end token and pad
|
||||||
|
else:
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
|
if self.pad_to_max_length:
|
||||||
|
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||||
|
# start new batch
|
||||||
|
batch = []
|
||||||
|
if self.start_token is not None:
|
||||||
|
batch.append((self.start_token, 1.0, 0))
|
||||||
|
batched_tokens.append(batch)
|
||||||
|
else:
|
||||||
|
batch.extend([(t, w, i + 1) for t, w in t_group])
|
||||||
|
t_group = []
|
||||||
|
|
||||||
|
# pad extra padding token first befor getting to the end token
|
||||||
|
if self.extra_padding_token is not None:
|
||||||
|
batch.extend(
|
||||||
|
[(self.extra_padding_token, 1.0, 0)]
|
||||||
|
* (self.min_length - len(batch) - 1)
|
||||||
|
)
|
||||||
|
# fill last batch
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
|
if self.pad_to_max_length:
|
||||||
|
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
|
if self.min_length is not None and len(batch) < self.min_length:
|
||||||
|
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||||
|
|
||||||
|
if not return_word_ids:
|
||||||
|
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||||
|
|
||||||
|
return batched_tokens
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLClipGTokenizer(SDTokenizer):
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
super().__init__(pad_with_end=False, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class SD3Tokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||||
|
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||||
|
self.t5xxl = T5XXLTokenizer()
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text: str):
|
||||||
|
out = {}
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text)
|
||||||
|
out["g"] = self.clip_g.tokenize_with_weights(text)
|
||||||
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ClipTokenWeightEncoder:
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
|
||||||
|
out, pooled = self([tokens])
|
||||||
|
if pooled is not None:
|
||||||
|
first_pooled = pooled[0:1].cpu()
|
||||||
|
else:
|
||||||
|
first_pooled = pooled
|
||||||
|
output = [out[0:1]]
|
||||||
|
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||||
|
|
||||||
|
|
||||||
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
|
|
||||||
|
LAYERS = ["last", "pooled", "hidden"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device="cpu",
|
||||||
|
max_length=77,
|
||||||
|
layer="last",
|
||||||
|
layer_idx=None,
|
||||||
|
textmodel_json_config=None,
|
||||||
|
dtype=None,
|
||||||
|
model_class=CLIPTextModel,
|
||||||
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
|
||||||
|
layer_norm_hidden_state=True,
|
||||||
|
return_projected_pooled=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert layer in self.LAYERS
|
||||||
|
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||||
|
self.num_layers = self.transformer.num_layers
|
||||||
|
self.max_length = max_length
|
||||||
|
self.transformer = self.transformer.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.layer = layer
|
||||||
|
self.layer_idx = None
|
||||||
|
self.special_tokens = special_tokens
|
||||||
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
|
self.return_projected_pooled = return_projected_pooled
|
||||||
|
if layer == "hidden":
|
||||||
|
assert layer_idx is not None
|
||||||
|
assert abs(layer_idx) < self.num_layers
|
||||||
|
self.set_clip_options({"layer": layer_idx})
|
||||||
|
self.options_default = (
|
||||||
|
self.layer,
|
||||||
|
self.layer_idx,
|
||||||
|
self.return_projected_pooled,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
|
self.return_projected_pooled = options.get(
|
||||||
|
"projected_pooled", self.return_projected_pooled
|
||||||
|
)
|
||||||
|
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
|
self.layer = "last"
|
||||||
|
else:
|
||||||
|
self.layer = "hidden"
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def forward(self, tokens):
|
||||||
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
|
device = backup_embeds.weight.device
|
||||||
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
outputs = self.transformer(
|
||||||
|
tokens,
|
||||||
|
intermediate_output=self.layer_idx,
|
||||||
|
final_layer_norm_intermediate=self.layer_norm_hidden_state,
|
||||||
|
)
|
||||||
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
if self.layer == "last":
|
||||||
|
z = outputs[0]
|
||||||
|
else:
|
||||||
|
z = outputs[1]
|
||||||
|
pooled_output = None
|
||||||
|
if len(outputs) >= 3:
|
||||||
|
if (
|
||||||
|
not self.return_projected_pooled
|
||||||
|
and len(outputs) >= 4
|
||||||
|
and outputs[3] is not None
|
||||||
|
):
|
||||||
|
pooled_output = outputs[3].float()
|
||||||
|
elif outputs[2] is not None:
|
||||||
|
pooled_output = outputs[2].float()
|
||||||
|
return z.float(), pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLClipG(SDClipModel):
|
||||||
|
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None
|
||||||
|
):
|
||||||
|
if layer == "penultimate":
|
||||||
|
layer = "hidden"
|
||||||
|
layer_idx = -2
|
||||||
|
super().__init__(
|
||||||
|
device=device,
|
||||||
|
layer=layer,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
textmodel_json_config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
special_tokens={"start": 49406, "end": 49407, "pad": 0},
|
||||||
|
layer_norm_hidden_state=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLModel(SDClipModel):
|
||||||
|
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
|
||||||
|
|
||||||
|
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||||
|
super().__init__(
|
||||||
|
device=device,
|
||||||
|
layer=layer,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
textmodel_json_config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
special_tokens={"end": 1, "pad": 0},
|
||||||
|
model_class=T5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXLTokenizer(SDTokenizer):
|
||||||
|
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
pad_with_end=False,
|
||||||
|
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
|
||||||
|
has_start_token=False,
|
||||||
|
pad_to_max_length=False,
|
||||||
|
max_length=99999999,
|
||||||
|
min_length=77,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T5LayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.ones(hidden_size, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
variance = x.pow(2).mean(-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||||
|
|
||||||
|
|
||||||
|
class T5DenseGatedActDense(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||||
|
super().__init__()
|
||||||
|
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||||
|
hidden_linear = self.wi_1(x)
|
||||||
|
x = hidden_gelu * hidden_linear
|
||||||
|
x = self.wo(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class T5LayerFF(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||||
|
super().__init__()
|
||||||
|
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
|
||||||
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
forwarded_states = self.layer_norm(x)
|
||||||
|
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||||
|
x += forwarded_states
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class T5Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||||
|
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.relative_attention_bias = None
|
||||||
|
if relative_attention_bias:
|
||||||
|
self.relative_attention_num_buckets = 32
|
||||||
|
self.relative_attention_max_distance = 128
|
||||||
|
self.relative_attention_bias = torch.nn.Embedding(
|
||||||
|
self.relative_attention_num_buckets, self.num_heads, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(
|
||||||
|
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adapted from Mesh Tensorflow:
|
||||||
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relative_position: an int32 Tensor
|
||||||
|
bidirectional: a boolean - whether the attention is bidirectional
|
||||||
|
num_buckets: an integer
|
||||||
|
max_distance: an integer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||||
|
relative_position = torch.abs(relative_position)
|
||||||
|
else:
|
||||||
|
relative_position = -torch.min(
|
||||||
|
relative_position, torch.zeros_like(relative_position)
|
||||||
|
)
|
||||||
|
# now relative_position is in the range [0, inf)
|
||||||
|
# half of the buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
torch.log(relative_position.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
relative_position_if_large = torch.min(
|
||||||
|
relative_position_if_large,
|
||||||
|
torch.full_like(relative_position_if_large, num_buckets - 1),
|
||||||
|
)
|
||||||
|
relative_buckets += torch.where(
|
||||||
|
is_small, relative_position, relative_position_if_large
|
||||||
|
)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length, device):
|
||||||
|
"""Compute binned relative position bias"""
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
|
||||||
|
:, None
|
||||||
|
]
|
||||||
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
|
||||||
|
None, :
|
||||||
|
]
|
||||||
|
relative_position = (
|
||||||
|
memory_position - context_position
|
||||||
|
) # shape (query_length, key_length)
|
||||||
|
relative_position_bucket = self._relative_position_bucket(
|
||||||
|
relative_position, # shape (query_length, key_length)
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
|
max_distance=self.relative_attention_max_distance,
|
||||||
|
)
|
||||||
|
values = self.relative_attention_bias(
|
||||||
|
relative_position_bucket
|
||||||
|
) # shape (query_length, key_length, num_heads)
|
||||||
|
values = values.permute([2, 0, 1]).unsqueeze(
|
||||||
|
0
|
||||||
|
) # shape (1, num_heads, query_length, key_length)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(self, x, past_bias=None):
|
||||||
|
q = self.q(x)
|
||||||
|
k = self.k(x)
|
||||||
|
v = self.v(x)
|
||||||
|
if self.relative_attention_bias is not None:
|
||||||
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||||
|
if past_bias is not None:
|
||||||
|
mask = past_bias
|
||||||
|
out = attention(
|
||||||
|
q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask
|
||||||
|
)
|
||||||
|
return self.o(out), past_bias
|
||||||
|
|
||||||
|
|
||||||
|
class T5LayerSelfAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dim,
|
||||||
|
inner_dim,
|
||||||
|
ff_dim,
|
||||||
|
num_heads,
|
||||||
|
relative_attention_bias,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.SelfAttention = T5Attention(
|
||||||
|
model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
|
||||||
|
)
|
||||||
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, past_bias=None):
|
||||||
|
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
|
||||||
|
x += output
|
||||||
|
return x, past_bias
|
||||||
|
|
||||||
|
|
||||||
|
class T5Block(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dim,
|
||||||
|
inner_dim,
|
||||||
|
ff_dim,
|
||||||
|
num_heads,
|
||||||
|
relative_attention_bias,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = torch.nn.ModuleList()
|
||||||
|
self.layer.append(
|
||||||
|
T5LayerSelfAttention(
|
||||||
|
model_dim,
|
||||||
|
inner_dim,
|
||||||
|
ff_dim,
|
||||||
|
num_heads,
|
||||||
|
relative_attention_bias,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
|
||||||
|
|
||||||
|
def forward(self, x, past_bias=None):
|
||||||
|
x, past_bias = self.layer[0](x, past_bias)
|
||||||
|
x = self.layer[-1](x)
|
||||||
|
return x, past_bias
|
||||||
|
|
||||||
|
|
||||||
|
class T5Stack(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers,
|
||||||
|
model_dim,
|
||||||
|
inner_dim,
|
||||||
|
ff_dim,
|
||||||
|
num_heads,
|
||||||
|
vocab_size,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
||||||
|
self.block = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
T5Block(
|
||||||
|
model_dim,
|
||||||
|
inner_dim,
|
||||||
|
ff_dim,
|
||||||
|
num_heads,
|
||||||
|
relative_attention_bias=(i == 0),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True
|
||||||
|
):
|
||||||
|
intermediate = None
|
||||||
|
x = self.embed_tokens(input_ids)
|
||||||
|
past_bias = None
|
||||||
|
for i, l in enumerate(self.block):
|
||||||
|
x, past_bias = l(x, past_bias)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if intermediate is not None and final_layer_norm_intermediate:
|
||||||
|
intermediate = self.final_layer_norm(intermediate)
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
|
||||||
|
class T5(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = config_dict["num_layers"]
|
||||||
|
self.encoder = T5Stack(
|
||||||
|
self.num_layers,
|
||||||
|
config_dict["d_model"],
|
||||||
|
config_dict["d_model"],
|
||||||
|
config_dict["d_ff"],
|
||||||
|
config_dict["num_heads"],
|
||||||
|
config_dict["vocab_size"],
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.encoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.encoder.embed_tokens = embeddings
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.encoder(*args, **kwargs)
|
||||||
222
modules/models/sd35/sd3_cond.py
Normal file
222
modules/models/sd35/sd3_cond.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
import os
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
|
||||||
|
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsMapping(typing.Mapping):
|
||||||
|
def __init__(self, file):
|
||||||
|
self.file = file
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.file.keys())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for key in self.file.keys():
|
||||||
|
yield key
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.file.get_tensor(key)
|
||||||
|
|
||||||
|
|
||||||
|
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
|
||||||
|
CLIPL_CONFIG = {
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 768,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
|
||||||
|
CLIPG_CONFIG = {
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"intermediate_size": 5120,
|
||||||
|
"num_attention_heads": 20,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"textual_inversion_key": "clip_g",
|
||||||
|
}
|
||||||
|
|
||||||
|
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||||
|
T5_CONFIG = {
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_model": 4096,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"vocab_size": 32128,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
|
||||||
|
def __init__(self, clip_l, clip_g):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.clip_l = clip_l
|
||||||
|
self.clip_g = clip_g
|
||||||
|
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
|
||||||
|
empty = self.tokenizer('')["input_ids"]
|
||||||
|
self.id_start = empty[0]
|
||||||
|
self.id_end = empty[1]
|
||||||
|
self.id_pad = empty[1]
|
||||||
|
|
||||||
|
self.return_pooled = True
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
tokens_g = tokens.clone()
|
||||||
|
|
||||||
|
for batch_pos in range(tokens_g.shape[0]):
|
||||||
|
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
|
||||||
|
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
|
||||||
|
|
||||||
|
l_out, l_pooled = self.clip_l(tokens)
|
||||||
|
g_out, g_pooled = self.clip_g(tokens_g)
|
||||||
|
|
||||||
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
|
|
||||||
|
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
|
lg_out.pooled = vector_out
|
||||||
|
return lg_out
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3T5(torch.nn.Module):
|
||||||
|
def __init__(self, t5xxl):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.t5xxl = t5xxl
|
||||||
|
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
|
||||||
|
|
||||||
|
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
|
||||||
|
self.id_end = empty[0]
|
||||||
|
self.id_pad = empty[1]
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
def tokenize_line(self, line, *, target_token_count=None):
|
||||||
|
if shared.opts.emphasis != "None":
|
||||||
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
|
else:
|
||||||
|
parsed = [[line, 1.0]]
|
||||||
|
|
||||||
|
tokenized = self.tokenize([text for text, _ in parsed])
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
multipliers = []
|
||||||
|
|
||||||
|
for text_tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
if text == 'BREAK' and weight == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tokens += text_tokens
|
||||||
|
multipliers += [weight] * len(text_tokens)
|
||||||
|
|
||||||
|
tokens += [self.id_end]
|
||||||
|
multipliers += [1.0]
|
||||||
|
|
||||||
|
if target_token_count is not None:
|
||||||
|
if len(tokens) < target_token_count:
|
||||||
|
tokens += [self.id_pad] * (target_token_count - len(tokens))
|
||||||
|
multipliers += [1.0] * (target_token_count - len(tokens))
|
||||||
|
else:
|
||||||
|
tokens = tokens[0:target_token_count]
|
||||||
|
multipliers = multipliers[0:target_token_count]
|
||||||
|
|
||||||
|
return tokens, multipliers
|
||||||
|
|
||||||
|
def forward(self, texts, *, token_count):
|
||||||
|
if not self.t5xxl or not shared.opts.sd3_enable_t5:
|
||||||
|
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
|
||||||
|
|
||||||
|
tokens_batch = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
|
||||||
|
tokens_batch.append(tokens)
|
||||||
|
|
||||||
|
t5_out, t5_pooled = self.t5xxl(tokens_batch)
|
||||||
|
|
||||||
|
return t5_out
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
|
||||||
|
|
||||||
|
|
||||||
|
class SD3Cond(torch.nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.tokenizer = SD3Tokenizer()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
||||||
|
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||||
|
|
||||||
|
if shared.opts.sd3_enable_t5:
|
||||||
|
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||||
|
else:
|
||||||
|
self.t5xxl = None
|
||||||
|
|
||||||
|
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||||
|
self.model_t5 = Sd3T5(self.t5xxl)
|
||||||
|
|
||||||
|
def forward(self, prompts: list[str]):
|
||||||
|
with devices.without_autocast():
|
||||||
|
lg_out, vector_out = self.model_lg(prompts)
|
||||||
|
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
|
||||||
|
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'crossattn': lgt_out,
|
||||||
|
'vector': vector_out,
|
||||||
|
}
|
||||||
|
|
||||||
|
def before_load_weights(self, state_dict):
|
||||||
|
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||||
|
|
||||||
|
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||||
|
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||||
|
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||||
|
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||||
|
|
||||||
|
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||||
|
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||||
|
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||||
|
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
|
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||||
|
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||||
|
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||||
|
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
return self.model_lg.tokenize(texts)
|
||||||
|
|
||||||
|
def medvram_modules(self):
|
||||||
|
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||||
|
|
||||||
|
def get_token_count(self, text):
|
||||||
|
_, token_count = self.model_lg.process_texts([text])
|
||||||
|
|
||||||
|
return token_count
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(self, token_count):
|
||||||
|
return self.model_lg.get_target_prompt_token_count(token_count)
|
||||||
623
modules/models/sd35/sd3_impls.py
Normal file
623
modules/models/sd35/sd3_impls.py
Normal file
@@ -0,0 +1,623 @@
|
|||||||
|
### Impls of the SD3 core diffusion model and VAE
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modules.models.sd35.mmditx import MMDiTX
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### MMDiT Model Wrapping
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||||
|
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||||
|
|
||||||
|
def __init__(self, shift=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.shift = shift
|
||||||
|
timesteps = 1000
|
||||||
|
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||||
|
self.register_buffer("sigmas", ts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma * 1000
|
||||||
|
|
||||||
|
def sigma(self, timestep: torch.Tensor):
|
||||||
|
timestep = timestep / 1000.0
|
||||||
|
if self.shift == 1.0:
|
||||||
|
return timestep
|
||||||
|
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||||
|
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
return sigma * noise + (1.0 - sigma) * latent_image
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(torch.nn.Module):
|
||||||
|
"""Wrapper around the core MM-DiT model"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
shift=1.0,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
prefix = ''
|
||||||
|
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||||
|
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||||
|
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||||
|
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||||
|
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||||
|
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||||
|
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||||
|
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||||
|
qk_norm = (
|
||||||
|
"rms"
|
||||||
|
if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
x_block_self_attn_layers = sorted(
|
||||||
|
[
|
||||||
|
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
|
||||||
|
for key in list(
|
||||||
|
filter(
|
||||||
|
re.compile(".*.x_block.attn2.ln_k.weight").match, state_dict.keys()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
context_embedder_config = {
|
||||||
|
"target": "torch.nn.Linear",
|
||||||
|
"params": {
|
||||||
|
"in_features": context_shape[1],
|
||||||
|
"out_features": context_shape[0],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.diffusion_model = MMDiTX(
|
||||||
|
input_size=None,
|
||||||
|
pos_embed_scaling_factor=None,
|
||||||
|
pos_embed_offset=None,
|
||||||
|
pos_embed_max_size=pos_embed_max_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_channels=16,
|
||||||
|
depth=depth,
|
||||||
|
num_patches=num_patches,
|
||||||
|
adm_in_channels=adm_in_channels,
|
||||||
|
context_embedder_config=context_embedder_config,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||||
|
# device=kwargs['device'],
|
||||||
|
# dtype=kwargs['dtype'],
|
||||||
|
# verbose=kwargs['verbose'],
|
||||||
|
# **kwargs
|
||||||
|
)
|
||||||
|
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||||
|
|
||||||
|
def apply_model(self, x, sigma, y=None, *args, **kwargs):
|
||||||
|
dtype = self.get_dtype()
|
||||||
|
timestep = self.model_sampling.timestep(sigma).float()
|
||||||
|
model_output = self.diffusion_model(
|
||||||
|
x.to(dtype), timestep, context=kwargs["context"].to(dtype), y=y.to(dtype)
|
||||||
|
).float()
|
||||||
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.apply_model(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_dtype(self):
|
||||||
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiser(torch.nn.Module):
|
||||||
|
"""Helper for applying CFG Scaling to diffusion outputs"""
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(self, x, timestep, cond, uncond, cond_scale):
|
||||||
|
# Run cond and uncond in a batch together
|
||||||
|
batched = self.model.apply_model(
|
||||||
|
torch.cat([x, x]),
|
||||||
|
torch.cat([timestep, timestep]),
|
||||||
|
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
|
||||||
|
y=torch.cat([cond["y"], uncond["y"]]),
|
||||||
|
)
|
||||||
|
# Then split and apply CFG Scaling
|
||||||
|
pos_out, neg_out = batched.chunk(2)
|
||||||
|
scaled = neg_out + (pos_out - neg_out) * cond_scale
|
||||||
|
return scaled
|
||||||
|
|
||||||
|
|
||||||
|
class SD3LatentFormat:
|
||||||
|
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.5305
|
||||||
|
self.shift_factor = 0.0609
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
||||||
|
def decode_latent_to_preview(self, x0):
|
||||||
|
"""Quick RGB approximate preview of sd3 latents"""
|
||||||
|
factors = torch.tensor(
|
||||||
|
[
|
||||||
|
[-0.0645, 0.0177, 0.1052],
|
||||||
|
[0.0028, 0.0312, 0.0650],
|
||||||
|
[0.1848, 0.0762, 0.0360],
|
||||||
|
[0.0944, 0.0360, 0.0889],
|
||||||
|
[0.0897, 0.0506, -0.0364],
|
||||||
|
[-0.0020, 0.1203, 0.0284],
|
||||||
|
[0.0855, 0.0118, 0.0283],
|
||||||
|
[-0.0539, 0.0658, 0.1047],
|
||||||
|
[-0.0057, 0.0116, 0.0700],
|
||||||
|
[-0.0412, 0.0281, -0.0039],
|
||||||
|
[0.1106, 0.1171, 0.1220],
|
||||||
|
[-0.0248, 0.0682, -0.0481],
|
||||||
|
[0.0815, 0.0846, 0.1207],
|
||||||
|
[-0.0120, -0.0055, -0.0867],
|
||||||
|
[-0.0749, -0.0634, -0.0456],
|
||||||
|
[-0.1418, -0.1457, -0.1259],
|
||||||
|
],
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
|
||||||
|
|
||||||
|
latents_ubyte = (
|
||||||
|
((latent_image + 1) / 2)
|
||||||
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
.byte()
|
||||||
|
).cpu()
|
||||||
|
|
||||||
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### Samplers
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def append_dims(x, target_dims):
|
||||||
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||||
|
dims_to_append = target_dims - x.ndim
|
||||||
|
return x[(...,) + (None,) * dims_to_append]
|
||||||
|
|
||||||
|
|
||||||
|
def to_d(x, sigma, denoised):
|
||||||
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||||
|
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.autocast("cuda", dtype=torch.float16)
|
||||||
|
def sample_euler(model, x, sigmas, extra_args=None):
|
||||||
|
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in tqdm(range(len(sigmas) - 1)):
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||||
|
d = to_d(x, sigma_hat, denoised)
|
||||||
|
dt = sigmas[i + 1] - sigma_hat
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.autocast("cuda", dtype=torch.float16)
|
||||||
|
def sample_dpmpp_2m(model, x, sigmas, extra_args=None):
|
||||||
|
"""DPM-Solver++(2M)."""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
t_fn = lambda sigma: sigma.log().neg()
|
||||||
|
old_denoised = None
|
||||||
|
for i in tqdm(range(len(sigmas) - 1)):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||||
|
h = t_next - t
|
||||||
|
if old_denoised is None or sigmas[i + 1] == 0:
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||||
|
else:
|
||||||
|
h_last = t - t_fn(sigmas[i - 1])
|
||||||
|
r = h_last / h
|
||||||
|
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||||
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||||
|
old_denoised = denoised
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### VAE
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
||||||
|
return torch.nn.GroupNorm(
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_channels=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, *, in_channels, out_channels=None, dtype=torch.float32, device=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
||||||
|
self.conv1 = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
||||||
|
self.conv2 = torch.nn.Conv2d(
|
||||||
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
self.nin_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = None
|
||||||
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden = x
|
||||||
|
hidden = self.norm1(hidden)
|
||||||
|
hidden = self.swish(hidden)
|
||||||
|
hidden = self.conv1(hidden)
|
||||||
|
hidden = self.norm2(hidden)
|
||||||
|
hidden = self.swish(hidden)
|
||||||
|
hidden = self.conv2(hidden)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
return x + hidden
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||||
|
self.q = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.k = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.v = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden = self.norm(x)
|
||||||
|
q = self.q(hidden)
|
||||||
|
k = self.k(hidden)
|
||||||
|
v = self.v(hidden)
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
hidden = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v
|
||||||
|
) # scale is dim ** -0.5 per default
|
||||||
|
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||||
|
hidden = self.proj_out(hidden)
|
||||||
|
return x + hidden
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VAEEncoder(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch=128,
|
||||||
|
ch_mult=(1, 2, 4, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
in_channels=3,
|
||||||
|
z_channels=16,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = torch.nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = torch.nn.ModuleList()
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = torch.nn.ModuleList()
|
||||||
|
attn = torch.nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for i_block in range(num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
down = torch.nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
||||||
|
self.down.append(down)
|
||||||
|
# middle
|
||||||
|
self.mid = torch.nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||||
|
self.conv_out = torch.nn.Conv2d(
|
||||||
|
block_in,
|
||||||
|
2 * z_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# downsampling
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](hs[-1])
|
||||||
|
hs.append(h)
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = self.swish(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDecoder(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=(1, 2, 4, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
resolution=256,
|
||||||
|
z_channels=16,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = torch.nn.Conv2d(
|
||||||
|
z_channels,
|
||||||
|
block_in,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# middle
|
||||||
|
self.mid = torch.nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(
|
||||||
|
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||||
|
self.mid.block_2 = ResnetBlock(
|
||||||
|
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
# upsampling
|
||||||
|
self.up = torch.nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = torch.nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
up = torch.nn.Module()
|
||||||
|
up.block = block
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||||
|
self.conv_out = torch.nn.Conv2d(
|
||||||
|
block_in,
|
||||||
|
out_ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
# z to block_in
|
||||||
|
hidden = self.conv_in(z)
|
||||||
|
# middle
|
||||||
|
hidden = self.mid.block_1(hidden)
|
||||||
|
hidden = self.mid.attn_1(hidden)
|
||||||
|
hidden = self.mid.block_2(hidden)
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
hidden = self.up[i_level].block[i_block](hidden)
|
||||||
|
if i_level != 0:
|
||||||
|
hidden = self.up[i_level].upsample(hidden)
|
||||||
|
# end
|
||||||
|
hidden = self.norm_out(hidden)
|
||||||
|
hidden = self.swish(hidden)
|
||||||
|
hidden = self.conv_out(hidden)
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
|
||||||
|
class SDVAE(torch.nn.Module):
|
||||||
|
def __init__(self, dtype=torch.float32, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
||||||
|
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
@torch.autocast("cuda", dtype=torch.float16)
|
||||||
|
def decode(self, latent):
|
||||||
|
return self.decoder(latent)
|
||||||
|
|
||||||
|
@torch.autocast("cuda", dtype=torch.float16)
|
||||||
|
def encode(self, image):
|
||||||
|
hidden = self.encoder(image)
|
||||||
|
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
||||||
|
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
return mean + std * torch.randn_like(mean)
|
||||||
436
modules/models/sd35/sd3_infer.py
Normal file
436
modules/models/sd35/sd3_infer.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
# NOTE: Must have folder `models` with the following files:
|
||||||
|
# - `clip_g.safetensors` (openclip bigG, same as SDXL)
|
||||||
|
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
|
||||||
|
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
|
||||||
|
# - `sd3_medium.safetensors` (or whichever main MMDiT model file)
|
||||||
|
# Also can have
|
||||||
|
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from safetensors import safe_open
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modules.models.sd35 import sd3_impls
|
||||||
|
from modules.models.sd35.other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
|
||||||
|
from modules.models.sd35.sd3_impls import SDVAE, BaseModel, CFGDenoiser, SD3LatentFormat
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### Wrappers for model parts
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def load_into(f, model, prefix, device, dtype=None):
|
||||||
|
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
|
||||||
|
for key in f.keys():
|
||||||
|
if key.startswith(prefix) and not key.startswith("loss."):
|
||||||
|
path = key[len(prefix) :].split(".")
|
||||||
|
obj = model
|
||||||
|
for p in path:
|
||||||
|
if obj is list:
|
||||||
|
obj = obj[int(p)]
|
||||||
|
else:
|
||||||
|
obj = getattr(obj, p, None)
|
||||||
|
if obj is None:
|
||||||
|
print(
|
||||||
|
f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if obj is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
tensor = f.get_tensor(key).to(device=device)
|
||||||
|
if dtype is not None:
|
||||||
|
tensor = tensor.to(dtype=dtype)
|
||||||
|
obj.requires_grad_(False)
|
||||||
|
obj.set_(tensor)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load key '{key}' in safetensors file: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
CLIPG_CONFIG = {
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"intermediate_size": 5120,
|
||||||
|
"num_attention_heads": 20,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ClipG:
|
||||||
|
def __init__(self):
|
||||||
|
with safe_open("models/clip_g.safetensors", framework="pt", device="cpu") as f:
|
||||||
|
self.model = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
|
||||||
|
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
CLIPL_CONFIG = {
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 768,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ClipL:
|
||||||
|
def __init__(self):
|
||||||
|
with safe_open("models/clip_l.safetensors", framework="pt", device="cpu") as f:
|
||||||
|
self.model = SDClipModel(
|
||||||
|
layer="hidden",
|
||||||
|
layer_idx=-2,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
layer_norm_hidden_state=False,
|
||||||
|
return_projected_pooled=False,
|
||||||
|
textmodel_json_config=CLIPL_CONFIG,
|
||||||
|
)
|
||||||
|
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
T5_CONFIG = {
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_model": 4096,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"vocab_size": 32128,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class T5XXL:
|
||||||
|
def __init__(self):
|
||||||
|
with safe_open("models/t5xxl.safetensors", framework="pt", device="cpu") as f:
|
||||||
|
self.model = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
|
||||||
|
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class SD3:
|
||||||
|
def __init__(self, model, shift, verbose=False):
|
||||||
|
with safe_open(model, framework="pt", device="cpu") as f:
|
||||||
|
self.model = BaseModel(
|
||||||
|
shift=shift,
|
||||||
|
file=f,
|
||||||
|
prefix="model.diffusion_model.",
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float16,
|
||||||
|
verbose=verbose,
|
||||||
|
).eval()
|
||||||
|
load_into(f, self.model, "model.", "cpu", torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
class VAE:
|
||||||
|
def __init__(self, model):
|
||||||
|
with safe_open(model, framework="pt", device="cpu") as f:
|
||||||
|
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
|
||||||
|
prefix = ""
|
||||||
|
if any(k.startswith("first_stage_model.") for k in f.keys()):
|
||||||
|
prefix = "first_stage_model."
|
||||||
|
load_into(f, self.model, prefix, "cpu", torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################################
|
||||||
|
### Main inference logic
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
# Note: Sigma shift value, publicly released models use 3.0
|
||||||
|
SHIFT = 3.0
|
||||||
|
# Naturally, adjust to the width/height of the model you have
|
||||||
|
WIDTH = 1024
|
||||||
|
HEIGHT = 1024
|
||||||
|
# Pick your prompt
|
||||||
|
PROMPT = "a photo of a cat"
|
||||||
|
# Most models prefer the range of 4-5, but still work well around 7
|
||||||
|
CFG_SCALE = 4.5
|
||||||
|
# Different models want different step counts but most will be good at 50, albeit that's slow to run
|
||||||
|
# sd3_medium is quite decent at 28 steps
|
||||||
|
STEPS = 40
|
||||||
|
# Seed
|
||||||
|
SEED = 23
|
||||||
|
# SEEDTYPE = "fixed"
|
||||||
|
SEEDTYPE = "rand"
|
||||||
|
# SEEDTYPE = "roll"
|
||||||
|
# Actual model file path
|
||||||
|
# MODEL = "models/sd3_medium.safetensors"
|
||||||
|
# MODEL = "models/sd3.5_large_turbo.safetensors"
|
||||||
|
MODEL = "models/sd3.5_large.safetensors"
|
||||||
|
# VAE model file path, or set None to use the same model file
|
||||||
|
VAEFile = None # "models/sd3_vae.safetensors"
|
||||||
|
# Optional init image file path
|
||||||
|
INIT_IMAGE = None
|
||||||
|
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
|
||||||
|
DENOISE = 0.6
|
||||||
|
# Output file path
|
||||||
|
OUTDIR = "outputs"
|
||||||
|
# SAMPLER
|
||||||
|
# SAMPLER = "euler"
|
||||||
|
SAMPLER = "dpmpp_2m"
|
||||||
|
|
||||||
|
|
||||||
|
class SD3Inferencer:
|
||||||
|
def print(self, txt):
|
||||||
|
if self.verbose:
|
||||||
|
print(txt)
|
||||||
|
|
||||||
|
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT, verbose=False):
|
||||||
|
self.verbose = verbose
|
||||||
|
print("Loading tokenizers...")
|
||||||
|
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
|
||||||
|
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
|
||||||
|
# (T5 tokenizer is different though)
|
||||||
|
self.tokenizer = SD3Tokenizer()
|
||||||
|
print("Loading OpenAI CLIP L...")
|
||||||
|
self.clip_l = ClipL()
|
||||||
|
print("Loading OpenCLIP bigG...")
|
||||||
|
self.clip_g = ClipG()
|
||||||
|
print("Loading Google T5-v1-XXL...")
|
||||||
|
self.t5xxl = T5XXL()
|
||||||
|
print(f"Loading SD3 model {os.path.basename(model)}...")
|
||||||
|
self.sd3 = SD3(model, shift, verbose)
|
||||||
|
print("Loading VAE model...")
|
||||||
|
self.vae = VAE(vae or model)
|
||||||
|
print("Models loaded.")
|
||||||
|
|
||||||
|
def get_empty_latent(self, width, height):
|
||||||
|
self.print("Prep an empty latent...")
|
||||||
|
return torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
|
||||||
|
|
||||||
|
def get_sigmas(self, sampling, steps):
|
||||||
|
start = sampling.timestep(sampling.sigma_max)
|
||||||
|
end = sampling.timestep(sampling.sigma_min)
|
||||||
|
timesteps = torch.linspace(start, end, steps)
|
||||||
|
sigs = []
|
||||||
|
for x in range(len(timesteps)):
|
||||||
|
ts = timesteps[x]
|
||||||
|
sigs.append(sampling.sigma(ts))
|
||||||
|
sigs += [0.0]
|
||||||
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
def get_noise(self, seed, latent):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
self.print(
|
||||||
|
f"dtype = {latent.dtype}, layout = {latent.layout}, device = {latent.device}"
|
||||||
|
)
|
||||||
|
return torch.randn(
|
||||||
|
latent.size(),
|
||||||
|
dtype=torch.float32,
|
||||||
|
layout=latent.layout,
|
||||||
|
generator=generator,
|
||||||
|
device="cpu",
|
||||||
|
).to(latent.dtype)
|
||||||
|
|
||||||
|
def get_cond(self, prompt):
|
||||||
|
self.print("Encode prompt...")
|
||||||
|
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
||||||
|
l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
|
||||||
|
g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
|
||||||
|
t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
|
||||||
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
|
return torch.cat([lg_out, t5_out], dim=-2), torch.cat(
|
||||||
|
(l_pooled, g_pooled), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_denoise(self, sigmas):
|
||||||
|
max_sigma = float(self.sd3.model.model_sampling.sigma_max)
|
||||||
|
sigma = float(sigmas[0])
|
||||||
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
|
def fix_cond(self, cond):
|
||||||
|
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
|
||||||
|
return {"c_crossattn": cond, "y": pooled}
|
||||||
|
|
||||||
|
def do_sampling(
|
||||||
|
self,
|
||||||
|
latent,
|
||||||
|
seed,
|
||||||
|
conditioning,
|
||||||
|
neg_cond,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
sampler="dpmpp_2m",
|
||||||
|
denoise=1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
self.print("Sampling...")
|
||||||
|
latent = latent.half().cuda()
|
||||||
|
self.sd3.model = self.sd3.model.cuda()
|
||||||
|
noise = self.get_noise(seed, latent).cuda()
|
||||||
|
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
|
||||||
|
sigmas = sigmas[int(steps * (1 - denoise)) :]
|
||||||
|
conditioning = self.fix_cond(conditioning)
|
||||||
|
neg_cond = self.fix_cond(neg_cond)
|
||||||
|
extra_args = {"cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale}
|
||||||
|
noise_scaled = self.sd3.model.model_sampling.noise_scaling(
|
||||||
|
sigmas[0], noise, latent, self.max_denoise(sigmas)
|
||||||
|
)
|
||||||
|
sample_fn = getattr(sd3_impls, f"sample_{sampler}")
|
||||||
|
latent = sample_fn(
|
||||||
|
CFGDenoiser(self.sd3.model), noise_scaled, sigmas, extra_args=extra_args
|
||||||
|
)
|
||||||
|
latent = SD3LatentFormat().process_out(latent)
|
||||||
|
self.sd3.model = self.sd3.model.cpu()
|
||||||
|
self.print("Sampling done")
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def vae_encode(self, image) -> torch.Tensor:
|
||||||
|
self.print("Encoding image to latent...")
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image_np = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image_np = np.moveaxis(image_np, 2, 0)
|
||||||
|
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
|
||||||
|
image_torch = torch.from_numpy(batch_images)
|
||||||
|
image_torch = 2.0 * image_torch - 1.0
|
||||||
|
image_torch = image_torch.cuda()
|
||||||
|
self.vae.model = self.vae.model.cuda()
|
||||||
|
latent = self.vae.model.encode(image_torch).cpu()
|
||||||
|
self.vae.model = self.vae.model.cpu()
|
||||||
|
self.print("Encoded")
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def vae_decode(self, latent) -> Image.Image:
|
||||||
|
self.print("Decoding latent to image...")
|
||||||
|
latent = latent.cuda()
|
||||||
|
self.vae.model = self.vae.model.cuda()
|
||||||
|
image = self.vae.model.decode(latent)
|
||||||
|
image = image.float()
|
||||||
|
self.vae.model = self.vae.model.cpu()
|
||||||
|
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||||
|
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||||
|
decoded_np = decoded_np.astype(np.uint8)
|
||||||
|
out_image = Image.fromarray(decoded_np)
|
||||||
|
self.print("Decoded")
|
||||||
|
return out_image
|
||||||
|
|
||||||
|
def gen_image(
|
||||||
|
self,
|
||||||
|
prompts=[PROMPT],
|
||||||
|
width=WIDTH,
|
||||||
|
height=HEIGHT,
|
||||||
|
steps=STEPS,
|
||||||
|
cfg_scale=CFG_SCALE,
|
||||||
|
sampler=SAMPLER,
|
||||||
|
seed=SEED,
|
||||||
|
seed_type=SEEDTYPE,
|
||||||
|
out_dir=OUTDIR,
|
||||||
|
init_image=INIT_IMAGE,
|
||||||
|
denoise=DENOISE,
|
||||||
|
):
|
||||||
|
latent = self.get_empty_latent(width, height)
|
||||||
|
if init_image:
|
||||||
|
image_data = Image.open(init_image)
|
||||||
|
image_data = image_data.resize((width, height), Image.LANCZOS)
|
||||||
|
latent = self.vae_encode(image_data)
|
||||||
|
latent = SD3LatentFormat().process_in(latent)
|
||||||
|
neg_cond = self.get_cond("")
|
||||||
|
seed_num = None
|
||||||
|
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True)
|
||||||
|
for i, prompt in pbar:
|
||||||
|
if seed_type == "roll":
|
||||||
|
seed_num = seed if seed_num is None else seed_num + 1
|
||||||
|
elif seed_type == "rand":
|
||||||
|
seed_num = torch.randint(0, 100000, (1,)).item()
|
||||||
|
else: # fixed
|
||||||
|
seed_num = seed
|
||||||
|
conditioning = self.get_cond(prompt)
|
||||||
|
sampled_latent = self.do_sampling(
|
||||||
|
latent,
|
||||||
|
seed_num,
|
||||||
|
conditioning,
|
||||||
|
neg_cond,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
sampler,
|
||||||
|
denoise if init_image else 1.0,
|
||||||
|
)
|
||||||
|
image = self.vae_decode(sampled_latent)
|
||||||
|
save_path = os.path.join(out_dir, f"{i:06d}.png")
|
||||||
|
self.print(f"Will save to {save_path}")
|
||||||
|
image.save(save_path)
|
||||||
|
self.print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGS = {
|
||||||
|
"sd3_medium": {
|
||||||
|
"shift": 1.0,
|
||||||
|
"cfg": 5.0,
|
||||||
|
"steps": 50,
|
||||||
|
"sampler": "dpmpp_2m",
|
||||||
|
},
|
||||||
|
"sd3.5_large": {
|
||||||
|
"shift": 3.0,
|
||||||
|
"cfg": 4.5,
|
||||||
|
"steps": 40,
|
||||||
|
"sampler": "dpmpp_2m",
|
||||||
|
},
|
||||||
|
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main(
|
||||||
|
prompt=PROMPT,
|
||||||
|
model=MODEL,
|
||||||
|
out_dir=OUTDIR,
|
||||||
|
postfix=None,
|
||||||
|
seed=SEED,
|
||||||
|
seed_type=SEEDTYPE,
|
||||||
|
sampler=None,
|
||||||
|
steps=None,
|
||||||
|
cfg=None,
|
||||||
|
shift=None,
|
||||||
|
width=WIDTH,
|
||||||
|
height=HEIGHT,
|
||||||
|
vae=VAEFile,
|
||||||
|
init_image=INIT_IMAGE,
|
||||||
|
denoise=DENOISE,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
|
||||||
|
cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
|
||||||
|
shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
|
||||||
|
sampler = (
|
||||||
|
sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
|
||||||
|
)
|
||||||
|
|
||||||
|
inferencer = SD3Inferencer()
|
||||||
|
inferencer.load(model, vae, shift, verbose)
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
if os.path.splitext(prompt)[-1] == ".txt":
|
||||||
|
with open(prompt, "r") as f:
|
||||||
|
prompts = [l.strip() for l in f.readlines()]
|
||||||
|
else:
|
||||||
|
prompts = [prompt]
|
||||||
|
|
||||||
|
out_dir = os.path.join(
|
||||||
|
out_dir,
|
||||||
|
os.path.splitext(os.path.basename(model))[0],
|
||||||
|
os.path.splitext(os.path.basename(prompt))[0][:50]
|
||||||
|
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
|
||||||
|
)
|
||||||
|
print(f"Saving images to {out_dir}")
|
||||||
|
os.makedirs(out_dir, exist_ok=False)
|
||||||
|
|
||||||
|
inferencer.gen_image(
|
||||||
|
prompts,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
steps,
|
||||||
|
cfg,
|
||||||
|
sampler,
|
||||||
|
seed,
|
||||||
|
seed_type,
|
||||||
|
out_dir,
|
||||||
|
init_image,
|
||||||
|
denoise,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
fire.Fire(main)
|
||||||
Reference in New Issue
Block a user