class StableDiffusionModelHijack: def apply_optimizations(self, option=None): pass def convert_sdxl_to_ssd(self, m): pass def hijack(self, m): pass def undo_hijack(self, m): pass def apply_circular(self, enable): pass def clear_comments(self): pass def get_prompt_lengths(self, text, cond_stage_model): from modules import shared return shared.sd_model.get_prompt_lengths_on_ui(text) def redo_hijack(self, m): pass model_hijack = StableDiffusionModelHijack() # import torch # from torch.nn.functional import silu # from types import MethodType # # from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches # from modules.hypernetworks import hypernetwork # from modules.shared import cmd_opts # from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 # # import ldm.modules.attention # import ldm.modules.diffusionmodules.model # import ldm.modules.diffusionmodules.openaimodel # import ldm.models.diffusion.ddpm # import ldm.models.diffusion.ddim # import ldm.models.diffusion.plms # import ldm.modules.encoders.modules # # import sgm.modules.attention # import sgm.modules.diffusionmodules.model # import sgm.modules.diffusionmodules.openaimodel # import sgm.modules.encoders.modules # # attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward # diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity # diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward # # # new memory efficient cross attention blocks do not support hypernets and we already # # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention # ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention # ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # # # silence new console spam from SD2 # ldm.modules.attention.print = shared.ldm_print # ldm.modules.diffusionmodules.model.print = shared.ldm_print # ldm.util.print = shared.ldm_print # ldm.models.diffusion.ddpm.print = shared.ldm_print # # optimizers = [] # current_optimizer: sd_hijack_optimizations.SdOptimization = None # # ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) # ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) # # sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) # sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) # # # def list_optimizers(): # new_optimizers = script_callbacks.list_optimizers_callback() # # new_optimizers = [x for x in new_optimizers if x.is_available()] # # new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) # # optimizers.clear() # optimizers.extend(new_optimizers) # # # def apply_optimizations(option=None): # return # # # def undo_optimizations(): # return # # # def fix_checkpoint(): # """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want # checkpoints to be added when not training (there's a warning)""" # # pass # # # def weighted_loss(sd_model, pred, target, mean=True): # #Calculate the weight normally, but ignore the mean # loss = sd_model._old_get_loss(pred, target, mean=False) # # #Check if we have weights available # weight = getattr(sd_model, '_custom_loss_weight', None) # if weight is not None: # loss *= weight # # #Return the loss, as mean if specified # return loss.mean() if mean else loss # # def weighted_forward(sd_model, x, c, w, *args, **kwargs): # try: # #Temporarily append weights to a place accessible during loss calc # sd_model._custom_loss_weight = w # # #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely # #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set # if not hasattr(sd_model, '_old_get_loss'): # sd_model._old_get_loss = sd_model.get_loss # sd_model.get_loss = MethodType(weighted_loss, sd_model) # # #Run the standard forward function, but with the patched 'get_loss' # return sd_model.forward(x, c, *args, **kwargs) # finally: # try: # #Delete temporary weights if appended # del sd_model._custom_loss_weight # except AttributeError: # pass # # #If we have an old loss function, reset the loss function to the original one # if hasattr(sd_model, '_old_get_loss'): # sd_model.get_loss = sd_model._old_get_loss # del sd_model._old_get_loss # # def apply_weighted_forward(sd_model): # #Add new function 'weighted_forward' that can be called to calc weighted loss # sd_model.weighted_forward = MethodType(weighted_forward, sd_model) # # def undo_weighted_forward(sd_model): # try: # del sd_model.weighted_forward # except AttributeError: # pass # # # class StableDiffusionModelHijack: # fixes = None # layers = None # circular_enabled = False # clip = None # optimization_method = None # # def __init__(self): # self.extra_generation_params = {} # self.comments = [] # # def apply_optimizations(self, option=None): # pass # # def convert_sdxl_to_ssd(self, m): # pass # # def hijack(self, m): # pass # # def undo_hijack(self, m): # pass # # def apply_circular(self, enable): # pass # # def clear_comments(self): # self.comments = [] # self.extra_generation_params = {} # # def get_prompt_lengths(self, text, cond_stage_model): # pass # # def redo_hijack(self, m): # pass # # # class EmbeddingsWithFixes(torch.nn.Module): # def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): # super().__init__() # self.wrapped = wrapped # self.embeddings = embeddings # self.textual_inversion_key = textual_inversion_key # self.weight = self.wrapped.weight # # def forward(self, input_ids): # batch_fixes = self.embeddings.fixes # self.embeddings.fixes = None # # inputs_embeds = self.wrapped(input_ids) # # if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: # return inputs_embeds # # vecs = [] # for fixes, tensor in zip(batch_fixes, inputs_embeds): # for offset, embedding in fixes: # vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec # emb = devices.cond_cast_unet(vec) # emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) # tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) # # vecs.append(tensor) # # return torch.stack(vecs) # # # class TextualInversionEmbeddings(torch.nn.Embedding): # def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): # super().__init__(num_embeddings, embedding_dim, **kwargs) # # self.embeddings = model_hijack # self.textual_inversion_key = textual_inversion_key # # @property # def wrapped(self): # return super().forward # # def forward(self, input_ids): # return EmbeddingsWithFixes.forward(self, input_ids) # # # def add_circular_option_to_conv_2d(): # conv2d_constructor = torch.nn.Conv2d.__init__ # # def conv2d_constructor_circular(self, *args, **kwargs): # return conv2d_constructor(self, *args, padding_mode='circular', **kwargs) # # torch.nn.Conv2d.__init__ = conv2d_constructor_circular # # # # # # def register_buffer(self, name, attr): # """ # Fix register buffer bug for Mac OS. # """ # # if type(attr) == torch.Tensor: # if attr.device != devices.device: # attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) # # setattr(self, name, attr) # # # ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer # ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer