diff --git a/backend/args.py b/backend/args.py index 8c8e1cbf..0302dfc1 100644 --- a/backend/args.py +++ b/backend/args.py @@ -57,3 +57,9 @@ parser.add_argument("--cuda-stream", action="store_true") parser.add_argument("--pin-shared-memory", action="store_true") args = parser.parse_known_args()[0] + +# Some dynamic args that may be changed by webui rather than cmd flags. +dynamic_args = dict( + embedding_dir='./embeddings', + emphasis_name='original' +) diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py new file mode 100644 index 00000000..a4b2a523 --- /dev/null +++ b/backend/diffusion_engine/base.py @@ -0,0 +1,60 @@ +import torch + + +class ForgeObjects: + def __init__(self, unet, clip, vae, clipvision): + self.unet = unet + self.clip = clip + self.vae = vae + self.clipvision = clipvision + + def shallow_copy(self): + return ForgeObjects( + self.unet, + self.clip, + self.vae, + self.clipvision + ) + + +class ForgeDiffusionEngine: + matched_guesses = [] + + def __init__(self, estimated_config, huggingface_components): + self.model_config = estimated_config + self.is_inpaint = estimated_config.inpaint_model() + + self.forge_objects = None + self.forge_objects_original = None + self.forge_objects_after_applying_lora = None + + self.current_lora_hash = str([]) + self.tiling_enabled = False + + # WebUI Dirty Legacy + self.cond_stage_key = 'txt' + self.is_sd3 = False + self.latent_channels = 4 + self.is_sdxl = False + self.is_sdxl_inpaint = False + self.is_sd2 = False + self.is_sd1 = False + self.is_ssd = False + + def set_clip_skip(self, clip_skip): + pass + + def get_first_stage_encoding(self, x): + return x # legacy code, do not change + + def get_learned_conditioning(self, prompt: list[str]): + pass + + def encode_first_stage(self, x): + pass + + def decode_first_stage(self, x): + pass + + def get_prompt_lengths_on_ui(self, prompt): + pass diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py index 792d6005..82520a67 100644 --- a/backend/diffusion_engine/sd15.py +++ b/backend/diffusion_engine/sd15.py @@ -1 +1,81 @@ -# +import torch + +from huggingface_guess import model_list +from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects +from backend.patcher.clip import CLIP +from backend.patcher.vae import VAE +from backend.patcher.unet import UnetPatcher +from backend.text_processing.classic_engine import ClassicTextProcessingEngine +from backend.args import dynamic_args +from backend import memory_management + + +class StableDiffusion(ForgeDiffusionEngine): + matched_guesses = [model_list.SD15] + + def __init__(self, estimated_config, huggingface_components): + super().__init__(estimated_config, huggingface_components) + + clip = CLIP( + model_dict={ + 'clip_l': huggingface_components['text_encoder'] + }, + tokenizer_dict={ + 'clip_l': huggingface_components['tokenizer'] + } + ) + + vae = VAE(model=huggingface_components['vae']) + + unet = UnetPatcher.from_model( + model=huggingface_components['unet'], + diffusers_scheduler=huggingface_components['scheduler'] + ) + + self.text_processing_engine = ClassicTextProcessingEngine( + text_encoder=clip.cond_stage_model.clip_l, + tokenizer=clip.tokenizer.clip_l, + embedding_dir=dynamic_args['embedding_dir'], + embedding_key='clip_l', + embedding_expected_shape=768, + emphasis_name=dynamic_args['emphasis_name'], + text_projection=False, + minimal_clip_skip=1, + clip_skip=1, + return_pooled=False, + final_layer_norm=True, + callback_before_encode=None + ) + + self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) + self.forge_objects_original = self.forge_objects.shallow_copy() + self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy() + + # WebUI Legacy + self.is_sd1 = True + + def set_clip_skip(self, clip_skip): + self.text_processing_engine.clip_skip = clip_skip + + @torch.inference_mode() + def get_learned_conditioning(self, prompt: list[str]): + memory_management.load_model_gpu(self.forge_objects.clip.patcher) + cond = self.text_processing_engine(prompt) + return cond + + @torch.inference_mode() + def get_prompt_lengths_on_ui(self, prompt): + _, token_count = self.text_processing_engine.process_texts([prompt]) + return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count) + + @torch.inference_mode() + def encode_first_stage(self, x): + sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) + sample = self.forge_objects.vae.first_stage_model.process_in(sample) + return sample.to(x) + + @torch.inference_mode() + def decode_first_stage(self, x): + sample = self.forge_objects.vae.first_stage_model.process_out(x) + sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + return sample.to(x) diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py index 792d6005..f255c3dd 100644 --- a/backend/diffusion_engine/sdxl.py +++ b/backend/diffusion_engine/sdxl.py @@ -1 +1,132 @@ -# +import torch + +from huggingface_guess import model_list +from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects +from backend.patcher.clip import CLIP +from backend.patcher.vae import VAE +from backend.patcher.unet import UnetPatcher +from backend.text_processing.classic_engine import ClassicTextProcessingEngine +from backend.args import dynamic_args +from backend import memory_management +from backend.nn.unet import Timestep + + +class StableDiffusionXL(ForgeDiffusionEngine): + matched_guesses = [model_list.SDXL] + + def __init__(self, estimated_config, huggingface_components): + super().__init__(estimated_config, huggingface_components) + + clip = CLIP( + model_dict={ + 'clip_l': huggingface_components['text_encoder'], + 'clip_g': huggingface_components['text_encoder_2'] + }, + tokenizer_dict={ + 'clip_l': huggingface_components['tokenizer'], + 'clip_g': huggingface_components['tokenizer_2'] + } + ) + + vae = VAE(model=huggingface_components['vae']) + + unet = UnetPatcher.from_model( + model=huggingface_components['unet'], + diffusers_scheduler=huggingface_components['scheduler'] + ) + + self.text_processing_engine_l = ClassicTextProcessingEngine( + text_encoder=clip.cond_stage_model.clip_l, + tokenizer=clip.tokenizer.clip_l, + embedding_dir=dynamic_args['embedding_dir'], + embedding_key='clip_l', + embedding_expected_shape=2048, + emphasis_name=dynamic_args['emphasis_name'], + text_projection=False, + minimal_clip_skip=2, + clip_skip=2, + return_pooled=False, + final_layer_norm=False, + ) + + self.text_processing_engine_g = ClassicTextProcessingEngine( + text_encoder=clip.cond_stage_model.clip_g, + tokenizer=clip.tokenizer.clip_g, + embedding_dir=dynamic_args['embedding_dir'], + embedding_key='clip_g', + embedding_expected_shape=2048, + emphasis_name=dynamic_args['emphasis_name'], + text_projection=True, + minimal_clip_skip=2, + clip_skip=2, + return_pooled=True, + final_layer_norm=False, + ) + + self.embedder = Timestep(256) + + self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) + self.forge_objects_original = self.forge_objects.shallow_copy() + self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy() + + # WebUI Legacy + self.is_sdxl = True + + def set_clip_skip(self, clip_skip): + self.text_processing_engine_l.clip_skip = clip_skip + self.text_processing_engine_g.clip_skip = clip_skip + + @torch.inference_mode() + def get_learned_conditioning(self, prompt: list[str]): + memory_management.load_model_gpu(self.forge_objects.clip.patcher) + + cond_l = self.text_processing_engine_l(prompt) + cond_g, clip_pooled = self.text_processing_engine_g(prompt) + + width = getattr(prompt, 'width', 1024) or 1024 + height = getattr(prompt, 'height', 1024) or 1024 + is_negative_prompt = getattr(prompt, 'is_negative_prompt', False) + + crop_w = 0 + crop_h = 0 + target_width = width + target_height = height + + out = [ + self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])), + self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), + self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width])) + ] + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled) + + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt) + + if force_zero_negative_prompt: + clip_pooled = torch.zeros_like(clip_pooled) + cond_l = torch.zeros_like(cond_l) + cond_g = torch.zeros_like(cond_g) + + cond = dict( + crossattn=torch.cat([cond_l, cond_g], dim=2), + vector=torch.cat([clip_pooled, flat], dim=1), + ) + + return cond + + @torch.inference_mode() + def get_prompt_lengths_on_ui(self, prompt): + _, token_count = self.text_processing_engine_l.process_texts([prompt]) + return token_count, self.text_processing_engine_l.get_target_prompt_token_count(token_count) + + @torch.inference_mode() + def encode_first_stage(self, x): + sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) + sample = self.forge_objects.vae.first_stage_model.process_in(sample) + return sample.to(x) + + @torch.inference_mode() + def decode_first_stage(self, x): + sample = self.forge_objects.vae.first_stage_model.process_out(x) + sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + return sample.to(x) diff --git a/backend/loader.py b/backend/loader.py index f7b4b9d5..8072299a 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -1,16 +1,23 @@ import os +import torch import logging import importlib import huggingface_guess from diffusers import DiffusionPipeline from transformers import modeling_utils -from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace +from backend.state_dict import try_filter_state_dict, load_state_dict from backend.operations import using_forge_operations from backend.nn.vae import IntegratedAutoencoderKL from backend.nn.clip import IntegratedCLIP, CLIPTextConfig from backend.nn.unet import IntegratedUNet2DConditionModel +from backend.diffusion_engine.sd15 import StableDiffusion +from backend.diffusion_engine.sdxl import StableDiffusionXL + + +possible_models = [StableDiffusion, StableDiffusionXL] + logging.getLogger("diffusers").setLevel(logging.ERROR) dir_path = os.path.dirname(__file__) @@ -27,61 +34,84 @@ def load_component(guess, component_name, lib_name, cls_name, repo_path, state_d cls = getattr(importlib.import_module(lib_name), cls_name) return cls.from_pretrained(os.path.join(repo_path, component_name)) if cls_name in ['AutoencoderKL']: - sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.']) config = IntegratedAutoencoderKL.load_config(config_path) with using_forge_operations(): model = IntegratedAutoencoderKL.from_config(config) - load_state_dict(model, sd) + load_state_dict(model, state_dict) return model if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: - if component_name == 'text_encoder': - sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.']) - elif component_name == 'text_encoder_2': - sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.']) - else: - raise ValueError(f"Wrong component_name: {component_name}") - - if 'model.text_projection' in sd: - sd = transformers_convert(sd, "model.", "transformer.text_model.", 32) - sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection", - "model.text_projection.weight": "text_projection", - "model.logit_scale": "logit_scale"}) - config = CLIPTextConfig.from_pretrained(config_path) with modeling_utils.no_init_weights(): with using_forge_operations(): model = IntegratedCLIP(config) - load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale', - 'transformer.text_model.embeddings.position_ids']) + load_state_dict(model, state_dict, ignore_errors=[ + 'transformer.text_projection.weight', + 'transformer.text_model.embeddings.position_ids', + 'logit_scale' + ], log_name=cls_name) + return model if cls_name == 'UNet2DConditionModel': - sd = try_filter_state_dict(state_dict, ['model.diffusion_model.']) - with using_forge_operations(): model = IntegratedUNet2DConditionModel.from_config(guess.unet_config) model._internal_dict = guess.unet_config - load_state_dict(model, sd) + load_state_dict(model, state_dict) return model print(f'Skipped: {component_name} = {lib_name}.{cls_name}') return None -def load_huggingface_components(sd): +def split_state_dict(sd): guess = huggingface_guess.guess(sd) - repo_name = guess.huggingface_repo + + state_dict = { + 'unet': try_filter_state_dict(sd, ['model.diffusion_model.']), + 'vae': try_filter_state_dict(sd, guess.vae_key_prefix) + } + + sd = guess.process_clip_state_dict(sd) + guess.clip_target = guess.clip_target(sd) + + for k, v in guess.clip_target.items(): + state_dict[v] = try_filter_state_dict(sd, [k + '.']) + + state_dict['ignore'] = sd + + print_dict = {k: len(v) for k, v in state_dict.items()} + print(f'StateDict Keys: {print_dict}') + + del state_dict['ignore'] + + return state_dict, guess + + +@torch.no_grad() +def forge_loader(sd): + state_dicts, estimated_config = split_state_dict(sd) + repo_name = estimated_config.huggingface_repo + local_path = os.path.join(dir_path, 'huggingface', repo_name) - config = DiffusionPipeline.load_config(local_path) - result = {"repo_path": local_path} + config: dict = DiffusionPipeline.load_config(local_path) + huggingface_components = {} for component_name, v in config.items(): if isinstance(v, list) and len(v) == 2: lib_name, cls_name = v - component = load_component(guess, component_name, lib_name, cls_name, local_path, sd) + component_sd = state_dicts.get(component_name, None) + component = load_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd) + if component_sd is not None: + del state_dicts[component_name] if component is not None: - result[component_name] = component - return result + huggingface_components[component_name] = component + + for M in possible_models: + if any(isinstance(estimated_config, x) for x in M.matched_guesses): + return M(estimated_config=estimated_config, huggingface_components=huggingface_components) + + print('Failed to recognize model type!') + return None diff --git a/backend/modules/k_model.py b/backend/modules/k_model.py index 03e1b644..16236227 100644 --- a/backend/modules/k_model.py +++ b/backend/modules/k_model.py @@ -5,14 +5,14 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler class KModel(torch.nn.Module): - def __init__(self, huggingface_components, storage_dtype, computation_dtype): + def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype): super().__init__() self.storage_dtype = storage_dtype self.computation_dtype = computation_dtype - self.diffusion_model = huggingface_components['unet'] - self.predictor = k_prediction_from_diffusers_scheduler(huggingface_components['scheduler']) + self.diffusion_model = model + self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler) def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t diff --git a/backend/modules/k_prediction.py b/backend/modules/k_prediction.py index f49f5d50..ead306bb 100644 --- a/backend/modules/k_prediction.py +++ b/backend/modules/k_prediction.py @@ -108,11 +108,11 @@ class Prediction(AbstractPrediction): alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, dim=0) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - self.set_sigmas(sigmas) - def set_sigmas(self, sigmas): + self.register_buffer('alphas_cumprod', alphas_cumprod.float()) self.register_buffer('sigmas', sigmas.float()) self.register_buffer('log_sigmas', sigmas.log().float()) + return @property def sigma_min(self): diff --git a/backend/nn/clip.py b/backend/nn/clip.py index c65f7b2a..373fb73d 100644 --- a/backend/nn/clip.py +++ b/backend/nn/clip.py @@ -7,5 +7,6 @@ class IntegratedCLIP(torch.nn.Module): def __init__(self, config: CLIPTextConfig): super().__init__() self.transformer = CLIPTextModel(config) - self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + embed_dim = config.hidden_size + self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False) + self.transformer.text_projection.weight.copy_(torch.eye(embed_dim)) diff --git a/backend/patcher/clip.py b/backend/patcher/clip.py index 260b91c1..b6c085da 100644 --- a/backend/patcher/clip.py +++ b/backend/patcher/clip.py @@ -1,24 +1,10 @@ -import torch - from backend import memory_management from backend.patcher.base import ModelPatcher - - -class JointTokenizer: - def __init__(self, huggingface_components): - self.clip_l = huggingface_components.get('tokenizer', None) - self.clip_g = huggingface_components.get('tokenizer_2', None) - - -class JointCLIPTextEncoder(torch.nn.Module): - def __init__(self, huggingface_components): - super().__init__() - self.clip_l = huggingface_components.get('text_encoder', None) - self.clip_g = huggingface_components.get('text_encoder_2', None) +from backend.nn.base import ModuleDict, ObjectDict class CLIP: - def __init__(self, huggingface_components=None, no_init=False): + def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False): if no_init: return @@ -26,8 +12,8 @@ class CLIP: offload_device = memory_management.text_encoder_offload_device() text_encoder_dtype = memory_management.text_encoder_dtype(load_device) - self.cond_stage_model = JointCLIPTextEncoder(huggingface_components) - self.tokenizer = JointTokenizer(huggingface_components) + self.cond_stage_model = ModuleDict(model_dict) + self.tokenizer = ObjectDict(tokenizer_dict) self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py index 38528441..b70b257c 100644 --- a/backend/patcher/unet.py +++ b/backend/patcher/unet.py @@ -1,12 +1,26 @@ import copy import torch +from backend.modules.k_model import KModel from backend.patcher.base import ModelPatcher +from backend import memory_management class UnetPatcher(ModelPatcher): - def __init__(self, model, *args, **kwargs): - super().__init__(model, *args, **kwargs) + @classmethod + def from_model(cls, model, diffusers_scheduler): + parameters = memory_management.module_size(model) + unet_dtype = memory_management.unet_dtype(model_params=parameters) + load_device = memory_management.get_torch_device() + initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype) + manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device) + manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype + model.to(device=initial_load_device, dtype=unet_dtype) + model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype) + return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.controlnet_linked_list = None self.extra_preserved_memory_during_sampling = 0 self.extra_model_patchers_during_sampling = [] diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index adab833e..aa39961f 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -301,14 +301,15 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_ def sampling_function(self, denoiser_params, cond_scale, cond_composition): - model = self.inner_model.inner_model.forge_objects.unet.model - control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list - extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition + unet_patcher = self.inner_model.forge_objects.unet + model = unet_patcher.model + control = unet_patcher.controlnet_linked_list + extra_concat_condition = unet_patcher.extra_concat_condition x = denoiser_params.x timestep = denoiser_params.sigma uncond = compile_conditions(denoiser_params.text_uncond) cond = compile_weighted_conditions(denoiser_params.text_cond, cond_composition) - model_options = self.inner_model.inner_model.forge_objects.unet.model_options + model_options = unet_patcher.model_options seed = self.p.seeds[0] if extra_concat_condition is not None: diff --git a/backend/state_dict.py b/backend/state_dict.py index 6a5ab7e5..ef3bdc8b 100644 --- a/backend/state_dict.py +++ b/backend/state_dict.py @@ -1,14 +1,15 @@ import torch -def load_state_dict(model, sd, ignore_errors=[]): +def load_state_dict(model, sd, ignore_errors=[], log_name=None): missing, unexpected = model.load_state_dict(sd, strict=False) missing = [x for x in missing if x not in ignore_errors] unexpected = [x for x in unexpected if x not in ignore_errors] + log_name = log_name or type(model).__name__ if len(missing) > 0: - print(f'{type(model).__name__} Missing: {missing}') + print(f'{log_name} Missing: {missing}') if len(unexpected) > 0: - print(f'{type(model).__name__} Unexpected: {unexpected}') + print(f'{log_name} Unexpected: {unexpected}') return diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index ecbdc213..3832fe65 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -5,7 +5,6 @@ from collections import namedtuple from backend.text_processing import parsing, emphasis from backend.text_processing.textual_inversion import EmbeddingDatabase - PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) last_extra_generation_params = {} @@ -47,11 +46,12 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module): return torch.stack(vecs) -class ClassicTextProcessingEngine(torch.nn.Module): - def __init__(self, text_encoder, tokenizer, chunk_length=75, - embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", - text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True, - callback_before_encode=None): +class ClassicTextProcessingEngine: + def __init__( + self, text_encoder, tokenizer, chunk_length=75, + embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original", + text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True + ): super().__init__() self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape) @@ -71,7 +71,6 @@ class ClassicTextProcessingEngine(torch.nn.Module): self.clip_skip = clip_skip self.return_pooled = return_pooled self.final_layer_norm = final_layer_norm - self.callback_before_encode = callback_before_encode self.chunk_length = chunk_length @@ -133,7 +132,7 @@ class ClassicTextProcessingEngine(torch.nn.Module): pooled_output = outputs.pooler_output if self.text_projection: - pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float() + pooled_output = self.text_encoder.transformer.text_projection(pooled_output) z.pooled = pooled_output return z @@ -240,10 +239,7 @@ class ClassicTextProcessingEngine(torch.nn.Module): return batch_chunks, token_count - def forward(self, texts): - if self.callback_before_encode is not None: - self.callback_before_encode(self, texts) - + def __call__(self, texts): batch_chunks, token_count = self.process_texts(texts) used_embeddings = {} diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 5b31861c..f933de64 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -401,7 +401,7 @@ def prepare_environment(): stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "aebabb94eaaa1a26a3b37128d1c079838c134623") + huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: diff --git a/modules/processing.py b/modules/processing.py index d7d36eea..2d0f13fa 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -100,7 +100,7 @@ def create_binary_mask(image, round=True): return image def txt2img_image_conditioning(sd_model, x, width, height): - if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models + if sd_model.is_inpaint: # Inpainting models # The "masked-image" in this case will just be all 0.5 since the entire image is masked. image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 @@ -111,24 +111,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): image_conditioning = image_conditioning.to(x.dtype) return image_conditioning - - elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models - - return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) - else: - if sd_model.is_sdxl_inpaint: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) - - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) - - return image_conditioning - # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. @@ -307,7 +290,7 @@ class StableDiffusionProcessing: self.comments[text] = 1 def txt2img_image_conditioning(self, x, width=None, height=None): - self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} + self.is_using_inpainting_conditioning = self.sd_model.is_inpaint return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height) @@ -497,6 +480,8 @@ class StableDiffusionProcessing: cache = caches[0] with devices.autocast(): + shared.sd_model.set_clip_skip(opts.CLIP_stop_at_last_layers) + cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) import backend.text_processing.classic_engine diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ffaadc89..f292073b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -150,8 +150,7 @@ class StableDiffusionModelHijack: self.extra_generation_params = {} def get_prompt_lengths(self, text, cond_stage_model): - _, token_count = cond_stage_model.process_texts([text]) - return token_count, cond_stage_model.get_target_prompt_token_count(token_count) + pass def redo_hijack(self, m): pass diff --git a/modules/sd_models.py b/modules/sd_models.py index 92537ece..d89a8326 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,11 +14,12 @@ import ldm.modules.midas as midas import gc from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches -from modules.shared import opts +from modules.shared import opts, cmd_opts from modules.timer import Timer import numpy as np -from modules_forge import loader +from backend.loader import forge_loader from backend import memory_management +from backend.args import dynamic_args model_dir = "Stable-diffusion" @@ -636,6 +637,7 @@ def get_obj_from_str(string, reload=False): return getattr(importlib.import_module(module, package=None), cls) +@torch.no_grad() def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -663,8 +665,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): # cache newly loaded model checkpoints_loaded[checkpoint_info] = state_dict.copy() - sd_model = loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict) + dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir + dynamic_args['emphasis_name'] = opts.emphasis + sd_model = forge_loader(state_dict) + timer.record("forge model load") + + sd_model.sd_checkpoint_info = checkpoint_info sd_model.filename = checkpoint_info.filename + sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 0934ccfe..a3a896a1 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -39,8 +39,10 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, sampler): + def __init__(self, sampler, model): super().__init__() + self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None @@ -64,10 +66,6 @@ class CFGDenoiser(torch.nn.Module): self.classic_ddim_eps_estimation = False - @property - def inner_model(self): - raise NotImplementedError() - def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) @@ -158,7 +156,7 @@ class CFGDenoiser(torch.nn.Module): original_x_dtype = x.dtype if self.classic_ddim_eps_estimation: - acd = self.inner_model.inner_model.alphas_cumprod + acd = self.inner_model.alphas_cumprod fake_sigmas = ((1 - acd) / acd) ** 0.5 real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] real_sigma_data = 1.0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 01723c30..fee65f83 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -237,7 +237,7 @@ class Sampler: self.eta_infotext_field = 'Eta' self.eta_default = 1.0 - self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn') + self.conditioning_key = 'crossattn' self.p = None self.model_wrap_cfg = None diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index f0b88c95..a702f7b2 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,12 +1,11 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices -from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 +from modules import sd_samplers_common, sd_samplers_extra, sd_schedulers, devices +from modules.sd_samplers_cfg_denoiser import CFGDenoiser from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts -import modules.shared as shared from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup @@ -51,21 +50,6 @@ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers} -class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): - @property - def inner_model(self): - if self.model_wrap is None: - denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None) - - if denoiser_constructor is not None: - self.model_wrap = denoiser_constructor() - else: - denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) - - return self.model_wrap - - class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model, options=None): super().__init__(funcname) @@ -75,8 +59,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler): self.options = options or {} self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - self.model_wrap_cfg = CFGDenoiserKDiffusion(self) - self.model_wrap = self.model_wrap_cfg.inner_model + self.model_wrap = self.model_wrap_cfg = CFGDenoiser(self, sd_model) + self.predictor = sd_model.forge_objects.unet.model.predictor + + self.model_wrap_cfg.sigmas = self.predictor.sigmas + self.model_wrap_cfg.log_sigmas = self.predictor.sigmas.log() def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -92,13 +79,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler): scheduler = sd_schedulers.schedulers_map.get(scheduler_name) - m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item() + m_sigma_min, m_sigma_max = self.predictor.sigmas[0].item(), self.predictor.sigmas[-1].item() sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) elif scheduler is None or scheduler.function is None: - sigmas = self.model_wrap.get_sigmas(steps) + raise ValueError('Wrong scheduler!') else: sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max} @@ -120,7 +107,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): p.extra_generation_params["Schedule rho"] = opts.rho if scheduler.need_inner_model: - sigmas_kwargs['inner_model'] = self.model_wrap + sigmas_kwargs['inner_model'] = self.model_wrap_cfg if scheduler.label == 'Beta': p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha @@ -134,11 +121,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler): return sigmas.cpu() def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - unet_patcher = self.model_wrap.inner_model.forge_objects.unet - sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) + unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet + sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) + self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device) + self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) @@ -196,11 +183,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler): return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - unet_patcher = self.model_wrap.inner_model.forge_objects.unet - sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) + unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet + sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x) - self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device) - self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device) + self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device) + self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device) steps = steps or p.steps @@ -219,8 +206,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler): extra_params_kwargs['n'] = steps if 'sigma_min' in parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() + extra_params_kwargs['sigma_min'] = self.model_wrap_cfg.sigmas[0].item() + extra_params_kwargs['sigma_max'] = self.model_wrap_cfg.sigmas[-1].item() if 'sigmas' in parameters: extra_params_kwargs['sigmas'] = sigmas diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 08956497..c5a97290 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -49,10 +49,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): class CFGDenoiserTimesteps(CFGDenoiser): - def __init__(self, sampler): - super().__init__(sampler) + def __init__(self, sampler, model): + super().__init__(sampler, model) - self.alphas = shared.sd_model.alphas_cumprod + self.alphas = model.forge_objects.unet.model.predictor.alphas_cumprod self.classic_ddim_eps_estimation = True def get_pred_x0(self, x_in, x_out, sigma): @@ -66,14 +66,6 @@ class CFGDenoiserTimesteps(CFGDenoiser): return pred_x0 - @property - def inner_model(self): - if self.model_wrap is None: - denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser - self.model_wrap = denoiser(shared.sd_model) - - return self.model_wrap - class CompVisSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): @@ -83,8 +75,10 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_infotext_field = 'Eta DDIM' self.eta_default = 0.0 - self.model_wrap_cfg = CFGDenoiserTimesteps(self) - self.model_wrap = self.model_wrap_cfg.inner_model + self.model_wrap = self.model_wrap_cfg = CFGDenoiserTimesteps(self, sd_model) + self.predictor = sd_model.forge_objects.unet.model.predictor + + self.model_wrap.inner_model.alphas_cumprod = self.predictor.alphas_cumprod def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index 180e4389..a63179e6 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -10,7 +10,7 @@ from modules.torch_utils import float64 @torch.no_grad() def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): - alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) @@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction. The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0]. """ - alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) @@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None @torch.no_grad() def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): - alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x)) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) @@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC): def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): - alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas_cumprod = model.inner_model.alphas_cumprod ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means diff --git a/modules/sd_schedulers.py b/modules/sd_schedulers.py index af873dc9..7fc6ae1b 100644 --- a/modules/sd_schedulers.py +++ b/modules/sd_schedulers.py @@ -130,7 +130,7 @@ def beta_scheduler(n, sigma_min, sigma_max, inner_model, device): def turbo_scheduler(n, sigma_min, sigma_max, inner_model, device): unet = inner_model.inner_model.forge_objects.unet timesteps = torch.flip(torch.arange(1, n + 1) * float(1000.0 / n) - 1, (0,)).round().long().clip(0, 999) - sigmas = unet.model.model_sampling.sigma(timesteps) + sigmas = unet.model.predictor.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return sigmas.to(device) diff --git a/modules/ui.py b/modules/ui.py index 676132ff..7acb63d2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -181,14 +181,14 @@ def update_token_counter(text, steps, styles, *, is_positive=True): prompt_schedules = [[[steps, text]]] try: - cond_stage_model = sd_models.model_data.sd_model.cond_stage_model - assert cond_stage_model is not None + get_prompt_lengths_on_ui = sd_models.model_data.sd_model.get_prompt_lengths_on_ui + assert get_prompt_lengths_on_ui is not None except Exception: return f"?/?" flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0]) + token_count, max_length = max([get_prompt_lengths_on_ui(prompt) for prompt in prompts], key=lambda args: args[0]) return f"{token_count}/{max_length}" diff --git a/modules_forge/loader.py b/modules_forge/forge_old_loader.py similarity index 100% rename from modules_forge/loader.py rename to modules_forge/forge_old_loader.py