mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
rework sd1.5 and sdxl from scratch
This commit is contained in:
@@ -57,3 +57,9 @@ parser.add_argument("--cuda-stream", action="store_true")
|
||||
parser.add_argument("--pin-shared-memory", action="store_true")
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
# Some dynamic args that may be changed by webui rather than cmd flags.
|
||||
dynamic_args = dict(
|
||||
embedding_dir='./embeddings',
|
||||
emphasis_name='original'
|
||||
)
|
||||
|
||||
60
backend/diffusion_engine/base.py
Normal file
60
backend/diffusion_engine/base.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ForgeObjects:
|
||||
def __init__(self, unet, clip, vae, clipvision):
|
||||
self.unet = unet
|
||||
self.clip = clip
|
||||
self.vae = vae
|
||||
self.clipvision = clipvision
|
||||
|
||||
def shallow_copy(self):
|
||||
return ForgeObjects(
|
||||
self.unet,
|
||||
self.clip,
|
||||
self.vae,
|
||||
self.clipvision
|
||||
)
|
||||
|
||||
|
||||
class ForgeDiffusionEngine:
|
||||
matched_guesses = []
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
self.model_config = estimated_config
|
||||
self.is_inpaint = estimated_config.inpaint_model()
|
||||
|
||||
self.forge_objects = None
|
||||
self.forge_objects_original = None
|
||||
self.forge_objects_after_applying_lora = None
|
||||
|
||||
self.current_lora_hash = str([])
|
||||
self.tiling_enabled = False
|
||||
|
||||
# WebUI Dirty Legacy
|
||||
self.cond_stage_key = 'txt'
|
||||
self.is_sd3 = False
|
||||
self.latent_channels = 4
|
||||
self.is_sdxl = False
|
||||
self.is_sdxl_inpaint = False
|
||||
self.is_sd2 = False
|
||||
self.is_sd1 = False
|
||||
self.is_ssd = False
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
pass
|
||||
|
||||
def get_first_stage_encoding(self, x):
|
||||
return x # legacy code, do not change
|
||||
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
pass
|
||||
|
||||
def encode_first_stage(self, x):
|
||||
pass
|
||||
|
||||
def decode_first_stage(self, x):
|
||||
pass
|
||||
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
pass
|
||||
@@ -1 +1,81 @@
|
||||
#
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
class StableDiffusion(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD15]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler']
|
||||
)
|
||||
|
||||
self.text_processing_engine = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=False,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=False,
|
||||
final_layer_norm=True,
|
||||
callback_before_encode=None
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd1 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
cond = self.text_processing_engine(prompt)
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine.process_texts([prompt])
|
||||
return token_count, self.text_processing_engine.get_target_prompt_token_count(token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
@@ -1 +1,132 @@
|
||||
#
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.nn.unet import Timestep
|
||||
|
||||
|
||||
class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SDXL]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
'clip_g': huggingface_components['text_encoder_2']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
'clip_g': huggingface_components['tokenizer_2']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler']
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=2048,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=False,
|
||||
minimal_clip_skip=2,
|
||||
clip_skip=2,
|
||||
return_pooled=False,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=2048,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=2,
|
||||
clip_skip=2,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sdxl = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_l = self.text_processing_engine_l(prompt)
|
||||
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
|
||||
|
||||
width = getattr(prompt, 'width', 1024) or 1024
|
||||
height = getattr(prompt, 'height', 1024) or 1024
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
crop_w = 0
|
||||
crop_h = 0
|
||||
target_width = width
|
||||
target_height = height
|
||||
|
||||
out = [
|
||||
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
|
||||
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
|
||||
self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))
|
||||
]
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
clip_pooled = torch.zeros_like(clip_pooled)
|
||||
cond_l = torch.zeros_like(cond_l)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
|
||||
cond = dict(
|
||||
crossattn=torch.cat([cond_l, cond_g], dim=2),
|
||||
vector=torch.cat([clip_pooled, flat], dim=1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine_l.process_texts([prompt])
|
||||
return token_count, self.text_processing_engine_l.get_target_prompt_token_count(token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
import importlib
|
||||
import huggingface_guess
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
from backend.state_dict import try_filter_state_dict, transformers_convert, load_state_dict, state_dict_key_replace
|
||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
|
||||
from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusionXL]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
dir_path = os.path.dirname(__file__)
|
||||
@@ -27,61 +34,84 @@ def load_component(guess, component_name, lib_name, cls_name, repo_path, state_d
|
||||
cls = getattr(importlib.import_module(lib_name), cls_name)
|
||||
return cls.from_pretrained(os.path.join(repo_path, component_name))
|
||||
if cls_name in ['AutoencoderKL']:
|
||||
sd = try_filter_state_dict(state_dict, ['first_stage_model.', 'vae.'])
|
||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
load_state_dict(model, sd)
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||
if component_name == 'text_encoder':
|
||||
sd = try_filter_state_dict(state_dict, ['cond_stage_model.', 'conditioner.embedders.0.'])
|
||||
elif component_name == 'text_encoder_2':
|
||||
sd = try_filter_state_dict(state_dict, ['conditioner.embedders.1.'])
|
||||
else:
|
||||
raise ValueError(f"Wrong component_name: {component_name}")
|
||||
|
||||
if 'model.text_projection' in sd:
|
||||
sd = transformers_convert(sd, "model.", "transformer.text_model.", 32)
|
||||
sd = state_dict_key_replace(sd, {"model.text_projection": "text_projection",
|
||||
"model.text_projection.weight": "text_projection",
|
||||
"model.logit_scale": "logit_scale"})
|
||||
|
||||
config = CLIPTextConfig.from_pretrained(config_path)
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations():
|
||||
model = IntegratedCLIP(config)
|
||||
|
||||
load_state_dict(model, sd, ignore_errors=['text_projection', 'logit_scale',
|
||||
'transformer.text_model.embeddings.position_ids'])
|
||||
load_state_dict(model, state_dict, ignore_errors=[
|
||||
'transformer.text_projection.weight',
|
||||
'transformer.text_model.embeddings.position_ids',
|
||||
'logit_scale'
|
||||
], log_name=cls_name)
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
sd = try_filter_state_dict(state_dict, ['model.diffusion_model.'])
|
||||
|
||||
with using_forge_operations():
|
||||
model = IntegratedUNet2DConditionModel.from_config(guess.unet_config)
|
||||
model._internal_dict = guess.unet_config
|
||||
|
||||
load_state_dict(model, sd)
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
return None
|
||||
|
||||
|
||||
def load_huggingface_components(sd):
|
||||
def split_state_dict(sd):
|
||||
guess = huggingface_guess.guess(sd)
|
||||
repo_name = guess.huggingface_repo
|
||||
|
||||
state_dict = {
|
||||
'unet': try_filter_state_dict(sd, ['model.diffusion_model.']),
|
||||
'vae': try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
}
|
||||
|
||||
sd = guess.process_clip_state_dict(sd)
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
|
||||
for k, v in guess.clip_target.items():
|
||||
state_dict[v] = try_filter_state_dict(sd, [k + '.'])
|
||||
|
||||
state_dict['ignore'] = sd
|
||||
|
||||
print_dict = {k: len(v) for k, v in state_dict.items()}
|
||||
print(f'StateDict Keys: {print_dict}')
|
||||
|
||||
del state_dict['ignore']
|
||||
|
||||
return state_dict, guess
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forge_loader(sd):
|
||||
state_dicts, estimated_config = split_state_dict(sd)
|
||||
repo_name = estimated_config.huggingface_repo
|
||||
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
config = DiffusionPipeline.load_config(local_path)
|
||||
result = {"repo_path": local_path}
|
||||
config: dict = DiffusionPipeline.load_config(local_path)
|
||||
huggingface_components = {}
|
||||
for component_name, v in config.items():
|
||||
if isinstance(v, list) and len(v) == 2:
|
||||
lib_name, cls_name = v
|
||||
component = load_component(guess, component_name, lib_name, cls_name, local_path, sd)
|
||||
component_sd = state_dicts.get(component_name, None)
|
||||
component = load_component(estimated_config, component_name, lib_name, cls_name, local_path, component_sd)
|
||||
if component_sd is not None:
|
||||
del state_dicts[component_name]
|
||||
if component is not None:
|
||||
result[component_name] = component
|
||||
return result
|
||||
huggingface_components[component_name] = component
|
||||
|
||||
for M in possible_models:
|
||||
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
|
||||
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
|
||||
print('Failed to recognize model type!')
|
||||
return None
|
||||
|
||||
@@ -5,14 +5,14 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
|
||||
|
||||
|
||||
class KModel(torch.nn.Module):
|
||||
def __init__(self, huggingface_components, storage_dtype, computation_dtype):
|
||||
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype):
|
||||
super().__init__()
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
self.computation_dtype = computation_dtype
|
||||
|
||||
self.diffusion_model = huggingface_components['unet']
|
||||
self.predictor = k_prediction_from_diffusers_scheduler(huggingface_components['scheduler'])
|
||||
self.diffusion_model = model
|
||||
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
|
||||
@@ -108,11 +108,11 @@ class Prediction(AbstractPrediction):
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod.float())
|
||||
self.register_buffer('sigmas', sigmas.float())
|
||||
self.register_buffer('log_sigmas', sigmas.log().float())
|
||||
return
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
|
||||
@@ -7,5 +7,6 @@ class IntegratedCLIP(torch.nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.transformer = CLIPTextModel(config)
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
embed_dim = config.hidden_size
|
||||
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
|
||||
@@ -1,24 +1,10 @@
|
||||
import torch
|
||||
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
class JointTokenizer:
|
||||
def __init__(self, huggingface_components):
|
||||
self.clip_l = huggingface_components.get('tokenizer', None)
|
||||
self.clip_g = huggingface_components.get('tokenizer_2', None)
|
||||
|
||||
|
||||
class JointCLIPTextEncoder(torch.nn.Module):
|
||||
def __init__(self, huggingface_components):
|
||||
super().__init__()
|
||||
self.clip_l = huggingface_components.get('text_encoder', None)
|
||||
self.clip_g = huggingface_components.get('text_encoder_2', None)
|
||||
from backend.nn.base import ModuleDict, ObjectDict
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, huggingface_components=None, no_init=False):
|
||||
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
|
||||
if no_init:
|
||||
return
|
||||
|
||||
@@ -26,8 +12,8 @@ class CLIP:
|
||||
offload_device = memory_management.text_encoder_offload_device()
|
||||
text_encoder_dtype = memory_management.text_encoder_dtype(load_device)
|
||||
|
||||
self.cond_stage_model = JointCLIPTextEncoder(huggingface_components)
|
||||
self.tokenizer = JointTokenizer(huggingface_components)
|
||||
self.cond_stage_model = ModuleDict(model_dict)
|
||||
self.tokenizer = ObjectDict(tokenizer_dict)
|
||||
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
|
||||
@@ -1,12 +1,26 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from backend.modules.k_model import KModel
|
||||
from backend.patcher.base import ModelPatcher
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(model, *args, **kwargs)
|
||||
@classmethod
|
||||
def from_model(cls, model, diffusers_scheduler):
|
||||
parameters = memory_management.module_size(model)
|
||||
unet_dtype = memory_management.unet_dtype(model_params=parameters)
|
||||
load_device = memory_management.get_torch_device()
|
||||
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
|
||||
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
|
||||
model.to(device=initial_load_device, dtype=unet_dtype)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
|
||||
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.controlnet_linked_list = None
|
||||
self.extra_preserved_memory_during_sampling = 0
|
||||
self.extra_model_patchers_during_sampling = []
|
||||
|
||||
@@ -301,14 +301,15 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_
|
||||
|
||||
|
||||
def sampling_function(self, denoiser_params, cond_scale, cond_composition):
|
||||
model = self.inner_model.inner_model.forge_objects.unet.model
|
||||
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
|
||||
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
|
||||
unet_patcher = self.inner_model.forge_objects.unet
|
||||
model = unet_patcher.model
|
||||
control = unet_patcher.controlnet_linked_list
|
||||
extra_concat_condition = unet_patcher.extra_concat_condition
|
||||
x = denoiser_params.x
|
||||
timestep = denoiser_params.sigma
|
||||
uncond = compile_conditions(denoiser_params.text_uncond)
|
||||
cond = compile_weighted_conditions(denoiser_params.text_cond, cond_composition)
|
||||
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
|
||||
model_options = unet_patcher.model_options
|
||||
seed = self.p.seeds[0]
|
||||
|
||||
if extra_concat_condition is not None:
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def load_state_dict(model, sd, ignore_errors=[]):
|
||||
def load_state_dict(model, sd, ignore_errors=[], log_name=None):
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
missing = [x for x in missing if x not in ignore_errors]
|
||||
unexpected = [x for x in unexpected if x not in ignore_errors]
|
||||
log_name = log_name or type(model).__name__
|
||||
if len(missing) > 0:
|
||||
print(f'{type(model).__name__} Missing: {missing}')
|
||||
print(f'{log_name} Missing: {missing}')
|
||||
if len(unexpected) > 0:
|
||||
print(f'{type(model).__name__} Unexpected: {unexpected}')
|
||||
print(f'{log_name} Unexpected: {unexpected}')
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from collections import namedtuple
|
||||
from backend.text_processing import parsing, emphasis
|
||||
from backend.text_processing.textual_inversion import EmbeddingDatabase
|
||||
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
last_extra_generation_params = {}
|
||||
|
||||
@@ -47,11 +46,12 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module):
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
def __init__(self, text_encoder, tokenizer, chunk_length=75,
|
||||
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original",
|
||||
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True,
|
||||
callback_before_encode=None):
|
||||
class ClassicTextProcessingEngine:
|
||||
def __init__(
|
||||
self, text_encoder, tokenizer, chunk_length=75,
|
||||
embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="Original",
|
||||
text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape)
|
||||
@@ -71,7 +71,6 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
self.clip_skip = clip_skip
|
||||
self.return_pooled = return_pooled
|
||||
self.final_layer_norm = final_layer_norm
|
||||
self.callback_before_encode = callback_before_encode
|
||||
|
||||
self.chunk_length = chunk_length
|
||||
|
||||
@@ -133,7 +132,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
if self.text_projection:
|
||||
pooled_output = pooled_output.float().to(self.text_encoder.text_projection.device) @ self.text_encoder.text_projection.float()
|
||||
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
|
||||
|
||||
z.pooled = pooled_output
|
||||
return z
|
||||
@@ -240,10 +239,7 @@ class ClassicTextProcessingEngine(torch.nn.Module):
|
||||
|
||||
return batch_chunks, token_count
|
||||
|
||||
def forward(self, texts):
|
||||
if self.callback_before_encode is not None:
|
||||
self.callback_before_encode(self, texts)
|
||||
|
||||
def __call__(self, texts):
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
|
||||
used_embeddings = {}
|
||||
|
||||
@@ -401,7 +401,7 @@ def prepare_environment():
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "aebabb94eaaa1a26a3b37128d1c079838c134623")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
|
||||
@@ -100,7 +100,7 @@ def create_binary_mask(image, round=True):
|
||||
return image
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
||||
if sd_model.is_inpaint: # Inpainting models
|
||||
|
||||
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
||||
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
||||
@@ -111,24 +111,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
|
||||
|
||||
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
||||
|
||||
else:
|
||||
if sd_model.is_sdxl_inpaint:
|
||||
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
||||
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
||||
image_conditioning = images_tensor_to_samples(image_conditioning,
|
||||
approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
@@ -307,7 +290,7 @@ class StableDiffusionProcessing:
|
||||
self.comments[text] = 1
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
self.is_using_inpainting_conditioning = self.sd_model.is_inpaint
|
||||
|
||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||
|
||||
@@ -497,6 +480,8 @@ class StableDiffusionProcessing:
|
||||
cache = caches[0]
|
||||
|
||||
with devices.autocast():
|
||||
shared.sd_model.set_clip_skip(opts.CLIP_stop_at_last_layers)
|
||||
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||
|
||||
import backend.text_processing.classic_engine
|
||||
|
||||
@@ -150,8 +150,7 @@ class StableDiffusionModelHijack:
|
||||
self.extra_generation_params = {}
|
||||
|
||||
def get_prompt_lengths(self, text, cond_stage_model):
|
||||
_, token_count = cond_stage_model.process_texts([text])
|
||||
return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
|
||||
pass
|
||||
|
||||
def redo_hijack(self, m):
|
||||
pass
|
||||
|
||||
@@ -14,11 +14,12 @@ import ldm.modules.midas as midas
|
||||
import gc
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.shared import opts
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.timer import Timer
|
||||
import numpy as np
|
||||
from modules_forge import loader
|
||||
from backend.loader import forge_loader
|
||||
from backend import memory_management
|
||||
from backend.args import dynamic_args
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
@@ -636,6 +637,7 @@ def get_obj_from_str(string, reload=False):
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
@@ -663,8 +665,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
sd_model = loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
|
||||
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
||||
dynamic_args['emphasis_name'] = opts.emphasis
|
||||
sd_model = forge_loader(state_dict)
|
||||
timer.record("forge model load")
|
||||
|
||||
sd_model.sd_checkpoint_info = checkpoint_info
|
||||
sd_model.filename = checkpoint_info.filename
|
||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
@@ -39,8 +39,10 @@ class CFGDenoiser(torch.nn.Module):
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
def __init__(self, sampler, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
self.model_wrap = None
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
@@ -64,10 +66,6 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
self.classic_ddim_eps_estimation = False
|
||||
|
||||
@property
|
||||
def inner_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
@@ -158,7 +156,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
original_x_dtype = x.dtype
|
||||
|
||||
if self.classic_ddim_eps_estimation:
|
||||
acd = self.inner_model.inner_model.alphas_cumprod
|
||||
acd = self.inner_model.alphas_cumprod
|
||||
fake_sigmas = ((1 - acd) / acd) ** 0.5
|
||||
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
|
||||
real_sigma_data = 1.0
|
||||
|
||||
@@ -237,7 +237,7 @@ class Sampler:
|
||||
self.eta_infotext_field = 'Eta'
|
||||
self.eta_default = 1.0
|
||||
|
||||
self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
|
||||
self.conditioning_key = 'crossattn'
|
||||
|
||||
self.p = None
|
||||
self.model_wrap_cfg = None
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import torch
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||
from modules import sd_samplers_common, sd_samplers_extra, sd_schedulers, devices
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
||||
|
||||
|
||||
@@ -51,21 +50,6 @@ k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
|
||||
|
||||
|
||||
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
|
||||
|
||||
if denoiser_constructor is not None:
|
||||
self.model_wrap = denoiser_constructor()
|
||||
else:
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model, options=None):
|
||||
super().__init__(funcname)
|
||||
@@ -75,8 +59,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
self.options = options or {}
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||
self.model_wrap = self.model_wrap_cfg = CFGDenoiser(self, sd_model)
|
||||
self.predictor = sd_model.forge_objects.unet.model.predictor
|
||||
|
||||
self.model_wrap_cfg.sigmas = self.predictor.sigmas
|
||||
self.model_wrap_cfg.log_sigmas = self.predictor.sigmas.log()
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
@@ -92,13 +79,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
|
||||
|
||||
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
|
||||
m_sigma_min, m_sigma_max = self.predictor.sigmas[0].item(), self.predictor.sigmas[-1].item()
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif scheduler is None or scheduler.function is None:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
raise ValueError('Wrong scheduler!')
|
||||
else:
|
||||
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
||||
|
||||
@@ -120,7 +107,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||
|
||||
if scheduler.need_inner_model:
|
||||
sigmas_kwargs['inner_model'] = self.model_wrap
|
||||
sigmas_kwargs['inner_model'] = self.model_wrap_cfg
|
||||
|
||||
if scheduler.label == 'Beta':
|
||||
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
|
||||
@@ -134,11 +121,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
return sigmas.cpu()
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
|
||||
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
|
||||
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
@@ -196,11 +183,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
unet_patcher = self.model_wrap_cfg.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap_cfg.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||
self.model_wrap_cfg.sigmas = self.model_wrap_cfg.sigmas.to(x.device)
|
||||
self.model_wrap_cfg.log_sigmas = self.model_wrap_cfg.log_sigmas.to(x.device)
|
||||
|
||||
steps = steps or p.steps
|
||||
|
||||
@@ -219,8 +206,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
extra_params_kwargs['n'] = steps
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap_cfg.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap_cfg.sigmas[-1].item()
|
||||
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
@@ -49,10 +49,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||
|
||||
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
|
||||
def __init__(self, sampler):
|
||||
super().__init__(sampler)
|
||||
def __init__(self, sampler, model):
|
||||
super().__init__(sampler, model)
|
||||
|
||||
self.alphas = shared.sd_model.alphas_cumprod
|
||||
self.alphas = model.forge_objects.unet.model.predictor.alphas_cumprod
|
||||
self.classic_ddim_eps_estimation = True
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
@@ -66,14 +66,6 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
|
||||
return pred_x0
|
||||
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model)
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
class CompVisSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model):
|
||||
@@ -83,8 +75,10 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
self.eta_infotext_field = 'Eta DDIM'
|
||||
self.eta_default = 0.0
|
||||
|
||||
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||
self.model_wrap = self.model_wrap_cfg = CFGDenoiserTimesteps(self, sd_model)
|
||||
self.predictor = sd_model.forge_objects.unet.model.predictor
|
||||
|
||||
self.model_wrap.inner_model.alphas_cumprod = self.predictor.alphas_cumprod
|
||||
|
||||
def get_timesteps(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
|
||||
@@ -10,7 +10,7 @@ from modules.torch_utils import float64
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
@@ -46,7 +46,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
|
||||
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
|
||||
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
|
||||
"""
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
@@ -82,7 +82,7 @@ def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None
|
||||
|
||||
@torch.no_grad()
|
||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
@@ -168,7 +168,7 @@ class UniPCCFG(uni_pc.UniPC):
|
||||
|
||||
|
||||
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||
|
||||
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||
|
||||
@@ -130,7 +130,7 @@ def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
||||
def turbo_scheduler(n, sigma_min, sigma_max, inner_model, device):
|
||||
unet = inner_model.inner_model.forge_objects.unet
|
||||
timesteps = torch.flip(torch.arange(1, n + 1) * float(1000.0 / n) - 1, (0,)).round().long().clip(0, 999)
|
||||
sigmas = unet.model.model_sampling.sigma(timesteps)
|
||||
sigmas = unet.model.predictor.sigma(timesteps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
return sigmas.to(device)
|
||||
|
||||
|
||||
@@ -181,14 +181,14 @@ def update_token_counter(text, steps, styles, *, is_positive=True):
|
||||
prompt_schedules = [[[steps, text]]]
|
||||
|
||||
try:
|
||||
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
|
||||
assert cond_stage_model is not None
|
||||
get_prompt_lengths_on_ui = sd_models.model_data.sd_model.get_prompt_lengths_on_ui
|
||||
assert get_prompt_lengths_on_ui is not None
|
||||
except Exception:
|
||||
return f"<span class='gr-box gr-text-input'>?/?</span>"
|
||||
|
||||
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
||||
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0])
|
||||
token_count, max_length = max([get_prompt_lengths_on_ui(prompt) for prompt in prompts], key=lambda args: args[0])
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user