rework sd1.5 and sdxl from scratch

This commit is contained in:
layerdiffusion
2024-08-04 20:23:01 -07:00
parent e28e11fa97
commit 0863765173
25 changed files with 440 additions and 162 deletions

View File

@@ -57,3 +57,9 @@ parser.add_argument("--cuda-stream", action="store_true")
parser.add_argument("--pin-shared-memory", action="store_true")
args = parser.parse_known_args()[0]
# Some dynamic args that may be changed by webui rather than cmd flags.
dynamic_args = dict(
embedding_dir='./embeddings',
emphasis_name='original'
)

View File

@@ -0,0 +1,60 @@
import torch
class ForgeObjects:
def __init__(self, unet, clip, vae, clipvision):
self.unet = unet
self.clip = clip
self.vae = vae
self.clipvision = clipvision
def shallow_copy(self):
return ForgeObjects(
self.unet,
self.clip,
self.vae,
self.clipvision
)
class ForgeDiffusionEngine:
matched_guesses = []
def __init__(self, estimated_config, huggingface_components):
self.model_config = estimated_config
self.is_inpaint = estimated_config.inpaint_model()
self.forge_objects = None
self.forge_objects_original = None
self.forge_objects_after_applying_lora = None
self.current_lora_hash = str([])
self.tiling_enabled = False
# WebUI Dirty Legacy
self.cond_stage_key = 'txt'
self.is_sd3 = False
self.latent_channels = 4
self.is_sdxl = False
self.is_sdxl_inpaint = False
self.is_sd2 = False
self.is_sd1 = False
self.is_ssd = False
def set_clip_skip(self, clip_skip):
pass
def get_first_stage_encoding(self, x):
return x # legacy code, do not change
def get_learned_conditioning(self, prompt: list[str]):
pass
def encode_first_stage(self, x):
pass
def decode_first_stage(self, x):
pass
def get_prompt_lengths_on_ui(self, prompt):
pass

View File

@@ -1 +1,81 @@
#
import torch
from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
class StableDiffusion(ForgeDiffusionEngine):
matched_guesses = [model_list.SD15]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
clip = CLIP(
model_dict={
'clip_l': huggingface_components['text_encoder']
},
tokenizer_dict={
'clip_l': huggingface_components['tokenizer']
}
)
vae = VAE(model=huggingface_components['vae'])
unet = UnetPatcher.from_model(
model=huggingface_components['unet'],
diffusers_scheduler=huggingface_components['scheduler']
)
self.text_processing_engine = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_l,
tokenizer=clip.tokenizer.clip_l,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_l',
embedding_expected_shape=768,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=False,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=False,
final_layer_norm=True,
callback_before_encode=None
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
# WebUI Legacy
self.is_sd1 = True
def set_clip_skip(self, clip_skip):
self.text_processing_engine.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond = self.text_processing_engine(prompt)
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine.process_texts([prompt])
return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)

View File

@@ -1 +1,132 @@
#
import torch
from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
from backend.nn.unet import Timestep
class StableDiffusionXL(ForgeDiffusionEngine):
matched_guesses = [model_list.SDXL]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
clip = CLIP(
model_dict={
'clip_l': huggingface_components['text_encoder'],
'clip_g': huggingface_components['text_encoder_2']
},
tokenizer_dict={
'clip_l': huggingface_components['tokenizer'],
'clip_g': huggingface_components['tokenizer_2']
}
)
vae = VAE(model=huggingface_components['vae'])
unet = UnetPatcher.from_model(
model=huggingface_components['unet'],
diffusers_scheduler=huggingface_components['scheduler']
)
self.text_processing_engine_l = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_l,
tokenizer=clip.tokenizer.clip_l,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_l',
embedding_expected_shape=2048,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=False,
minimal_clip_skip=2,
clip_skip=2,
return_pooled=False,
final_layer_norm=False,
)
self.text_processing_engine_g = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_g,
tokenizer=clip.tokenizer.clip_g,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_g',
embedding_expected_shape=2048,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=True,
minimal_clip_skip=2,
clip_skip=2,
return_pooled=True,
final_layer_norm=False,
)
self.embedder = Timestep(256)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
# WebUI Legacy
self.is_sdxl = True
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip
self.text_processing_engine_g.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_l = self.text_processing_engine_l(prompt)
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
width = getattr(prompt, 'width', 1024) or 1024
height = getattr(prompt, 'height', 1024) or 1024
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
crop_w = 0
crop_h = 0
target_width = width
target_height = height
out = [
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))
]
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
if force_zero_negative_prompt:
clip_pooled = torch.zeros_like(clip_pooled)
cond_l = torch.zeros_like(cond_l)
cond_g = torch.zeros_like(cond_g)
cond = dict(
crossattn=torch.cat([cond_l, cond_g], dim=2),
vector=torch.cat([clip_pooled, flat], dim=1),
)
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine_l.process_texts([prompt])
return token_count, self.text_processing_engine_l.get_target_prompt_token_count(token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)

View File

@@ -1,16 +1,23 @@
import os
import torch
import logging
import importlib
import huggingface_guess
from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sdxl import StableDiffusionXL
possible_models = [StableDiffusion, StableDiffusionXL]
logging.getLogger("diffusers").setLevel(logging.ERROR)
dir_path = os.path.dirname(__file__)
@@ -27,61 +34,84 @@ def load_component(guess, component_name, lib_name, cls_name, repo_path, state_d
cls = getattr(importlib.import_module(lib_name), cls_name)
return cls.from_pretrained(os.path.join(repo_path, component_name))
if cls_name in ['AutoencoderKL']:
sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.'])
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations():
model = IntegratedAutoencoderKL.from_config(config)
load_state_dict(model, sd)
load_state_dict(model, state_dict)
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
if component_name == 'text_encoder':
sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.'])
elif component_name == 'text_encoder_2':
sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.'])
else:
raise ValueError(f"Wrong component_name: {component_name}")
if 'model.text_projection' in sd:
sd = transformers_convert(sd, "model.", "transformer.text_model.", 32)
sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection",
"model.text_projection.weight": "text_projection",
"model.logit_scale": "logit_scale"})
config = CLIPTextConfig.from_pretrained(config_path)
with modeling_utils.no_init_weights():
with using_forge_operations():
model = IntegratedCLIP(config)
load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
'transformer.text_model.embeddings.position_ids'])
load_state_dict(model, state_dict, ignore_errors=[
'transformer.text_projection.weight',
'transformer.text_model.embeddings.position_ids',
'logit_scale'
], log_name=cls_name)
return model
if cls_name == 'UNet2DConditionModel':
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
with using_forge_operations():
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
model._internal_dict = guess.unet_config
load_state_dict(model, sd)
load_state_dict(model, state_dict)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None
def load_huggingface_components(sd):
def split_state_dict(sd):
guess = huggingface_guess.guess(sd)
repo_name = guess.huggingface_repo
state_dict = {
'unet': try_filter_state_dict(sd, ['model.diffusion_model.']),
'vae': try_filter_state_dict(sd, guess.vae_key_prefix)
}
sd = guess.process_clip_state_dict(sd)
guess.clip_target = guess.clip_target(sd)
for k, v in guess.clip_target.items():
state_dict[v] = try_filter_state_dict(sd, [k + '.'])
state_dict['ignore'] = sd
print_dict = {k: len(v) for k, v in state_dict.items()}
print(f'StateDict Keys: {print_dict}')
del state_dict['ignore']
return state_dict, guess
@torch.no_grad()
def forge_loader(sd):
state_dicts, estimated_config = split_state_dict(sd)
repo_name = estimated_config.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)
config = DiffusionPipeline.load_config(local_path)
result = {"repo_path": local_path}
config: dict = DiffusionPipeline.load_config(local_path)
huggingface_components = {}
for component_name, v in config.items():
if isinstance(v, list) and len(v) == 2:
lib_name, cls_name = v
component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
component_sd = state_dicts.get(component_name, None)
component = load_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd)
if component_sd is not None:
del state_dicts[component_name]
if component is not None:
result[component_name] = component
return result
huggingface_components[component_name] = component
for M in possible_models:
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
print('Failed to recognize model type!')
return None

View File

@@ -5,14 +5,14 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
class KModel(torch.nn.Module):
def __init__(self, huggingface_components, storage_dtype, computation_dtype):
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype):
super().__init__()
self.storage_dtype = storage_dtype
self.computation_dtype = computation_dtype
self.diffusion_model = huggingface_components['unet']
self.predictor = k_prediction_from_diffusers_scheduler(huggingface_components['scheduler'])
self.diffusion_model = model
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t

View File

@@ -108,11 +108,11 @@ class Prediction(AbstractPrediction):
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('alphas_cumprod', alphas_cumprod.float())
self.register_buffer('sigmas', sigmas.float())
self.register_buffer('log_sigmas', sigmas.log().float())
return
@property
def sigma_min(self):

View File

@@ -7,5 +7,6 @@ class IntegratedCLIP(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.transformer = CLIPTextModel(config)
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
embed_dim = config.hidden_size
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))

View File

@@ -1,24 +1,10 @@
import torch
from backend import memory_management
from backend.patcher.base import ModelPatcher
class JointTokenizer:
def __init__(self, huggingface_components):
self.clip_l = huggingface_components.get('tokenizer', None)
self.clip_g = huggingface_components.get('tokenizer_2', None)
class JointCLIPTextEncoder(torch.nn.Module):
def __init__(self, huggingface_components):
super().__init__()
self.clip_l = huggingface_components.get('text_encoder', None)
self.clip_g = huggingface_components.get('text_encoder_2', None)
from backend.nn.base import ModuleDict, ObjectDict
class CLIP:
def __init__(self, huggingface_components=None, no_init=False):
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
if no_init:
return
@@ -26,8 +12,8 @@ class CLIP:
offload_device = memory_management.text_encoder_offload_device()
text_encoder_dtype = memory_management.text_encoder_dtype(load_device)
self.cond_stage_model = JointCLIPTextEncoder(huggingface_components)
self.tokenizer = JointTokenizer(huggingface_components)
self.cond_stage_model = ModuleDict(model_dict)
self.tokenizer = ObjectDict(tokenizer_dict)
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)

View File

@@ -1,12 +1,26 @@
import copy
import torch
from backend.modules.k_model import KModel
from backend.patcher.base import ModelPatcher
from backend import memory_management
class UnetPatcher(ModelPatcher):
def __init__(self, model, *args, **kwargs):
super().__init__(model, *args, **kwargs)
@classmethod
def from_model(cls, model, diffusers_scheduler):
parameters = memory_management.module_size(model)
unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = memory_management.get_torch_device()
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
model.to(device=initial_load_device, dtype=unet_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.controlnet_linked_list = None
self.extra_preserved_memory_during_sampling = 0
self.extra_model_patchers_during_sampling = []

View File

@@ -301,14 +301,15 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_
def sampling_function(self, denoiser_params, cond_scale, cond_composition):
model = self.inner_model.inner_model.forge_objects.unet.model
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
unet_patcher = self.inner_model.forge_objects.unet
model = unet_patcher.model
control = unet_patcher.controlnet_linked_list
extra_concat_condition = unet_patcher.extra_concat_condition
x = denoiser_params.x
timestep = denoiser_params.sigma
uncond = compile_conditions(denoiser_params.text_uncond)
cond = compile_weighted_conditions(denoiser_params.text_cond, cond_composition)
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
model_options = unet_patcher.model_options
seed = self.p.seeds[0]
if extra_concat_condition is not None:

View File

@@ -1,14 +1,15 @@
import torch
def load_state_dict(model, sd, ignore_errors=[]):
def load_state_dict(model, sd, ignore_errors=[], log_name=None):
missing, unexpected = model.load_state_dict(sd, strict=False)
missing = [x for x in missing if x not in ignore_errors]
unexpected = [x for x in unexpected if x not in ignore_errors]
log_name = log_name or type(model).__name__
if len(missing) > 0:
print(f'{type(model).__name__} Missing: {missing}')
print(f'{log_name} Missing: {missing}')
if len(unexpected) > 0:
print(f'{type(model).__name__} Unexpected: {unexpected}')
print(f'{log_name} Unexpected: {unexpected}')
return

View File

@@ -5,7 +5,6 @@ from collections import namedtuple
from backend.text_processing import parsing, emphasis
from backend.text_processing.textual_inversion import EmbeddingDatabase
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
last_extra_generation_params = {}
@@ -47,11 +46,12 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
return torch.stack(vecs)
class ClassicTextProcessingEngine(torch.nn.Module):
def __init__(self, text_encoder, tokenizer, chunk_length=75,
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original",
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True,
callback_before_encode=None):
class ClassicTextProcessingEngine:
def __init__(
self, text_encoder, tokenizer, chunk_length=75,
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original",
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True
):
super().__init__()
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
@@ -71,7 +71,6 @@ class ClassicTextProcessingEngine(torch.nn.Module):
self.clip_skip = clip_skip
self.return_pooled = return_pooled
self.final_layer_norm = final_layer_norm
self.callback_before_encode = callback_before_encode
self.chunk_length = chunk_length
@@ -133,7 +132,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
pooled_output = outputs.pooler_output
if self.text_projection:
pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float()
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
z.pooled = pooled_output
return z
@@ -240,10 +239,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
return batch_chunks, token_count
def forward(self, texts):
if self.callback_before_encode is not None:
self.callback_before_encode(self, texts)
def __call__(self, texts):
batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {}

View File

@@ -401,7 +401,7 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "aebabb94eaaa1a26a3b37128d1c079838c134623")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:

