rework sd1.5 and sdxl from scratch

This commit is contained in:
layerdiffusion
2024-08-04 20:23:01 -07:00
parent e28e11fa97
commit 0863765173
25 changed files with 440 additions and 162 deletions

View File

@@ -5,7 +5,6 @@ from collections import namedtuple
from backend.text_processing import parsing, emphasis
from backend.text_processing.textual_inversion import EmbeddingDatabase
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
last_extra_generation_params = {}
@@ -47,11 +46,12 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
return torch.stack(vecs)
class ClassicTextProcessingEngine(torch.nn.Module):
def __init__(self, text_encoder, tokenizer, chunk_length=75,
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original",
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True,
callback_before_encode=None):
class ClassicTextProcessingEngine:
def __init__(
self, text_encoder, tokenizer, chunk_length=75,
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original",
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True
):
super().__init__()
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
@@ -71,7 +71,6 @@ class ClassicTextProcessingEngine(torch.nn.Module):
self.clip_skip = clip_skip
self.return_pooled = return_pooled
self.final_layer_norm = final_layer_norm
self.callback_before_encode = callback_before_encode
self.chunk_length = chunk_length
@@ -133,7 +132,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
pooled_output = outputs.pooler_output
if self.text_projection:
pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float()
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
z.pooled = pooled_output
return z
@@ -240,10 +239,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
return batch_chunks, token_count
def forward(self, texts):
if self.callback_before_encode is not None:
self.callback_before_encode(self, texts)
def __call__(self, texts):
batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {}