Text Processing Engine is Finished

100% reproduce all previous results, including TI embeddings, LoRAs in CLIP, emphasize settings, BREAK, timestep swap scheduling, AB mixture, advanced uncond, etc
Backend is 85% finished
This commit is contained in:
layerdiffusion
2024-08-04 14:13:37 -07:00
parent 2791203d5b
commit a72154405e
8 changed files with 123 additions and 90 deletions

View File

@@ -9,7 +9,7 @@ import backend.nn.unet
from omegaconf import OmegaConf
from modules.sd_models_config import find_checkpoint_config
from modules.shared import cmd_opts
from modules.shared import cmd_opts, opts
from modules import sd_hijack
from modules.sd_models_xl import extend_sdxl
from ldm.util import instantiate_from_config
@@ -17,6 +17,7 @@ from modules_forge import clip
from modules_forge.unet_patcher import UnetPatcher
from backend.loader import load_huggingface_components
from backend.modules.k_model import KModel
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
import open_clip
from transformers import CLIPTextModel, CLIPTokenizer
@@ -148,6 +149,15 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
sd_model.first_stage_model = forge_objects.vae.first_stage_model
sd_model.model.diffusion_model = forge_objects.unet.model
def set_clip_skip_callback(m, ts):
m.clip_skip = opts.CLIP_stop_at_last_layers
return
def set_clip_skip_callback_and_move_model(m, ts):
memory_management.load_model_gpu(sd_model.forge_objects.clip.patcher)
m.clip_skip = opts.CLIP_stop_at_last_layers
return
conditioner = getattr(sd_model, 'conditioner', None)
if conditioner:
text_cond_models = []
@@ -156,23 +166,44 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenCLIPEmbedder': # SDXL Clip L
embedder.tokenizer = forge_objects.clip.tokenizer.clip_l
embedder.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
model_embeddings.token_embedding, sd_hijack.model_hijack)
embedder = clip.CLIP_SD_XL_L(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
engine = ClassicTextProcessingEngine(
text_encoder=forge_objects.clip.cond_stage_model.clip_l,
tokenizer=forge_objects.clip.tokenizer.clip_l,
embedding_dir=cmd_opts.embeddings_dir,
embedding_key='clip_l',
embedding_expected_shape=2048,
emphasis_name=opts.emphasis,
text_projection=False,
minimal_clip_skip=2,
clip_skip=2,
return_pooled=False,
final_layer_norm=False,
callback_before_encode=set_clip_skip_callback
)
engine.is_trainable = False # for sgm codebase
engine.legacy_ucg_val = None # for sgm codebase
engine.input_key = 'txt' # for sgm codebase
conditioner.embedders[i] = engine
text_cond_models.append(embedder)
elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G
embedder.tokenizer = forge_objects.clip.tokenizer.clip_g
embedder.transformer = forge_objects.clip.cond_stage_model.clip_g.transformer
embedder.text_projection = forge_objects.clip.cond_stage_model.clip_g.text_projection
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
model_embeddings.token_embedding, sd_hijack.model_hijack, textual_inversion_key='clip_g')
embedder = clip.CLIP_SD_XL_G(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
engine = ClassicTextProcessingEngine(
text_encoder=forge_objects.clip.cond_stage_model.clip_g,
tokenizer=forge_objects.clip.tokenizer.clip_g,
embedding_dir=cmd_opts.embeddings_dir,
embedding_key='clip_g',
embedding_expected_shape=2048,
emphasis_name=opts.emphasis,
text_projection=True,
minimal_clip_skip=2,
clip_skip=2,
return_pooled=True,
final_layer_norm=False,
callback_before_encode=set_clip_skip_callback
)
engine.is_trainable = False # for sgm codebase
engine.legacy_ucg_val = None # for sgm codebase
engine.input_key = 'txt' # for sgm codebase
conditioner.embedders[i] = engine
text_cond_models.append(embedder)
if len(text_cond_models) == 1:
@@ -180,19 +211,37 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
else:
sd_model.cond_stage_model = conditioner
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack)
engine = ClassicTextProcessingEngine(
text_encoder=forge_objects.clip.cond_stage_model.clip_l,
tokenizer=forge_objects.clip.tokenizer.clip_l,
embedding_dir=cmd_opts.embeddings_dir,
embedding_key='clip_l',
embedding_expected_shape=768,
emphasis_name=opts.emphasis,
text_projection=False,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=False,
final_layer_norm=True,
callback_before_encode=set_clip_skip_callback_and_move_model
)
sd_model.cond_stage_model = engine
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = clip.CLIP_SD_21_H(sd_model.cond_stage_model, sd_hijack.model_hijack)
engine = ClassicTextProcessingEngine(
text_encoder=forge_objects.clip.cond_stage_model.clip_l,
tokenizer=forge_objects.clip.tokenizer.clip_l,
embedding_dir=cmd_opts.embeddings_dir,
embedding_key='clip_l',
embedding_expected_shape=1024,
emphasis_name=opts.emphasis,
text_projection=False,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=False,
final_layer_norm=True,
callback_before_encode=set_clip_skip_callback_and_move_model
)
sd_model.cond_stage_model = engine
else:
raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)