View File

@@ -100,7 +100,7 @@ def create_binary_mask(image, round=True):
return image
def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
if sd_model.is_inpaint: # Inpainting models
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
@@ -111,24 +111,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
if sd_model.is_sdxl_inpaint:
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@@ -307,7 +290,7 @@ class StableDiffusionProcessing:
self.comments[text] = 1
def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
self.is_using_inpainting_conditioning = self.sd_model.is_inpaint
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
@@ -497,6 +480,8 @@ class StableDiffusionProcessing:
cache = caches[0]
with devices.autocast():
shared.sd_model.set_clip_skip(opts.CLIP_stop_at_last_layers)
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
import backend.text_processing.classic_engine

View File

@@ -150,8 +150,7 @@ class StableDiffusionModelHijack:
self.extra_generation_params = {}
def get_prompt_lengths(self, text, cond_stage_model):
_, token_count = cond_stage_model.process_texts([text])
return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
pass
def redo_hijack(self, m):
pass

View File

@@ -14,11 +14,12 @@ import ldm.modules.midas as midas
import gc
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.shared import opts
from modules.shared import opts, cmd_opts
from modules.timer import Timer
import numpy as np
from modules_forge import loader
from backend.loader import forge_loader
from backend import memory_management
from backend.args import dynamic_args
model_dir = "Stable-diffusion"
@@ -636,6 +637,7 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls)
@torch.no_grad()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -663,8 +665,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
sd_model = loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
dynamic_args['emphasis_name'] = opts.emphasis
sd_model = forge_loader(state_dict)
timer.record("forge model load")
sd_model.sd_checkpoint_info = checkpoint_info
sd_model.filename = checkpoint_info.filename
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

View File

@@ -39,8 +39,10 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""
def __init__(self, sampler):
def __init__(self, sampler, model):
super().__init__()
self.inner_model = model
self.model_wrap = None
self.mask = None
self.nmask = None
@@ -64,10 +66,6 @@ class CFGDenoiser(torch.nn.Module):
self.classic_ddim_eps_estimation = False
@property
def inner_model(self):
raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
@@ -158,7 +156,7 @@ class CFGDenoiser(torch.nn.Module):
original_x_dtype = x.dtype
if self.classic_ddim_eps_estimation:
acd = self.inner_model.inner_model.alphas_cumprod
acd = self.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
real_sigma_data = 1.0

View File

@@ -237,7 +237,7 @@ class Sampler:
self.eta_infotext_field = 'Eta'
self.eta_default = 1.0
self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
self.conditioning_key = 'crossattn'
self.p = None
self.model_wrap_cfg = None

View File

