mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-02-28 02:34:18 +00:00
add AltDiffusion to webui
Signed-off-by: zhaohu xing <920232796@qq.com>
This commit is contained in:
@@ -36,8 +36,8 @@ def get_optimal_device():
|
||||
else:
|
||||
return torch.device("cuda")
|
||||
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
# if has_mps():
|
||||
# return torch.device("mps")
|
||||
|
||||
return cpu
|
||||
|
||||
|
||||
@@ -70,14 +70,19 @@ class StableDiffusionModelHijack:
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
|
||||
if shared.text_model_name == "XLMR-Large":
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
else :
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embeddings, self)
|
||||
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
apply_optimizations()
|
||||
# apply_optimizations()
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
@@ -125,8 +130,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.token_mults = {}
|
||||
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
try:
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
except:
|
||||
self.comma_token = None
|
||||
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
@@ -298,6 +306,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
if shared.text_model_name == "XLMR-Large":
|
||||
return self.wrapped.encode(text)
|
||||
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
@@ -359,7 +370,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|
||||
|
||||
@@ -21,7 +21,7 @@ from modules.paths import models_path, script_path, sd_path
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--config", type=str, default="configs/altdiffusion/ad-inference.yaml", help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
@@ -106,6 +106,10 @@ restricted_opts = {
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
}
|
||||
from omegaconf import OmegaConf
|
||||
config = OmegaConf.load(f"{cmd_opts.config}")
|
||||
# XLMR-Large
|
||||
text_model_name = config.model.params.cond_stage_config.params.name
|
||||
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
|
||||
Reference in New Issue
Block a user