diff --git a/backend/args.py b/backend/args.py
index 8c8e1cbf..0302dfc1 100644
--- a/backend/args.py
+++ b/backend/args.py
@@ -57,3 +57,9 @@ parser.add_argument("--cuda-stream", action="store_true")
parser.add_argument("--pin-shared-memory", action="store_true")
args = parser.parse_known_args()[0]
+
+# Some dynamic args that may be changed by webui rather than cmd flags.
+dynamic_args = dict(
+ embedding_dir='./embeddings',
+ emphasis_name='original'
+)
diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py
new file mode 100644
index 00000000..a4b2a523
--- /dev/null
+++ b/backend/diffusion_engine/base.py
@@ -0,0 +1,60 @@
+import torch
+
+
+class ForgeObjects:
+ def __init__(self, unet, clip, vae, clipvision):
+ self.unet = unet
+ self.clip = clip
+ self.vae = vae
+ self.clipvision = clipvision
+
+ def shallow_copy(self):
+ return ForgeObjects(
+ self.unet,
+ self.clip,
+ self.vae,
+ self.clipvision
+ )
+
+
+class ForgeDiffusionEngine:
+ matched_guesses = []
+
+ def __init__(self, estimated_config, huggingface_components):
+ self.model_config = estimated_config
+ self.is_inpaint = estimated_config.inpaint_model()
+
+ self.forge_objects = None
+ self.forge_objects_original = None
+ self.forge_objects_after_applying_lora = None
+
+ self.current_lora_hash = str([])
+ self.tiling_enabled = False
+
+ # WebUI Dirty Legacy
+ self.cond_stage_key = 'txt'
+ self.is_sd3 = False
+ self.latent_channels = 4
+ self.is_sdxl = False
+ self.is_sdxl_inpaint = False
+ self.is_sd2 = False
+ self.is_sd1 = False
+ self.is_ssd = False
+
+ def set_clip_skip(self, clip_skip):
+ pass
+
+ def get_first_stage_encoding(self, x):
+ return x # legacy code, do not change
+
+ def get_learned_conditioning(self, prompt: list[str]):
+ pass
+
+ def encode_first_stage(self, x):
+ pass
+
+ def decode_first_stage(self, x):
+ pass
+
+ def get_prompt_lengths_on_ui(self, prompt):
+ pass
diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py
index 792d6005..82520a67 100644
--- a/backend/diffusion_engine/sd15.py
+++ b/backend/diffusion_engine/sd15.py
@@ -1 +1,81 @@
-#
+import torch
+
+from huggingface_guess import model_list
+from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
+from backend.patcher.clip import CLIP
+from backend.patcher.vae import VAE
+from backend.patcher.unet import UnetPatcher
+from backend.text_processing.classic_engine import ClassicTextProcessingEngine
+from backend.args import dynamic_args
+from backend import memory_management
+
+
+class StableDiffusion(ForgeDiffusionEngine):
+ matched_guesses = [model_list.SD15]
+
+ def __init__(self, estimated_config, huggingface_components):
+ super().__init__(estimated_config, huggingface_components)
+
+ clip = CLIP(
+ model_dict={
+ 'clip_l': huggingface_components['text_encoder']
+ },
+ tokenizer_dict={
+ 'clip_l': huggingface_components['tokenizer']
+ }
+ )
+
+ vae = VAE(model=huggingface_components['vae'])
+
+ unet = UnetPatcher.from_model(
+ model=huggingface_components['unet'],
+ diffusers_scheduler=huggingface_components['scheduler']
+ )
+
+ self.text_processing_engine = ClassicTextProcessingEngine(
+ text_encoder=clip.cond_stage_model.clip_l,
+ tokenizer=clip.tokenizer.clip_l,
+ embedding_dir=dynamic_args['embedding_dir'],
+ embedding_key='clip_l',
+ embedding_expected_shape=768,
+ emphasis_name=dynamic_args['emphasis_name'],
+ text_projection=False,
+ minimal_clip_skip=1,
+ clip_skip=1,
+ return_pooled=False,
+ final_layer_norm=True,
+ callback_before_encode=None
+ )
+
+ self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
+ self.forge_objects_original = self.forge_objects.shallow_copy()
+ self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
+
+ # WebUI Legacy
+ self.is_sd1 = True
+
+ def set_clip_skip(self, clip_skip):
+ self.text_processing_engine.clip_skip = clip_skip
+
+ @torch.inference_mode()
+ def get_learned_conditioning(self, prompt: list[str]):
+ memory_management.load_model_gpu(self.forge_objects.clip.patcher)
+ cond = self.text_processing_engine(prompt)
+ return cond
+
+ @torch.inference_mode()
+ def get_prompt_lengths_on_ui(self, prompt):
+ _, token_count = self.text_processing_engine.process_texts([prompt])
+ return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count)
+
+ @torch.inference_mode()
+ def encode_first_stage(self, x):
+ sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
+ sample = self.forge_objects.vae.first_stage_model.process_in(sample)
+ return sample.to(x)
+
+ @torch.inference_mode()
+ def decode_first_stage(self, x):
+ sample = self.forge_objects.vae.first_stage_model.process_out(x)
+ sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
+ return sample.to(x)
diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py
index 792d6005..f255c3dd 100644
--- a/backend/diffusion_engine/sdxl.py
+++ b/backend/diffusion_engine/sdxl.py
@@ -1 +1,132 @@
-#
+import torch
+
+from huggingface_guess import model_list
+from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
+from backend.patcher.clip import CLIP
+from backend.patcher.vae import VAE
+from backend.patcher.unet import UnetPatcher
+from backend.text_processing.classic_engine import ClassicTextProcessingEngine
+from backend.args import dynamic_args
+from backend import memory_management
+from backend.nn.unet import Timestep
+
+
+class StableDiffusionXL(ForgeDiffusionEngine):
+ matched_guesses = [model_list.SDXL]
+
+ def __init__(self, estimated_config, huggingface_components):
+ super().__init__(estimated_config, huggingface_components)
+
+ clip = CLIP(
+ model_dict={
+ 'clip_l': huggingface_components['text_encoder'],
+ 'clip_g': huggingface_components['text_encoder_2']
+ },
+ tokenizer_dict={
+ 'clip_l': huggingface_components['tokenizer'],
+ 'clip_g': huggingface_components['tokenizer_2']
+ }
+ )
+
+ vae = VAE(model=huggingface_components['vae'])
+
+ unet = UnetPatcher.from_model(
+ model=huggingface_components['unet'],
+ diffusers_scheduler=huggingface_components['scheduler']
+ )
+
+ self.text_processing_engine_l = ClassicTextProcessingEngine(
+ text_encoder=clip.cond_stage_model.clip_l,
+ tokenizer=clip.tokenizer.clip_l,
+ embedding_dir=dynamic_args['embedding_dir'],
+ embedding_key='clip_l',
+ embedding_expected_shape=2048,
+ emphasis_name=dynamic_args['emphasis_name'],
+ text_projection=False,
+ minimal_clip_skip=2,
+ clip_skip=2,
+ return_pooled=False,
+ final_layer_norm=False,
+ )
+
+ self.text_processing_engine_g = ClassicTextProcessingEngine(
+ text_encoder=clip.cond_stage_model.clip_g,
+ tokenizer=clip.tokenizer.clip_g,
+ embedding_dir=dynamic_args['embedding_dir'],
+ embedding_key='clip_g',
+ embedding_expected_shape=2048,
+ emphasis_name=dynamic_args['emphasis_name'],
+ text_projection=True,
+ minimal_clip_skip=2,
+ clip_skip=2,
+ return_pooled=True,
+ final_layer_norm=False,
+ )
+
+ self.embedder = Timestep(256)
+
+ self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
+ self.forge_objects_original = self.forge_objects.shallow_copy()
+ self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
+
+ # WebUI Legacy
+ self.is_sdxl = True
+
+ def set_clip_skip(self, clip_skip):
+ self.text_processing_engine_l.clip_skip = clip_skip
+ self.text_processing_engine_g.clip_skip = clip_skip
+
+ @torch.inference_mode()
+ def get_learned_conditioning(self, prompt: list[str]):
+ memory_management.load_model_gpu(self.forge_objects.clip.patcher)
+
+ cond_l = self.text_processing_engine_l(prompt)
+ cond_g, clip_pooled = self.text_processing_engine_g(prompt)
+
+ width = getattr(prompt, 'width', 1024) or 1024
+ height = getattr(prompt, 'height', 1024) or 1024
+ is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
+
+ crop_w = 0
+ crop_h = 0
+ target_width = width
+ target_height = height
+
+ out = [
+ self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
+ self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
+ self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))
+ ]
+
+ flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
+
+ force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
+
+ if force_zero_negative_prompt:
+ clip_pooled = torch.zeros_like(clip_pooled)
+ cond_l = torch.zeros_like(cond_l)
+ cond_g = torch.zeros_like(cond_g)
+
+ cond = dict(
+ crossattn=torch.cat([cond_l, cond_g], dim=2),
+ vector=torch.cat([clip_pooled, flat], dim=1),
+ )
+
+ return cond
+
+ @torch.inference_mode()
+ def get_prompt_lengths_on_ui(self, prompt):
+ _, token_count = self.text_processing_engine_l.process_texts([prompt])
+ return token_count, self.text_processing_engine_l.get_target_prompt_token_count(token_count)
+
+ @torch.inference_mode()
+ def encode_first_stage(self, x):
+ sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
+ sample = self.forge_objects.vae.first_stage_model.process_in(sample)
+ return sample.to(x)
+
+ @torch.inference_mode()
+ def decode_first_stage(self, x):
+ sample = self.forge_objects.vae.first_stage_model.process_out(x)
+ sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
+ return sample.to(x)
diff --git a/backend/loader.py b/backend/loader.py
index f7b4b9d5..8072299a 100644
--- a/backend/loader.py
+++ b/backend/loader.py
@@ -1,16 +1,23 @@
import os
+import torch
import logging
import importlib
import huggingface_guess
from diffusers import DiffusionPipeline
from transformers import modeling_utils
-from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
+from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
from backend.nn.unet import IntegratedUNet2DConditionModel
+from backend.diffusion_engine.sd15 import StableDiffusion
+from backend.diffusion_engine.sdxl import StableDiffusionXL
+
+
+possible_models = [StableDiffusion, StableDiffusionXL]
+
logging.getLogger("diffusers").setLevel(logging.ERROR)
dir_path = os.path.dirname(__file__)
@@ -27,61 +34,84 @@ def load_component(guess, component_name, lib_name, cls_name, repo_path, state_d
cls = getattr(importlib.import_module(lib_name), cls_name)
return cls.from_pretrained(os.path.join(repo_path, component_name))
if cls_name in ['AutoencoderKL']:
- sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.'])
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations():
model = IntegratedAutoencoderKL.from_config(config)
- load_state_dict(model, sd)
+ load_state_dict(model, state_dict)
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
- if component_name == 'text_encoder':
- sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.'])
- elif component_name == 'text_encoder_2':
- sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.'])
- else:
- raise ValueError(f"Wrong component_name: {component_name}")
-
- if 'model.text_projection' in sd:
- sd = transformers_convert(sd, "model.", "transformer.text_model.", 32)
- sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection",
- "model.text_projection.weight": "text_projection",
- "model.logit_scale": "logit_scale"})
-
config = CLIPTextConfig.from_pretrained(config_path)
with modeling_utils.no_init_weights():
with using_forge_operations():
model = IntegratedCLIP(config)
- load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
- 'transformer.text_model.embeddings.position_ids'])
+ load_state_dict(model, state_dict, ignore_errors=[
+ 'transformer.text_projection.weight',
+ 'transformer.text_model.embeddings.position_ids',
+ 'logit_scale'
+ ], log_name=cls_name)
+
return model
if cls_name == 'UNet2DConditionModel':
- sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
-
with using_forge_operations():
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
model._internal_dict = guess.unet_config
- load_state_dict(model, sd)
+ load_state_dict(model, state_dict)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None
-def load_huggingface_components(sd):
+def split_state_dict(sd):
guess = huggingface_guess.guess(sd)
- repo_name = guess.huggingface_repo
+
+ state_dict = {
+ 'unet': try_filter_state_dict(sd, ['model.diffusion_model.']),
+ 'vae': try_filter_state_dict(sd, guess.vae_key_prefix)
+ }
+
+ sd = guess.process_clip_state_dict(sd)
+ guess.clip_target = guess.clip_target(sd)
+
+ for k, v in guess.clip_target.items():
+ state_dict[v] = try_filter_state_dict(sd, [k + '.'])
+
+ state_dict['ignore'] = sd
+
+ print_dict = {k: len(v) for k, v in state_dict.items()}
+ print(f'StateDict Keys: {print_dict}')
+
+ del state_dict['ignore']
+
+ return state_dict, guess
+
+
+@torch.no_grad()
+def forge_loader(sd):
+ state_dicts, estimated_config = split_state_dict(sd)
+ repo_name = estimated_config.huggingface_repo
+
local_path = os.path.join(dir_path, 'huggingface', repo_name)
- config = DiffusionPipeline.load_config(local_path)
- result = {"repo_path": local_path}
+ config: dict = DiffusionPipeline.load_config(local_path)
+ huggingface_components = {}
for component_name, v in config.items():
if isinstance(v, list) and len(v) == 2:
lib_name, cls_name = v
- component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
+ component_sd = state_dicts.get(component_name, None)
+ component = load_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd)
+ if component_sd is not None:
+ del state_dicts[component_name]
if component is not None:
- result[component_name] = component
- return result
+ huggingface_components[component_name] = component
+
+ for M in possible_models:
+ if any(isinstance(estimated_config, x) for x in M.matched_guesses):
+ return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
+
+ print('Failed to recognize model type!')
+ return None
diff --git a/backend/modules/k_model.py b/backend/modules/k_model.py
index 03e1b644..16236227 100644
--- a/backend/modules/k_model.py
+++ b/backend/modules/k_model.py
@@ -5,14 +5,14 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
class KModel(torch.nn.Module):
- def __init__(self, huggingface_components, storage_dtype, computation_dtype):
+ def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype):
super().__init__()
self.storage_dtype = storage_dtype
self.computation_dtype = computation_dtype
- self.diffusion_model = huggingface_components['unet']
- self.predictor = k_prediction_from_diffusers_scheduler(huggingface_components['scheduler'])
+ self.diffusion_model = model
+ self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
diff --git a/backend/modules/k_prediction.py b/backend/modules/k_prediction.py
index f49f5d50..ead306bb 100644
--- a/backend/modules/k_prediction.py
+++ b/backend/modules/k_prediction.py
@@ -108,11 +108,11 @@ class Prediction(AbstractPrediction):
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
- self.set_sigmas(sigmas)
- def set_sigmas(self, sigmas):
+ self.register_buffer('alphas_cumprod', alphas_cumprod.float())
self.register_buffer('sigmas', sigmas.float())
self.register_buffer('log_sigmas', sigmas.log().float())
+ return
@property
def sigma_min(self):
diff --git a/backend/nn/clip.py b/backend/nn/clip.py
index c65f7b2a..373fb73d 100644
--- a/backend/nn/clip.py
+++ b/backend/nn/clip.py
@@ -7,5 +7,6 @@ class IntegratedCLIP(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.transformer = CLIPTextModel(config)
- self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
- self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
+ embed_dim = config.hidden_size
+ self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
+ self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
diff --git a/backend/patcher/clip.py b/backend/patcher/clip.py
index 260b91c1..b6c085da 100644
--- a/backend/patcher/clip.py
+++ b/backend/patcher/clip.py
@@ -1,24 +1,10 @@
-import torch
-
from backend import memory_management
from backend.patcher.base import ModelPatcher
-
-
-class JointTokenizer:
- def __init__(self, huggingface_components):
- self.clip_l = huggingface_components.get('tokenizer', None)
- self.clip_g = huggingface_components.get('tokenizer_2', None)
-
-
-class JointCLIPTextEncoder(torch.nn.Module):
- def __init__(self, huggingface_components):
- super().__init__()
- self.clip_l = huggingface_components.get('text_encoder', None)
- self.clip_g = huggingface_components.get('text_encoder_2', None)
+from backend.nn.base import ModuleDict, ObjectDict
class CLIP:
- def __init__(self, huggingface_components=None, no_init=False):
+ def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
if no_init:
return
@@ -26,8 +12,8 @@ class CLIP:
offload_device = memory_management.text_encoder_offload_device()
text_encoder_dtype = memory_management.text_encoder_dtype(load_device)
- self.cond_stage_model = JointCLIPTextEncoder(huggingface_components)
- self.tokenizer = JointTokenizer(huggingface_components)
+ self.cond_stage_model = ModuleDict(model_dict)
+ self.tokenizer = ObjectDict(tokenizer_dict)
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py
index 38528441..b70b257c 100644
--- a/backend/patcher/unet.py
+++ b/backend/patcher/unet.py
@@ -1,12 +1,26 @@
import copy
import torch
+from backend.modules.k_model import KModel
from backend.patcher.base import ModelPatcher
+from backend import memory_management
class UnetPatcher(ModelPatcher):
- def __init__(self, model, *args, **kwargs):
- super().__init__(model, *args, **kwargs)
+ @classmethod
+ def from_model(cls, model, diffusers_scheduler):
+ parameters = memory_management.module_size(model)
+ unet_dtype = memory_management.unet_dtype(model_params=parameters)
+ load_device = memory_management.get_torch_device()
+ initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
+ manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
+ manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
+ model.to(device=initial_load_device, dtype=unet_dtype)
+ model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
+ return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
self.controlnet_linked_list = None
self.extra_preserved_memory_during_sampling = 0
self.extra_model_patchers_during_sampling = []
diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py
index adab833e..aa39961f 100644
--- a/backend/sampling/sampling_function.py
+++ b/backend/sampling/sampling_function.py
@@ -301,14 +301,15 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_
def sampling_function(self, denoiser_params, cond_scale, cond_composition):
- model = self.inner_model.inner_model.forge_objects.unet.model
- control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
- extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
+ unet_patcher = self.inner_model.forge_objects.unet
+ model = unet_patcher.model
+ control = unet_patcher.controlnet_linked_list
+ extra_concat_condition = unet_patcher.extra_concat_condition
x = denoiser_params.x
timestep = denoiser_params.sigma
uncond = compile_conditions(denoiser_params.text_uncond)
cond = compile_weighted_conditions(denoiser_params.text_cond, cond_composition)
- model_options = self.inner_model.inner_model.forge_objects.unet.model_options
+ model_options = unet_patcher.model_options
seed = self.p.seeds[0]
if extra_concat_condition is not None:
diff --git a/backend/state_dict.py b/backend/state_dict.py
index 6a5ab7e5..ef3bdc8b 100644
--- a/backend/state_dict.py
+++ b/backend/state_dict.py
@@ -1,14 +1,15 @@
import torch
-def load_state_dict(model, sd, ignore_errors=[]):
+def load_state_dict(model, sd, ignore_errors=[], log_name=None):
missing, unexpected = model.load_state_dict(sd, strict=False)
missing = [x for x in missing if x not in ignore_errors]
unexpected = [x for x in unexpected if x not in ignore_errors]
+ log_name = log_name or type(model).__name__
if len(missing) > 0:
- print(f'{type(model).__name__} Missing: {missing}')
+ print(f'{log_name} Missing: {missing}')
if len(unexpected) > 0:
- print(f'{type(model).__name__} Unexpected: {unexpected}')
+ print(f'{log_name} Unexpected: {unexpected}')
return
diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py
index ecbdc213..3832fe65 100644
--- a/backend/text_processing/classic_engine.py
+++ b/backend/text_processing/classic_engine.py
@@ -5,7 +5,6 @@ from collections import namedtuple
from backend.text_processing import parsing, emphasis
from backend.text_processing.textual_inversion import EmbeddingDatabase
-
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
last_extra_generation_params = {}
@@ -47,11 +46,12 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
return torch.stack(vecs)
-class ClassicTextProcessingEngine(torch.nn.Module):
- def __init__(self, text_encoder, tokenizer, chunk_length=75,
- embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original",
- text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True,
- callback_before_encode=None):
+class ClassicTextProcessingEngine:
+ def __init__(
+ self, text_encoder, tokenizer, chunk_length=75,
+ embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original",
+ text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True
+ ):
super().__init__()
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
@@ -71,7 +71,6 @@ class ClassicTextProcessingEngine(torch.nn.Module):
self.clip_skip = clip_skip
self.return_pooled = return_pooled
self.final_layer_norm = final_layer_norm
- self.callback_before_encode = callback_before_encode
self.chunk_length = chunk_length
@@ -133,7 +132,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
pooled_output = outputs.pooler_output
if self.text_projection:
- pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float()
+ pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
z.pooled = pooled_output
return z
@@ -240,10 +239,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
return batch_chunks, token_count
- def forward(self, texts):
- if self.callback_before_encode is not None:
- self.callback_before_encode(self, texts)
-
+ def __call__(self, texts):
batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {}
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 5b31861c..f933de64 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -401,7 +401,7 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
- huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "aebabb94eaaa1a26a3b37128d1c079838c134623")
+ huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
diff --git a/modules/processing.py b/modules/processing.py
index d7d36eea..2d0f13fa 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -100,7 +100,7 @@ def create_binary_mask(image, round=True):
return image
def txt2img_image_conditioning(sd_model, x, width, height):
- if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
+ if sd_model.is_inpaint: # Inpainting models
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
@@ -111,24 +111,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
-
- elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
-
- return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
-
else:
- if sd_model.is_sdxl_inpaint:
- # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
- image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
- image_conditioning = images_tensor_to_samples(image_conditioning,
- approximation_indexes.get(opts.sd_vae_encode_method))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
-
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@@ -307,7 +290,7 @@ class StableDiffusionProcessing:
self.comments[text] = 1
def txt2img_image_conditioning(self, x, width=None, height=None):
- self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ self.is_using_inpainting_conditioning = self.sd_model.is_inpaint
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
@@ -497,6 +480,8 @@ class StableDiffusionProcessing:
cache = caches[0]
with devices.autocast():
+ shared.sd_model.set_clip_skip(opts.CLIP_stop_at_last_layers)
+
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
import backend.text_processing.classic_engine
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index ffaadc89..f292073b 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -150,8 +150,7 @@ class StableDiffusionModelHijack:
self.extra_generation_params = {}
def get_prompt_lengths(self, text, cond_stage_model):
- _, token_count = cond_stage_model.process_texts([text])
- return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
+ pass
def redo_hijack(self, m):
pass
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 92537ece..d89a8326 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,11 +14,12 @@ import ldm.modules.midas as midas
import gc
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
-from modules.shared import opts
+from modules.shared import opts, cmd_opts
from modules.timer import Timer
import numpy as np
-from modules_forge import loader
+from backend.loader import forge_loader
from backend import memory_management
+from backend.args import dynamic_args
model_dir = "Stable-diffusion"
@@ -636,6 +637,7 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls)
+@torch.no_grad()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -663,8 +665,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
- sd_model = loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
+ dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
+ dynamic_args['emphasis_name'] = opts.emphasis
+ sd_model = forge_loader(state_dict)
+ timer.record("forge model load")
+
+ sd_model.sd_checkpoint_info = checkpoint_info
sd_model.filename = checkpoint_info.filename
+ sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
+ timer.record("calculate hash")
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index 0934ccfe..a3a896a1 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -39,8 +39,10 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""
- def __init__(self, sampler):
+ def __init__(self, sampler, model):
super().__init__()
+ self.inner_model = model
+
self.model_wrap = None
self.mask = None
self.nmask = None
@@ -64,10 +66,6 @@ class CFGDenoiser(torch.nn.Module):
self.classic_ddim_eps_estimation = False
- @property
- def inner_model(self):
- raise NotImplementedError()
-
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
@@ -158,7 +156,7 @@ class CFGDenoiser(torch.nn.Module):
original_x_dtype = x.dtype
if self.classic_ddim_eps_estimation:
- acd = self.inner_model.inner_model.alphas_cumprod
+ acd = self.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
real_sigma_data = 1.0
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 01723c30..fee65f83 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -237,7 +237,7 @@ class Sampler:
self.eta_infotext_field = 'Eta'
self.eta_default = 1.0
- self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
+ self.conditioning_key = 'crossattn'
self.p = None
self.model_wrap_cfg = None
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index f0b88c95..a702f7b2 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,12 +1,11 @@
import torch
import inspect
import k_diffusion.sampling
-from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
-from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
+from modules import sd_samplers_common, sd_samplers_extra, sd_schedulers, devices
+from modules.sd_samplers_cfg_denoiser import CFGDenoiser
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
-import modules.shared as shared
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
@@ -51,21 +50,6 @@ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
-class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
- @property
- def inner_model(self):
- if self.model_wrap is None:
- denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
-
- if denoiser_constructor is not None:
- self.model_wrap = denoiser_constructor()
- else:
- denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
- self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
-
- return self.model_wrap
-
-
class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname)
@@ -75,8 +59,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
- self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
- self.model_wrap = self.model_wrap_cfg.inner_model
+ self.model_wrap = self.model_wrap_cfg = CFGDenoiser(self, sd_model)
+ self.predictor = sd_model.forge_objects.unet.model.predictor
+
+ self.model_wrap_cfg.sigmas = self.predictor.sigmas
+ self.model_wrap_cfg.log_sigmas = self.predictor.sigmas.log()
def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -92,13 +79,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
- m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
+ m_sigma_min, m_sigma_max = self.predictor.sigmas[0].item(), self.predictor.sigmas[-1].item()
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif scheduler is None or scheduler.function is None:
- sigmas = self.model_wrap.get_sigmas(steps)
+ raise ValueError('Wrong scheduler!')
else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
@@ -120,7 +107,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
p.extra_generation_params["Schedule rho"] = opts.rho
if scheduler.need_inner_model:
- sigmas_kwargs['inner_model'] = self.model_wrap
+ sigmas_kwargs['inner_model'] = self.model_wrap_cfg
if scheduler.label == 'Beta':
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
@@ -134,11 +121,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return sigmas.cpu()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- unet_patcher = self.model_wrap.inner_model.forge_objects.unet
- sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
+ unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
+ sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
- self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
- self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
+ self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
+ self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
@@ -196,11 +183,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- unet_patcher = self.model_wrap.inner_model.forge_objects.unet
- sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
+ unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
+ sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
- self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
- self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
+ self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
+ self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
steps = steps or p.steps
@@ -219,8 +206,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
extra_params_kwargs['n'] = steps
if 'sigma_min' in parameters:
- extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
- extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
+ extra_params_kwargs['sigma_min'] = self.model_wrap_cfg.sigmas[0].item()
+ extra_params_kwargs['sigma_max'] = self.model_wrap_cfg.sigmas[-1].item()
if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigmas
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index 08956497..c5a97290 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -49,10 +49,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
class CFGDenoiserTimesteps(CFGDenoiser):
- def __init__(self, sampler):
- super().__init__(sampler)
+ def __init__(self, sampler, model):
+ super().__init__(sampler, model)
- self.alphas = shared.sd_model.alphas_cumprod
+ self.alphas = model.forge_objects.unet.model.predictor.alphas_cumprod
self.classic_ddim_eps_estimation = True
def get_pred_x0(self, x_in, x_out, sigma):
@@ -66,14 +66,6 @@ class CFGDenoiserTimesteps(CFGDenoiser):
return pred_x0
- @property
- def inner_model(self):
- if self.model_wrap is None:
- denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
- self.model_wrap = denoiser(shared.sd_model)
-
- return self.model_wrap
-
class CompVisSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model):
@@ -83,8 +75,10 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_infotext_field = 'Eta DDIM'
self.eta_default = 0.0
- self.model_wrap_cfg = CFGDenoiserTimesteps(self)
- self.model_wrap = self.model_wrap_cfg.inner_model
+ self.model_wrap = self.model_wrap_cfg = CFGDenoiserTimesteps(self, sd_model)
+ self.predictor = sd_model.forge_objects.unet.model.predictor
+
+ self.model_wrap.inner_model.alphas_cumprod = self.predictor.alphas_cumprod
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
index 180e4389..a63179e6 100644
--- a/modules/sd_samplers_timesteps_impl.py
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -10,7 +10,7 @@ from modules.torch_utils import float64
@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
- alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas_cumprod = model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
- alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas_cumprod = model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
- alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas_cumprod = model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC):
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
- alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas_cumprod = model.inner_model.alphas_cumprod
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py
index af873dc9..7fc6ae1b 100644
--- a/modules/sd_schedulers.py
+++ b/modules/sd_schedulers.py
@@ -130,7 +130,7 @@ def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
def turbo_scheduler(n, sigma_min, sigma_max, inner_model, device):
unet = inner_model.inner_model.forge_objects.unet
timesteps = torch.flip(torch.arange(1, n + 1) * float(1000.0 / n) - 1, (0,)).round().long().clip(0, 999)
- sigmas = unet.model.model_sampling.sigma(timesteps)
+ sigmas = unet.model.predictor.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return sigmas.to(device)
diff --git a/modules/ui.py b/modules/ui.py
index 676132ff..7acb63d2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -181,14 +181,14 @@ def update_token_counter(text, steps, styles, *, is_positive=True):
prompt_schedules = [[[steps, text]]]
try:
- cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
- assert cond_stage_model is not None
+ get_prompt_lengths_on_ui = sd_models.model_data.sd_model.get_prompt_lengths_on_ui
+ assert get_prompt_lengths_on_ui is not None
except Exception:
return f"?/?"
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
- token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0])
+ token_count, max_length = max([get_prompt_lengths_on_ui(prompt) for prompt in prompts], key=lambda args: args[0])
return f"{token_count}/{max_length}"
diff --git a/modules_forge/loader.py b/modules_forge/forge_old_loader.py
similarity index 100%
rename from modules_forge/loader.py
rename to modules_forge/forge_old_loader.py