mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-26 17:29:09 +00:00
rework sd1.5 and sdxl from scratch
This commit is contained in:
@@ -5,7 +5,6 @@ from collections import namedtuple
|
||||
from backend.text_processing import parsing, emphasis
|
||||
from backend.text_processing.textual_inversion import EmbeddingDatabase
|
||||
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
last_extra_generation_params = {}
|
||||
|
||||
@@ -47,11 +46,12 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75,
|
||||
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original",
|
||||
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True,
|
||||
callback_before_encode=None):
|
||||
class ClassicTextProcessingEngine:
|
||||
def __init__(
|
||||
self, text_encoder, tokenizer, chunk_length=75,
|
||||
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original",
|
||||
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
||||
@@ -71,7 +71,6 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
self.clip_skip = clip_skip
|
||||
self.return_pooled = return_pooled
|
||||
self.final_layer_norm = final_layer_norm
|
||||
self.callback_before_encode = callback_before_encode
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
|
||||
@@ -133,7 +132,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
if self.text_projection:
|
||||
pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float()
|
||||
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
|
||||
|
||||
z.pooled = pooled_output
|
||||
return z
|
||||
@@ -240,10 +239,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
|
||||
return batch_chunks, token_count
|
||||
|
||||
def forward(self, texts):
|
||||
if self.callback_before_encode is not None:
|
||||
self.callback_before_encode(self, texts)
|
||||
|
||||
def __call__(self, texts):
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
|
||||
used_embeddings = {}
|
||||
|
||||
Reference in New Issue
Block a user