@@ -1,12 +1,11 @@
import torch
import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
from modules import sd_samplers_common, sd_samplers_extra, sd_schedulers, devices
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
@@ -51,21 +50,6 @@ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property
def inner_model(self):
if self.model_wrap is None:
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
if denoiser_constructor is not None:
self.model_wrap = denoiser_constructor()
else:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
return self.model_wrap
class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname)
@@ -75,8 +59,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
self.model_wrap = self.model_wrap_cfg.inner_model
self.model_wrap = self.model_wrap_cfg = CFGDenoiser(self, sd_model)
self.predictor = sd_model.forge_objects.unet.model.predictor
self.model_wrap_cfg.sigmas = self.predictor.sigmas
self.model_wrap_cfg.log_sigmas = self.predictor.sigmas.log()
def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -92,13 +79,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
m_sigma_min, m_sigma_max = self.predictor.sigmas[0].item(), self.predictor.sigmas[-1].item()
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif scheduler is None or scheduler.function is None:
sigmas = self.model_wrap.get_sigmas(steps)
raise ValueError('Wrong scheduler!')
else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
@@ -120,7 +107,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
p.extra_generation_params["Schedule rho"] = opts.rho
if scheduler.need_inner_model:
sigmas_kwargs['inner_model'] = self.model_wrap
sigmas_kwargs['inner_model'] = self.model_wrap_cfg
if scheduler.label == 'Beta':
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
@@ -134,11 +121,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return sigmas.cpu()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
@@ -196,11 +183,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
steps = steps or p.steps
@@ -219,8 +206,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
extra_params_kwargs['n'] = steps
if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
extra_params_kwargs['sigma_min'] = self.model_wrap_cfg.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap_cfg.sigmas[-1].item()
if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigmas

View File

@@ -49,10 +49,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
class CFGDenoiserTimesteps(CFGDenoiser):
def __init__(self, sampler):
super().__init__(sampler)
def __init__(self, sampler, model):
super().__init__(sampler, model)
self.alphas = shared.sd_model.alphas_cumprod
self.alphas = model.forge_objects.unet.model.predictor.alphas_cumprod
self.classic_ddim_eps_estimation = True
def get_pred_x0(self, x_in, x_out, sigma):
@@ -66,14 +66,6 @@ class CFGDenoiserTimesteps(CFGDenoiser):
return pred_x0
@property
def inner_model(self):
if self.model_wrap is None:
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
self.model_wrap = denoiser(shared.sd_model)
return self.model_wrap
class CompVisSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model):
@@ -83,8 +75,10 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_infotext_field = 'Eta DDIM'
self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
self.model_wrap = self.model_wrap_cfg.inner_model
self.model_wrap = self.model_wrap_cfg = CFGDenoiserTimesteps(self, sd_model)
self.predictor = sd_model.forge_objects.unet.model.predictor
self.model_wrap.inner_model.alphas_cumprod = self.predictor.alphas_cumprod
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)

View File

@@ -10,7 +10,7 @@ from modules.torch_utils import float64
@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
@@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC):
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas_cumprod = model.inner_model.alphas_cumprod
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means

View File

@@ -130,7 +130,7 @@ def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
def turbo_scheduler(n, sigma_min, sigma_max, inner_model, device):
unet = inner_model.inner_model.forge_objects.unet
timesteps = torch.flip(torch.arange(1, n + 1) * float(1000.0 / n) - 1, (0,)).round().long().clip(0, 999)
sigmas = unet.model.model_sampling.sigma(timesteps)
sigmas = unet.model.predictor.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return sigmas.to(device)

View File

@@ -181,14 +181,14 @@ def update_token_counter(text, steps, styles, *, is_positive=True):
prompt_schedules = [[[steps, text]]]
try:
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
assert cond_stage_model is not None
get_prompt_lengths_on_ui = sd_models.model_data.sd_model.get_prompt_lengths_on_ui
assert get_prompt_lengths_on_ui is not None
except Exception:
return f"<span class='gr-box gr-text-input'>?/?</span>"
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0])
token_count, max_length = max([get_prompt_lengths_on_ui(prompt) for prompt in prompts], key=lambda args: args[0])
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"