mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-27 11:29:46 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5853ef9f8 | ||
|
|
2d5b6cacef | ||
|
|
1363999fb1 |
7
NEWS.md
7
NEWS.md
@@ -1,7 +0,0 @@
|
||||
About Gradio 5: will try to upgrade to Gradio 5 at about 2025 March. If failed, then will try again on about 2025 June. relatively positive that we can have Gradio5 before next summer.
|
||||
|
||||
2024 Oct 28: A new branch `sd35` is contributed by [#2183](https://github.com/lllyasviel/stable-diffusion-webui-forge/pull/2183) . I will take a look at quants and sampling and transformer's clip-g vs that clip-g rewrite before merging to main ... (Oct 29: okay maybe medium also need to take a look later)
|
||||
|
||||
About Flux ControlNet (sync [here](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/932)): The rewrite of ControlNet Intergrated will ~start at about Sep 29~ (delayed) ~start at about Oct 15~ (delayed) ~start at about Oct 30~ (delayed) start at about Nov 20. (When this note is announced, the main targets include some diffusers formatted Flux ControlNets and some community implementation of Union ControlNets. However, this may be extended if stronger models come out after this note.)
|
||||
|
||||
2024 Sep 7: New sampler `Flux Realistic` is available now! Recommended scheduler is "simple".
|
||||
@@ -6,7 +6,9 @@ The name "Forge" is inspired from "Minecraft Forge". This project is aimed at be
|
||||
|
||||
Forge is currently based on SD-WebUI 1.10.1 at [this commit](https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/82a973c04367123ae98bd9abdf80d9eda9b910e2). (Because original SD-WebUI is almost static now, Forge will sync with original WebUI every 90 days, or when important fixes.)
|
||||
|
||||
News are moved to this link: [Click here to see the News section](https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/NEWS.md)
|
||||
# News
|
||||
|
||||
2024 Sep 7: New sampler `Flux Realistic` is available now! Recommended scheduler is "simple".
|
||||
|
||||
# Quick List
|
||||
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend.modules.k_prediction import PredictionFlux
|
||||
from backend import memory_management
|
||||
|
||||
class Chroma(ForgeDiffusionEngine):
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
self.is_inpaint = False
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
't5xxl': huggingface_components['text_encoder']
|
||||
},
|
||||
tokenizer_dict={
|
||||
't5xxl': huggingface_components['tokenizer']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
k_predictor = PredictionFlux(
|
||||
mu=1.0
|
||||
)
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler=None,
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
min_length=1
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
return self.text_processing_engine_t5(prompt)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
@@ -33,17 +33,9 @@ class Flux(ForgeDiffusionEngine):
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
if 'schnell' in estimated_config.huggingface_repo.lower():
|
||||
k_predictor = PredictionFlux(
|
||||
mu=1.0
|
||||
)
|
||||
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000)
|
||||
else:
|
||||
k_predictor = PredictionFlux(
|
||||
seq_len=4096,
|
||||
base_seq_len=256,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.5,
|
||||
max_shift=1.15,
|
||||
)
|
||||
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000)
|
||||
self.use_distilled_cfg_scale = True
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
|
||||
@@ -9,9 +9,6 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
|
||||
class StableDiffusion(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD15]
|
||||
@@ -82,19 +79,3 @@ class StableDiffusion(ForgeDiffusionEngine):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SD15.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
@@ -9,9 +9,6 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
|
||||
class StableDiffusion2(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD20]
|
||||
@@ -82,19 +79,3 @@ class StableDiffusion2(ForgeDiffusionEngine):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SD20.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
@@ -1,149 +1,137 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.modules.k_prediction import PredictionDiscreteFlow
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
## patch SD3 Class in huggingface_guess.model_list
|
||||
def SD3_clip_target(self, state_dict={}):
|
||||
return {'clip_l': 'text_encoder', 'clip_g': 'text_encoder_2', 't5xxl': 'text_encoder_3'}
|
||||
|
||||
model_list.SD3.unet_target = 'transformer'
|
||||
model_list.SD3.clip_target = SD3_clip_target
|
||||
## end patch
|
||||
|
||||
class StableDiffusion3(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD3]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
self.is_inpaint = False
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
'clip_g': huggingface_components['text_encoder_2'],
|
||||
't5xxl' : huggingface_components['text_encoder_3']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
'clip_g': huggingface_components['tokenizer_2'],
|
||||
't5xxl' : huggingface_components['tokenizer_3']
|
||||
}
|
||||
)
|
||||
|
||||
k_predictor = PredictionDiscreteFlow(shift=3.0)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler= None,
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=1280,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd3 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_g, g_pooled = self.text_processing_engine_g(prompt)
|
||||
cond_l, l_pooled = self.text_processing_engine_l(prompt)
|
||||
if opts.sd3_enable_t5:
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
else:
|
||||
cond_t5 = torch.zeros([len(prompt), 256, 4096]).to(cond_l.device)
|
||||
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
l_pooled = torch.zeros_like(l_pooled)
|
||||
g_pooled = torch.zeros_like(g_pooled)
|
||||
cond_l = torch.zeros_like(cond_l)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
cond_t5 = torch.zeros_like(cond_t5)
|
||||
|
||||
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
|
||||
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
|
||||
|
||||
cond = dict(
|
||||
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
|
||||
vector=torch.cat([l_pooled, g_pooled], dim=-1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
|
||||
return sample.to(x)
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
# from huggingface_guess.latent import SD3
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.modules.k_prediction import PredictionDiscreteFlow
|
||||
|
||||
class StableDiffusion3(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD35]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
'clip_g': huggingface_components['text_encoder_2'],
|
||||
't5xxl': huggingface_components['text_encoder_3']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
'clip_g': huggingface_components['tokenizer_2'],
|
||||
't5xxl': huggingface_components['tokenizer_3']
|
||||
}
|
||||
)
|
||||
|
||||
k_predictor = PredictionDiscreteFlow( shift=3.0)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler= None,
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=1280,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
)
|
||||
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd3 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_g, g_pooled = self.text_processing_engine_g(prompt)
|
||||
cond_l, l_pooled = self.text_processing_engine_l(prompt)
|
||||
# if enabled?
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
l_pooled = torch.zeros_like(l_pooled)
|
||||
g_pooled = torch.zeros_like(g_pooled)
|
||||
cond_l = torch.zeros_like(cond_l)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
cond_t5 = torch.zeros_like(cond_t5)
|
||||
|
||||
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
|
||||
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
|
||||
|
||||
cond = dict(
|
||||
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
|
||||
vector=torch.cat([l_pooled, g_pooled], dim=-1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
|
||||
return sample.to(x)
|
||||
|
||||
@@ -10,11 +10,6 @@ from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.nn.unet import Timestep
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SDXL]
|
||||
@@ -93,8 +88,8 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
height = getattr(prompt, 'height', 1024) or 1024
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
crop_w = opts.sdxl_crop_left
|
||||
crop_h = opts.sdxl_crop_top
|
||||
crop_w = 0
|
||||
crop_h = 0
|
||||
target_width = width
|
||||
target_height = height
|
||||
|
||||
@@ -136,137 +131,3 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SDXL.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
|
||||
class StableDiffusionXLRefiner(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SDXLRefiner]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_g': huggingface_components['text_encoder']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_g': huggingface_components['tokenizer'],
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=2048,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=2,
|
||||
clip_skip=2,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sdxl = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
|
||||
|
||||
width = getattr(prompt, 'width', 1024) or 1024
|
||||
height = getattr(prompt, 'height', 1024) or 1024
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
crop_w = opts.sdxl_crop_left
|
||||
crop_h = opts.sdxl_crop_top
|
||||
aesthetic = opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
out = [
|
||||
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
|
||||
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
|
||||
self.embedder(torch.Tensor([aesthetic]))
|
||||
]
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
clip_pooled = torch.zeros_like(clip_pooled)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
|
||||
cond = dict(
|
||||
crossattn=cond_g,
|
||||
vector=torch.cat([clip_pooled, flat], dim=1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine_g.process_texts([prompt])
|
||||
return token_count, self.text_processing_engine_g.get_target_prompt_token_count(token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SDXLRefiner.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
{
|
||||
"_class_name": "FluxPipeline",
|
||||
"_diffusers_version": "0.30.0.dev0",
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"T5TokenizerFast"
|
||||
],
|
||||
"transformer": [
|
||||
"diffusers",
|
||||
"ChromaTransformer2DModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.30.0.dev0",
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 1.15,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"use_dynamic_shifting": false
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"_class_name": "StableDiffusion3Pipeline",
|
||||
"_diffusers_version": "0.30.3.dev0",
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"text_encoder_2": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"text_encoder_3": [
|
||||
"transformers",
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"tokenizer_2": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"tokenizer_3": [
|
||||
"transformers",
|
||||
"T5TokenizerFast"
|
||||
],
|
||||
"transformer": [
|
||||
"diffusers",
|
||||
"SD3Transformer2DModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.29.0.dev0",
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"architectures": [
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.41.2",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"architectures": [
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 1280,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.41.2",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
{
|
||||
"architectures": [
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"classifier_dropout": 0.0,
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "gelu_new",
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
@@ -16,7 +21,11 @@
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.41.2",
|
||||
"use_cache": true,
|
||||
"vocab_size": 32128
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"_class_name": "SD3Transformer2DModel",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"attention_head_dim": 64,
|
||||
"caption_projection_dim": 2432,
|
||||
"in_channels": 16,
|
||||
"joint_attention_dim": 4096,
|
||||
"num_attention_heads": 38,
|
||||
"num_layers": 38,
|
||||
"out_channels": 16,
|
||||
"patch_size": 2,
|
||||
"pooled_projection_dim": 2048,
|
||||
"pos_embed_max_size": 192,
|
||||
"qk_norm": "rms_norm",
|
||||
"sample_size": 128
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"_class_name": "AutoencoderKL",
|
||||
"_diffusers_version": "0.30.0.dev0",
|
||||
"_name_or_path": "../checkpoints/flux-dev",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": "../sdxl-vae/",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
@@ -25,8 +25,8 @@
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 3,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 0.3611,
|
||||
"shift_factor": 0.1159,
|
||||
"scaling_factor": 1.5305,
|
||||
"shift_factor": 0.0609,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
@@ -9,10 +9,18 @@
|
||||
"EulerDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
null,
|
||||
null
|
||||
],
|
||||
"text_encoder_2": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"tokenizer": [
|
||||
null,
|
||||
null
|
||||
],
|
||||
"tokenizer_2": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -19,13 +19,12 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
|
||||
from backend.diffusion_engine.sd35 import StableDiffusion3
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
from backend.diffusion_engine.chroma import Chroma
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, StableDiffusion3, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -109,20 +108,17 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel', 'ChromaTransformer2DModel']:
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
|
||||
model_loader = None
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
||||
elif cls_name == 'FluxTransformer2DModel':
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
elif cls_name == 'ChromaTransformer2DModel':
|
||||
from backend.nn.chroma import IntegratedChromaTransformer2DModel
|
||||
model_loader = lambda c: IntegratedChromaTransformer2DModel(**c)
|
||||
elif cls_name == 'SD3Transformer2DModel':
|
||||
from backend.nn.mmditx import MMDiTX
|
||||
if cls_name == 'SD3Transformer2DModel':
|
||||
from modules.models.sd35.mmditx import MMDiTX
|
||||
model_loader = lambda c: MMDiTX(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
@@ -178,7 +174,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
return None
|
||||
|
||||
|
||||
def replace_state_dict(sd, asd, guess):
|
||||
def replace_state_dict(sd, asd, guess, is_clip_g = False):
|
||||
vae_key_prefix = guess.vae_key_prefix[0]
|
||||
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
||||
|
||||
@@ -217,217 +213,19 @@ def replace_state_dict(sd, asd, guess):
|
||||
for k, v in asd.items():
|
||||
sd[vae_key_prefix + k] = v
|
||||
|
||||
|
||||
## identify model type
|
||||
flux_test_key = "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
|
||||
sd3_test_key = "model.diffusion_model.final_layer.adaLN_modulation.1.bias"
|
||||
legacy_test_key = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
model_type = "-"
|
||||
if legacy_test_key in sd:
|
||||
match sd[legacy_test_key].shape[1]:
|
||||
case 768:
|
||||
model_type = "sd1"
|
||||
case 1024:
|
||||
model_type = "sd2"
|
||||
case 1280:
|
||||
model_type = "xlrf" # sdxl refiner model
|
||||
case 2048:
|
||||
model_type = "sdxl"
|
||||
elif flux_test_key in sd:
|
||||
model_type = "flux"
|
||||
elif sd3_test_key in sd:
|
||||
model_type = "sd3"
|
||||
|
||||
## prefixes used by various model types for CLIP-L
|
||||
prefix_L = {
|
||||
"-" : None,
|
||||
"sd1" : "cond_stage_model.transformer.",
|
||||
"sd2" : None,
|
||||
"xlrf": None,
|
||||
"sdxl": "conditioner.embedders.0.transformer.",
|
||||
"flux": "text_encoders.clip_l.transformer.",
|
||||
"sd3" : "text_encoders.clip_l.transformer.",
|
||||
}
|
||||
## prefixes used by various model types for CLIP-G
|
||||
prefix_G = {
|
||||
"-" : None,
|
||||
"sd1" : None,
|
||||
"sd2" : None,
|
||||
"xlrf": "conditioner.embedders.0.model.transformer.",
|
||||
"sdxl": "conditioner.embedders.1.model.transformer.",
|
||||
"flux": None,
|
||||
"sd3" : "text_encoders.clip_g.transformer.",
|
||||
}
|
||||
## prefixes used by various model types for CLIP-H
|
||||
prefix_H = {
|
||||
"-" : None,
|
||||
"sd1" : None,
|
||||
"sd2" : "conditioner.embedders.0.model.",
|
||||
"xlrf": None,
|
||||
"sdxl": None,
|
||||
"flux": None,
|
||||
"sd3" : None,
|
||||
}
|
||||
|
||||
|
||||
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
|
||||
if "first_stage_model.decoder.conv_in.weight" in asd:
|
||||
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
|
||||
if model_type == "sd1" or model_type == "sd2" or model_type == "xlrf" or model_type == "sdxl":
|
||||
if channels == 4:
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
elif model_type == "sd3":
|
||||
if channels == 16:
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
## CLIP-H
|
||||
CLIP_H = { # key to identify source model old_prefix
|
||||
'cond_stage_model.model.ln_final.weight' : 'cond_stage_model.model.',
|
||||
# 'text_model.encoder.layers.0.layer_norm1.bias' : 'text_model'. # would need converting
|
||||
}
|
||||
for CLIP_key in CLIP_H.keys():
|
||||
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1024:
|
||||
new_prefix = prefix_H[model_type]
|
||||
old_prefix = CLIP_H[CLIP_key]
|
||||
|
||||
if new_prefix is not None:
|
||||
for k, v in asd.items():
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
sd[new_k] = v
|
||||
|
||||
## CLIP-G
|
||||
CLIP_G = { # key to identify source model old_prefix
|
||||
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.transformer.',
|
||||
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.transformer.',
|
||||
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
||||
}
|
||||
for CLIP_key in CLIP_G.keys():
|
||||
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280:
|
||||
new_prefix = prefix_G[model_type]
|
||||
old_prefix = CLIP_G[CLIP_key]
|
||||
|
||||
if new_prefix is not None:
|
||||
if "resblocks" not in CLIP_key and model_type != "sd3": # need to convert
|
||||
def convert_transformers(statedict, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding",
|
||||
"{}text_model.embeddings.token_embedding.weight" : "{}token_embedding.weight",
|
||||
"{}text_model.final_layer_norm.weight" : "{}ln_final.weight",
|
||||
"{}text_model.final_layer_norm.bias" : "{}ln_final.bias",
|
||||
"text_projection.weight" : "{}text_projection",
|
||||
}
|
||||
resblock_to_replace = {
|
||||
"layer_norm1" : "ln_1",
|
||||
"layer_norm2" : "ln_2",
|
||||
"mlp.fc1" : "mlp.c_fc",
|
||||
"mlp.fc2" : "mlp.c_proj",
|
||||
"self_attn.out_proj" : "attn.out_proj" ,
|
||||
}
|
||||
|
||||
for x in keys_to_replace: # remove trailing 'transformer.' from new prefix
|
||||
k = x.format(prefix_from)
|
||||
statedict[keys_to_replace[x].format(prefix_to[:-12])] = statedict.pop(k)
|
||||
|
||||
for resblock in range(number):
|
||||
for y in ["weight", "bias"]:
|
||||
for x in resblock_to_replace:
|
||||
k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
statedict[k_to] = statedict.pop(k)
|
||||
|
||||
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y)
|
||||
weightsQ = statedict.pop(k_from)
|
||||
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.k_proj", y)
|
||||
weightsK = statedict.pop(k_from)
|
||||
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y)
|
||||
weightsV = statedict.pop(k_from)
|
||||
|
||||
k_to = "{}resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
|
||||
|
||||
statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV))
|
||||
return statedict
|
||||
|
||||
asd = convert_transformers(asd, old_prefix, new_prefix, 32)
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
elif old_prefix == "":
|
||||
for k, v in asd.items():
|
||||
new_k = new_prefix + k
|
||||
sd[new_k] = v
|
||||
else:
|
||||
for k, v in asd.items():
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
sd[new_k] = v
|
||||
|
||||
## CLIP-L
|
||||
CLIP_L = { # key to identify source model old_prefix
|
||||
'cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'cond_stage_model.transformer.',
|
||||
'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.',
|
||||
'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.',
|
||||
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
||||
}
|
||||
|
||||
for CLIP_key in CLIP_L.keys():
|
||||
if CLIP_key in asd and asd[CLIP_key].shape[0] == 768:
|
||||
new_prefix = prefix_L[model_type]
|
||||
old_prefix = CLIP_L[CLIP_key]
|
||||
|
||||
if new_prefix is not None:
|
||||
if "resblocks" in CLIP_key: # need to convert
|
||||
def transformers_convert(statedict, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"positional_embedding" : "{}text_model.embeddings.position_embedding.weight",
|
||||
"token_embedding.weight": "{}text_model.embeddings.token_embedding.weight",
|
||||
"ln_final.weight" : "{}text_model.final_layer_norm.weight",
|
||||
"ln_final.bias" : "{}text_model.final_layer_norm.bias",
|
||||
"text_projection" : "text_projection.weight",
|
||||
}
|
||||
resblock_to_replace = {
|
||||
"ln_1" : "layer_norm1",
|
||||
"ln_2" : "layer_norm2",
|
||||
"mlp.c_fc" : "mlp.fc1",
|
||||
"mlp.c_proj" : "mlp.fc2",
|
||||
"attn.out_proj" : "self_attn.out_proj",
|
||||
}
|
||||
|
||||
for k in keys_to_replace:
|
||||
statedict[keys_to_replace[k].format(prefix_to)] = statedict.pop(k)
|
||||
|
||||
for resblock in range(number):
|
||||
for y in ["weight", "bias"]:
|
||||
for x in resblock_to_replace:
|
||||
k = "{}resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
statedict[k_to] = statedict.pop(k)
|
||||
|
||||
k_from = "{}resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
weights = statedict.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
statedict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return statedict
|
||||
|
||||
asd = transformers_convert(asd, old_prefix, new_prefix, 12)
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
elif old_prefix == "":
|
||||
for k, v in asd.items():
|
||||
new_k = new_prefix + k
|
||||
sd[new_k] = v
|
||||
else:
|
||||
for k, v in asd.items():
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
sd[new_k] = v
|
||||
|
||||
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
||||
if is_clip_g:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_g.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_g.transformer.{k}"] = v
|
||||
else:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||
|
||||
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
||||
@@ -440,8 +238,9 @@ def replace_state_dict(sd, asd, guess):
|
||||
|
||||
|
||||
def preprocess_state_dict(sd):
|
||||
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
|
||||
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
||||
if any("double_block" in k for k in sd.keys()):
|
||||
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
|
||||
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
||||
|
||||
return sd
|
||||
|
||||
@@ -453,16 +252,14 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
||||
|
||||
if isinstance(additional_state_dicts, list):
|
||||
for asd in additional_state_dicts:
|
||||
is_clip_g = 'clip_g' in asd
|
||||
asd = load_torch_file(asd)
|
||||
sd = replace_state_dict(sd, asd, guess)
|
||||
del asd
|
||||
sd = replace_state_dict(sd, asd, guess, is_clip_g)
|
||||
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
guess.model_type = guess.model_type(sd)
|
||||
guess.ztsnr = 'ztsnr' in sd
|
||||
|
||||
sd = guess.process_vae_state_dict(sd)
|
||||
|
||||
state_dict = {
|
||||
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||
@@ -482,18 +279,7 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
||||
|
||||
return state_dict, guess
|
||||
|
||||
# To be removed once PR merged on huggingface_guess
|
||||
chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma")
|
||||
|
||||
if not chroma_is_in_huggingface_guess:
|
||||
class GuessChroma:
|
||||
huggingface_repo = 'Chroma'
|
||||
unet_extra_config = {
|
||||
'guidance_out_dim': 3072,
|
||||
'guidance_hidden_dim': 5120,
|
||||
'guidance_n_layers': 5
|
||||
}
|
||||
unet_remove_config = ['guidance_embed']
|
||||
@torch.inference_mode()
|
||||
def forge_loader(sd, additional_state_dicts=None):
|
||||
try:
|
||||
@@ -501,17 +287,6 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
except:
|
||||
raise ValueError('Failed to recognize model type!')
|
||||
|
||||
if not chroma_is_in_huggingface_guess \
|
||||
and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
|
||||
and "transformer" in state_dicts \
|
||||
and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]:
|
||||
estimated_config.huggingface_repo = GuessChroma.huggingface_repo
|
||||
for x in GuessChroma.unet_extra_config:
|
||||
estimated_config.unet_config[x] = GuessChroma.unet_extra_config[x]
|
||||
for x in GuessChroma.unet_remove_config:
|
||||
del estimated_config.unet_config[x]
|
||||
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
|
||||
del state_dicts['text_encoder_2']
|
||||
repo_name = estimated_config.huggingface_repo
|
||||
|
||||
local_path = os.path.join(dir_path, 'huggingface', repo_name)
|
||||
@@ -527,47 +302,15 @@ def forge_loader(sd, additional_state_dicts=None):
|
||||
if component is not None:
|
||||
huggingface_components[component_name] = component
|
||||
|
||||
yaml_config = None
|
||||
yaml_config_prediction_type = None
|
||||
|
||||
try:
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
config_filename = os.path.splitext(sd)[0] + '.yaml'
|
||||
if Path(config_filename).is_file():
|
||||
with open(config_filename, 'r') as stream:
|
||||
yaml_config = yaml.safe_load(stream)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fix Huggingface prediction type using .yaml config or estimated config detection
|
||||
# Fix Huggingface prediction type using estimated config detection
|
||||
prediction_types = {
|
||||
'EPS': 'epsilon',
|
||||
'V_PREDICTION': 'v_prediction',
|
||||
'EDM': 'edm',
|
||||
}
|
||||
if 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config:
|
||||
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
|
||||
|
||||
has_prediction_type = 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config
|
||||
|
||||
if yaml_config is not None:
|
||||
yaml_config_prediction_type: str = (
|
||||
yaml_config.get('model', {}).get('params', {}).get('parameterization', '')
|
||||
or yaml_config.get('model', {}).get('params', {}).get('denoiser_config', {}).get('params', {}).get('scaling_config', {}).get('target', '')
|
||||
)
|
||||
if yaml_config_prediction_type == 'v' or yaml_config_prediction_type.endswith(".VScaling"):
|
||||
yaml_config_prediction_type = 'v_prediction'
|
||||
else:
|
||||
# Use estimated prediction config if no suitable prediction type found
|
||||
yaml_config_prediction_type = ''
|
||||
|
||||
if has_prediction_type:
|
||||
if yaml_config_prediction_type:
|
||||
huggingface_components['scheduler'].config.prediction_type = yaml_config_prediction_type
|
||||
else:
|
||||
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
|
||||
|
||||
if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma":
|
||||
return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
for M in possible_models:
|
||||
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
|
||||
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)
|
||||
|
||||
@@ -560,11 +560,6 @@ def unload_model_clones(model):
|
||||
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
||||
# this check fully unloads any 'abandoned' models
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
|
||||
|
||||
if free_all:
|
||||
memory_required = 1e30
|
||||
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
||||
|
||||
@@ -2,9 +2,6 @@ import math
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.pipelines.flux.pipeline_flux import calculate_shift
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
betas = []
|
||||
@@ -44,6 +41,10 @@ def time_snr_shift(alpha, t):
|
||||
return alpha * t / (1 + (alpha - 1) * t)
|
||||
|
||||
|
||||
def flux_time_shift(mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
@@ -274,31 +275,13 @@ class PredictionDiscreteFlow(AbstractPrediction):
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
|
||||
|
||||
class PredictionFlux(AbstractPrediction):
|
||||
def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None):
|
||||
super().__init__(sigma_data=1.0, prediction_type='const')
|
||||
self.mu = mu
|
||||
self.pseudo_timestep_range = pseudo_timestep_range
|
||||
self.apply_mu_transform(seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift, mu=mu)
|
||||
|
||||
def apply_mu_transform(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, mu=None):
|
||||
# TODO: Add an UI option to let user choose whether to call this in each generation to bind latent size to sigmas
|
||||
# And some cases may want their own mu values or other parameters
|
||||
if mu is None:
|
||||
self.mu = calculate_shift(image_seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift)
|
||||
else:
|
||||
self.mu = mu
|
||||
sigmas = torch.arange(1, self.pseudo_timestep_range + 1, 1) / self.pseudo_timestep_range
|
||||
sigmas = FlowMatchEulerDiscreteScheduler.time_shift(None, self.mu, 1.0, sigmas)
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
|
||||
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||
self.shift = shift
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
@@ -312,7 +295,7 @@ class PredictionFlux(AbstractPrediction):
|
||||
return sigma
|
||||
|
||||
def sigma(self, timestep):
|
||||
return timestep
|
||||
return flux_time_shift(self.shift, 1.0, timestep)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
# implementation of Chroma for Forge, inspired by https://github.com/lodestone-rock/ComfyUI_FluxMod
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from backend.attention import attention_function
|
||||
from backend.utils import fp16_fix, tensor2parameter
|
||||
from backend.nn.flux import attention, rope, timestep_embedding, EmbedND, MLPEmbedder, RMSNorm, QKNorm, SelfAttention
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 4):
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm( hidden_dim) for x in range( n_layers)])
|
||||
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_proj(x)
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: torch.Tensor
|
||||
scale: torch.Tensor
|
||||
gate: torch.Tensor
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img, txt, mod, pe):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = mod
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
B, L, _ = img_qkv.shape
|
||||
H = self.num_heads
|
||||
D = img_qkv.shape[-1] // (3 * H)
|
||||
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
B, L, _ = txt_qkv.shape
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt = fp16_fix(txt)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x, mod, pe):
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
del x_mod
|
||||
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
|
||||
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
|
||||
q, k = self.norm(q, k, v)
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
del q, k, v, pe
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
|
||||
del attn, mlp
|
||||
|
||||
x = x + mod.gate * output
|
||||
x = fp16_fix(x)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
def forward(self, x, mod):
|
||||
shift, scale = mod
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class IntegratedChromaTransformer2DModel(nn.Module):
|
||||
def __init__(self, in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_out_dim: int, guidance_hidden_dim: int, guidance_n_layers: int):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels * 4
|
||||
self.out_channels = self.in_channels
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
|
||||
|
||||
pe_dim = hidden_size // num_heads
|
||||
if sum(axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.distilled_guidance_layer = Approximator(64, guidance_out_dim, guidance_hidden_dim, guidance_n_layers)
|
||||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
@staticmethod
|
||||
def distribute_modulations(tensor, single_block_count: int = 38, double_blocks_count: int = 19):
|
||||
"""
|
||||
Distributes slices of the tensor into the block_dict as ModulationOut objects.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
|
||||
"""
|
||||
batch_size, vectors, dim = tensor.shape
|
||||
block_dict = {}
|
||||
for i in range(single_block_count):
|
||||
key = f"single_blocks.{i}.modulation.lin"
|
||||
block_dict[key] = None
|
||||
for i in range(double_blocks_count):
|
||||
key = f"double_blocks.{i}.img_mod.lin"
|
||||
block_dict[key] = None
|
||||
for i in range(double_blocks_count):
|
||||
key = f"double_blocks.{i}.txt_mod.lin"
|
||||
block_dict[key] = None
|
||||
block_dict["final_layer.adaLN_modulation.1"] = None
|
||||
idx = 0 # Index to keep track of the vector slices
|
||||
for key in block_dict.keys():
|
||||
if "single_blocks" in key:
|
||||
# Single block: 1 ModulationOut
|
||||
block_dict[key] = ModulationOut(
|
||||
shift=tensor[:, idx:idx+1, :],
|
||||
scale=tensor[:, idx+1:idx+2, :],
|
||||
gate=tensor[:, idx+2:idx+3, :]
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors
|
||||
elif "img_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx:idx+1, :],
|
||||
scale=tensor[:, idx+1:idx+2, :],
|
||||
gate=tensor[:, idx+2:idx+3, :]
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
elif "txt_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx:idx+1, :],
|
||||
scale=tensor[:, idx+1:idx+2, :],
|
||||
gate=tensor[:, idx+2:idx+3, :]
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
elif "final_layer" in key:
|
||||
# Final layer: 1 ModulationOut
|
||||
block_dict[key] = [
|
||||
tensor[:, idx:idx+1, :],
|
||||
tensor[:, idx+1:idx+2, :],
|
||||
]
|
||||
idx += 2 # Advance by 2 vectors
|
||||
return block_dict
|
||||
|
||||
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps):
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
img = self.img_in(img)
|
||||
device = img.device
|
||||
dtype = img.dtype # torch.bfloat16
|
||||
nb_double_block = len(self.double_blocks)
|
||||
nb_single_block = len(self.single_blocks)
|
||||
|
||||
mod_index_length = nb_double_block*12 + nb_single_block*3 + 2
|
||||
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(device=device, dtype=dtype)
|
||||
distil_guidance = timestep_embedding(torch.zeros_like(timesteps), 16).to(device=device, dtype=dtype)
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(device=device, dtype=dtype)
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
mod_vectors_dict = self.distribute_modulations(mod_vectors, nb_single_block, nb_double_block)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
del txt_ids, img_ids
|
||||
pe = self.pe_embedder(ids)
|
||||
del ids
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
|
||||
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
|
||||
double_mod = [img_mod, txt_mod]
|
||||
img, txt = block(img=img, txt=txt, mod=double_mod, pe=pe)
|
||||
img = torch.cat((txt, img), 1)
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
||||
img = block(img, mod=single_mod, pe=pe)
|
||||
del pe
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
|
||||
img = self.final_layer(img, final_mod)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
input_device = x.device
|
||||
input_dtype = x.dtype
|
||||
patch_size = 2
|
||||
pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
|
||||
pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
del x, pad_h, pad_w
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=input_device, dtype=input_dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
|
||||
del input_device, input_dtype
|
||||
out = self.inner_forward(img, img_ids, context, txt_ids, timestep)
|
||||
del img, img_ids, txt_ids, timestep, context
|
||||
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
|
||||
del h_len, w_len, bs
|
||||
return out
|
||||
@@ -4,11 +4,10 @@ import math
|
||||
from backend.attention import attention_pytorch as attention_function
|
||||
from transformers.activations import NewGELUActivation
|
||||
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
"relu": torch.nn.functional.relu,
|
||||
"gelu_new": lambda a: NewGELUActivation()(a),
|
||||
"gelu_new": lambda a: NewGELUActivation()(a)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ stash = {}
|
||||
|
||||
|
||||
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
|
||||
scale_weight = getattr(layer, 'scale_weight', None)
|
||||
patches = getattr(layer, 'forge_online_loras', None)
|
||||
weight_patches, bias_patches = None, None
|
||||
|
||||
@@ -33,8 +32,6 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None,
|
||||
weight = weight_fn(weight)
|
||||
if weight_args is not None:
|
||||
weight = weight.to(**weight_args)
|
||||
if scale_weight is not None:
|
||||
weight = weight*scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
if weight_patches is not None:
|
||||
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
|
||||
|
||||
@@ -130,7 +127,6 @@ class ForgeOperations:
|
||||
self.out_features = out_features
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||
self.weight = None
|
||||
self.scale_weight = None
|
||||
self.bias = None
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
@@ -138,8 +134,6 @@ class ForgeOperations:
|
||||
if hasattr(self, 'dummy'):
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
|
||||
if prefix + 'scale_weight' in state_dict:
|
||||
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
del self.dummy
|
||||
|
||||
@@ -13,7 +13,6 @@ quants_mapping = {
|
||||
gguf.GGMLQuantizationType.Q5_K: gguf.Q5_K,
|
||||
gguf.GGMLQuantizationType.Q6_K: gguf.Q6_K,
|
||||
gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0,
|
||||
gguf.GGMLQuantizationType.BF16: gguf.BF16,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -49,33 +49,24 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/80a44b97f5cbcb890896e2b9e65d177f1ac6a588/comfy/weight_adapter/base.py#L42
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
|
||||
|
||||
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
|
||||
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||
if wd_on_output_axis:
|
||||
weight_norm = (
|
||||
weight.reshape(weight.shape[0], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||
)
|
||||
else:
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
weight += strength * weight_calc
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
@@ -138,15 +129,10 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
|
||||
elif patch_type == "set":
|
||||
weight.copy_(v[0])
|
||||
|
||||
elif patch_type == "lora":
|
||||
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
|
||||
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
|
||||
dora_scale = v[4]
|
||||
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
@@ -156,26 +142,12 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))
|
||||
|
||||
try:
|
||||
lora_diff = lora_diff.reshape(weight.shape)
|
||||
except:
|
||||
if weight.shape[1] < lora_diff.shape[1]:
|
||||
expand_factor = (lora_diff.shape[1] - weight.shape[1])
|
||||
weight = torch.nn.functional.pad(weight, (0, expand_factor), mode='constant', value=0)
|
||||
elif weight.shape[1] > lora_diff.shape[1]:
|
||||
# expand factor should be 1*64 (for FluxTools Canny or Depth), or 5*64 (for FluxTools Fill)
|
||||
expand_factor = (weight.shape[1] - lora_diff.shape[1])
|
||||
lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0)
|
||||
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
@@ -220,7 +192,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -258,51 +230,29 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
print("ERROR {} {} {}".format(patch_type, key, e))
|
||||
raise e
|
||||
|
||||
elif patch_type == "glora":
|
||||
dora_scale = v[5]
|
||||
|
||||
old_glora = False
|
||||
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||
old_glora = True
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||
pass
|
||||
else:
|
||||
old_glora = False
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
|
||||
if v[4] is None:
|
||||
alpha = 1.0
|
||||
else:
|
||||
if old_glora:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = v[4] / v[1].shape[0]
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=computation_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
|
||||
else:
|
||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
|
||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -349,10 +299,10 @@ class LoraLoader:
|
||||
self.loaded_hash = str([])
|
||||
|
||||
@torch.inference_mode()
|
||||
def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False):
|
||||
def refresh(self, lora_patches, offload_device=torch.device('cpu')):
|
||||
hashes = str(list(lora_patches.keys()))
|
||||
|
||||
if hashes == self.loaded_hash and not force_refresh:
|
||||
if hashes == self.loaded_hash:
|
||||
return
|
||||
|
||||
# Merge Patches
|
||||
|
||||
@@ -6,8 +6,6 @@ from backend.text_processing import parsing, emphasis
|
||||
from backend.text_processing.textual_inversion import EmbeddingDatabase
|
||||
from backend import memory_management
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
last_extra_generation_params = {}
|
||||
@@ -69,7 +67,7 @@ class ClassicTextProcessingEngine:
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.emphasis = emphasis.get_current_option(opts.emphasis)()
|
||||
self.emphasis = emphasis.get_current_option(emphasis_name)()
|
||||
self.text_projection = text_projection
|
||||
self.minimal_clip_skip = minimal_clip_skip
|
||||
self.clip_skip = clip_skip
|
||||
@@ -141,14 +139,14 @@ class ClassicTextProcessingEngine:
|
||||
if self.return_pooled:
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
if self.text_projection and self.embedding_key != 'clip_l':
|
||||
if self.text_projection and self.embedding_key is not 'clip_l':
|
||||
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
|
||||
|
||||
z.pooled = pooled_output
|
||||
return z
|
||||
|
||||
def tokenize_line(self, line):
|
||||
parsed = parsing.parse_prompt_attention(line, self.emphasis.name)
|
||||
parsed = parsing.parse_prompt_attention(line)
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
@@ -250,8 +248,6 @@ class ClassicTextProcessingEngine:
|
||||
return batch_chunks, token_count
|
||||
|
||||
def __call__(self, texts):
|
||||
self.emphasis = emphasis.get_current_option(opts.emphasis)()
|
||||
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
|
||||
used_embeddings = {}
|
||||
|
||||
@@ -20,7 +20,7 @@ re_attention = re.compile(r"""
|
||||
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
||||
|
||||
|
||||
def parse_prompt_attention(text, emphasis):
|
||||
def parse_prompt_attention(text):
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
@@ -32,48 +32,44 @@ def parse_prompt_attention(text, emphasis):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
if emphasis == "None":
|
||||
# interpret literally
|
||||
res = [[text, 1.0]]
|
||||
else:
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith('\\'):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == '(':
|
||||
round_brackets.append(len(res))
|
||||
elif text == '[':
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and round_brackets:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ')' and round_brackets:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == ']' and square_brackets:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
parts = re.split(re_break, text)
|
||||
for i, part in enumerate(parts):
|
||||
if i > 0:
|
||||
res.append(["BREAK", -1])
|
||||
res.append([part, 1.0])
|
||||
if text.startswith('\\'):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == '(':
|
||||
round_brackets.append(len(res))
|
||||
elif text == '[':
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and round_brackets:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ')' and round_brackets:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == ']' and square_brackets:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
parts = re.split(re_break, text)
|
||||
for i, part in enumerate(parts):
|
||||
if i > 0:
|
||||
res.append(["BREAK", -1])
|
||||
res.append([part, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
@@ -4,8 +4,6 @@ from collections import namedtuple
|
||||
from backend.text_processing import parsing, emphasis
|
||||
from backend import memory_management
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
|
||||
@@ -23,7 +21,7 @@ class T5TextProcessingEngine:
|
||||
self.text_encoder = text_encoder.transformer
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.emphasis = emphasis.get_current_option(opts.emphasis)()
|
||||
self.emphasis = emphasis.get_current_option(emphasis_name)()
|
||||
self.min_length = min_length
|
||||
self.id_end = 1
|
||||
self.id_pad = 0
|
||||
@@ -66,7 +64,7 @@ class T5TextProcessingEngine:
|
||||
return z
|
||||
|
||||
def tokenize_line(self, line):
|
||||
parsed = parsing.parse_prompt_attention(line, self.emphasis.name)
|
||||
parsed = parsing.parse_prompt_attention(line)
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
@@ -113,8 +111,6 @@ class T5TextProcessingEngine:
|
||||
zs = []
|
||||
cache = {}
|
||||
|
||||
self.emphasis = emphasis.get_current_option(opts.emphasis)()
|
||||
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
line_z_values = cache[line]
|
||||
|
||||
@@ -25,11 +25,6 @@ class ExtraOptionsSection(scripts.Script):
|
||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
||||
|
||||
not_allowed = ['sd_model_checkpoint', 'sd_vae', 'CLIP_stop_at_last_layers', 'forge_additional_modules']
|
||||
for na in not_allowed:
|
||||
if na in extra_options:
|
||||
extra_options.remove(na)
|
||||
|
||||
mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
|
||||
@@ -1,287 +1,69 @@
|
||||
|
||||
import spaces
|
||||
import os
|
||||
import gradio as gr
|
||||
import gc
|
||||
|
||||
try:
|
||||
import moviepy.editor as mp
|
||||
got_mp = True
|
||||
except:
|
||||
got_mp = False
|
||||
|
||||
from gradio_imageslider import ImageSlider
|
||||
from loadimg import load_img
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
import glob
|
||||
import pathlib
|
||||
from PIL import Image
|
||||
import numpy
|
||||
# torch.set_float32_matmul_precision(["high", "highest"][0])
|
||||
|
||||
|
||||
transform_image = None
|
||||
birefnet = None
|
||||
|
||||
def load_model(model):
|
||||
global birefnet
|
||||
birefnet = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
||||
"ZhengPeng7/"+model, trust_remote_code=True
|
||||
)
|
||||
birefnet.eval()
|
||||
birefnet.half()
|
||||
|
||||
spaces.automatically_move_to_gpu_when_forward(birefnet)
|
||||
os.environ['HOME'] = spaces.convert_root_path() + 'home'
|
||||
|
||||
with spaces.capture_gpu_object() as birefnet_gpu_obj:
|
||||
load_model("BiRefNet_HR")
|
||||
|
||||
def common_setup(w, h):
|
||||
global transform_image
|
||||
|
||||
transform_image = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((w, h)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
||||
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
||||
)
|
||||
|
||||
spaces.automatically_move_to_gpu_when_forward(birefnet)
|
||||
|
||||
transform_image = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
||||
def process(image, save_flat, bg_colour):
|
||||
def fn(image):
|
||||
im = load_img(image, output_type="pil")
|
||||
im = im.convert("RGB")
|
||||
image_size = im.size
|
||||
origin = im.copy()
|
||||
image = load_img(im)
|
||||
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
|
||||
input_images = transform_image(image).unsqueeze(0).to(spaces.gpu)
|
||||
# Prediction
|
||||
with torch.no_grad():
|
||||
preds = birefnet(input_image)[-1].sigmoid().cpu()
|
||||
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
image.putalpha(mask)
|
||||
|
||||
if save_flat:
|
||||
bg_colour += "FF"
|
||||
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
|
||||
background = Image.new("RGBA", image_size, colour_rgb)
|
||||
image = Image.alpha_composite(background, image)
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
# video processing based on https://huggingface.co/spaces/brokerrobin/video-background-removal/blob/main/app.py
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
||||
def video_process(video, bg_colour):
|
||||
# Load the video using moviepy
|
||||
video = mp.VideoFileClip(video)
|
||||
|
||||
fps = video.fps
|
||||
|
||||
# Extract audio from the video
|
||||
audio = video.audio
|
||||
|
||||
# Extract frames at the specified FPS
|
||||
frames = video.iter_frames(fps=fps)
|
||||
|
||||
# Process each frame for background removal
|
||||
processed_frames = []
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
print (f"birefnet [video]: frame {i+1}", end='\r', flush=True)
|
||||
|
||||
image = Image.fromarray(frame)
|
||||
|
||||
if i == 0:
|
||||
image_size = image.size
|
||||
|
||||
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5))
|
||||
background = Image.new("RGBA", image_size, colour_rgb + (255,))
|
||||
|
||||
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
|
||||
# Prediction
|
||||
with torch.no_grad():
|
||||
preds = birefnet(input_image)[-1].sigmoid().cpu()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
|
||||
# Apply mask and composite
|
||||
image.putalpha(mask)
|
||||
processed_image = Image.alpha_composite(background, image)
|
||||
|
||||
processed_frames.append(numpy.array(processed_image))
|
||||
|
||||
# Create a new video from the processed frames
|
||||
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
|
||||
|
||||
# Add the original audio back to the processed video
|
||||
processed_video = processed_video.set_audio(audio)
|
||||
|
||||
# Save the processed video using modified original filename (goes to gradio temp)
|
||||
filename, _ = os.path.splitext(video.filename)
|
||||
filename += "-birefnet.mp4"
|
||||
processed_video.write_videofile(filename, codec="libx264")
|
||||
|
||||
return filename
|
||||
return (image, origin)
|
||||
|
||||
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
||||
def batch_process(input_folder, output_folder, save_png, save_flat, bg_colour):
|
||||
# Ensure output folder exists
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Supported image extensions
|
||||
image_extensions = ['.jpg', '.jpeg', '.jfif', '.png', '.bmp', '.webp', ".avif"]
|
||||
|
||||
# Collect all image files from input folder
|
||||
input_images = []
|
||||
for ext in image_extensions:
|
||||
input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}')))
|
||||
|
||||
if save_flat:
|
||||
bg_colour += "FF"
|
||||
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
|
||||
# Process each image
|
||||
processed_images = []
|
||||
for i, image_path in enumerate(input_images):
|
||||
print (f"birefnet [batch]: image {i+1}", end='\r', flush=True)
|
||||
try:
|
||||
# Load image
|
||||
im = load_img(image_path, output_type="pil")
|
||||
im = im.convert("RGB")
|
||||
image_size = im.size
|
||||
image = load_img(im)
|
||||
|
||||
# Prepare image for processing
|
||||
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
|
||||
|
||||
# Prediction
|
||||
with torch.no_grad():
|
||||
preds = birefnet(input_image)[-1].sigmoid().cpu()
|
||||
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
|
||||
# Apply mask
|
||||
image.putalpha(mask)
|
||||
|
||||
# Save processed image
|
||||
output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}")
|
||||
|
||||
if save_flat:
|
||||
background = Image.new("RGBA", image_size, colour_rgb)
|
||||
image = Image.alpha_composite(background, image)
|
||||
image = image.convert("RGB")
|
||||
elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"):
|
||||
# jpegs don't support alpha channel, so add .png extension (not change, to avoid potential overwrites)
|
||||
output_filename += ".png"
|
||||
if save_png and not output_filename.lower().endswith(".png"):
|
||||
output_filename += ".png"
|
||||
|
||||
image.save(output_filename)
|
||||
|
||||
processed_images.append(output_filename)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {image_path}: {str(e)}")
|
||||
|
||||
return processed_images
|
||||
slider1 = ImageSlider(label="birefnet", type="pil")
|
||||
slider2 = ImageSlider(label="birefnet", type="pil")
|
||||
image = gr.Image(label="Upload an image")
|
||||
text = gr.Textbox(label="Paste an image URL")
|
||||
|
||||
|
||||
def unload():
|
||||
global birefnet, transform_image
|
||||
birefnet = None
|
||||
transform_image = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil")
|
||||
|
||||
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
|
||||
tab1 = gr.Interface(
|
||||
fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never"
|
||||
)
|
||||
|
||||
tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never")
|
||||
|
||||
|
||||
css = """
|
||||
.gradio-container {
|
||||
max-width: 1280px !important;
|
||||
}
|
||||
footer {
|
||||
display: none !important;
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(css=css, analytics_enabled=False) as demo:
|
||||
gr.Markdown("# birefnet for background removal")
|
||||
|
||||
with gr.Tab("image"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image = gr.Image(label="Upload an image", type='pil', height=584)
|
||||
go_image = gr.Button("Remove background")
|
||||
with gr.Column():
|
||||
result1 = gr.Image(label="birefnet", type="pil", height=544)
|
||||
|
||||
with gr.Tab("URL"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
text = gr.Textbox(label="URL to image, or local path to image", max_lines=1)
|
||||
go_text = gr.Button("Remove background")
|
||||
with gr.Column():
|
||||
result2 = gr.Image(label="birefnet", type="pil", height=544)
|
||||
|
||||
if got_mp:
|
||||
with gr.Tab("video"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
video = gr.Video(label="Upload a video", height=584)
|
||||
go_video = gr.Button("Remove background")
|
||||
with gr.Column():
|
||||
result4 = gr.Video(label="birefnet", height=544, show_share_button=False)
|
||||
|
||||
with gr.Tab("batch"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_dir = gr.Textbox(label="Input folder path", max_lines=1)
|
||||
output_dir = gr.Textbox(label="Output folder path (save images will overwrite)", max_lines=1)
|
||||
always_png = gr.Checkbox(label="Always save as PNG", value=True)
|
||||
|
||||
go_batch = gr.Button("Remove background(s)")
|
||||
with gr.Column():
|
||||
result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple")
|
||||
|
||||
with gr.Tab("options"):
|
||||
gr.Markdown("*HR* : high resolution; *matting* : better with transparency; *lite* : faster.")
|
||||
model = gr.Dropdown(label="Model (download on selection, see console for progress)",
|
||||
choices=["BiRefNet_512x512", "BiRefNet", "BiRefNet_HR", "BiRefNet-matting", "BiRefNet_HR-matting", "BiRefNet_lite", "BiRefNet_lite-2K", "BiRefNet-portrait", "BiRefNet-COD", "BiRefNet-DIS5K", "BiRefNet-DIS5k-TR_TEs", "BiRefNet-HRSOD"], value="BiRefNet_HR", type="value")
|
||||
|
||||
gr.Markdown("Regular models trained at 1024 \u00D7 1024; HR models trained at 2048 \u00D7 2048; 2K model trained at 2560 \u00D7 1440.")
|
||||
gr.Markdown("Greater processing image size will typically give more accurate results, but also requires more VRAM (shared memory works well).")
|
||||
with gr.Row():
|
||||
proc_sizeW = gr.Slider(label="birefnet processing image width",
|
||||
minimum=256, maximum=2560, value=2048, step=32)
|
||||
proc_sizeH = gr.Slider(label="birefnet processing image height",
|
||||
minimum=256, maximum=2048, value=2048, step=32)
|
||||
with gr.Row():
|
||||
save_flat = gr.Checkbox(label="Save flat (no mask)", value=False)
|
||||
bg_colour = gr.ColorPicker(label="Background colour for saving flat, and video", value="#00FF00", visible=True, interactive=True)
|
||||
|
||||
model.change(fn=load_model, inputs=model, outputs=None)
|
||||
|
||||
gr.Markdown("### https://github.com/ZhengPeng7/BiRefNet\n### https://huggingface.co/ZhengPeng7")
|
||||
|
||||
go_image.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[image, save_flat, bg_colour], outputs=result1)
|
||||
go_text.click( fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[text, save_flat, bg_colour], outputs=result2)
|
||||
if got_mp:
|
||||
go_video.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
|
||||
fn=video_process, inputs=[video, bg_colour], outputs=result4)
|
||||
go_batch.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
|
||||
fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat, bg_colour], outputs=result3)
|
||||
|
||||
demo.unload(unload)
|
||||
demo = gr.TabbedInterface(
|
||||
[tab1, tab2], ["image", "text"], title="birefnet for background removal"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(inbrowser=True)
|
||||
|
||||
@@ -3,6 +3,7 @@ import gradio as gr
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
import os
|
||||
|
||||
import requests
|
||||
import copy
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
@@ -13,22 +14,33 @@ import matplotlib.patches as patches
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from modules import shared
|
||||
# import subprocess
|
||||
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
||||
|
||||
from unittest.mock import patch
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
return get_imports(filename)
|
||||
imports = get_imports(filename)
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
|
||||
with spaces.capture_gpu_object() as gpu_object:
|
||||
models = {
|
||||
# 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', attn_implementation='sdpa', trust_remote_code=True).to("cuda").eval(),
|
||||
'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
|
||||
# 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
|
||||
'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
|
||||
}
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
models = {
|
||||
# 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', attn_implementation='sdpa', trust_remote_code=True).to("cuda").eval(),
|
||||
'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
|
||||
# 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
|
||||
# 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
|
||||
}
|
||||
|
||||
processors = {
|
||||
# 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
|
||||
'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
|
||||
# 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
|
||||
'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
|
||||
# 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
|
||||
}
|
||||
|
||||
|
||||
@@ -122,8 +134,8 @@ def draw_ocr_bboxes(image, prediction):
|
||||
fill=color)
|
||||
return image
|
||||
|
||||
|
||||
def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
|
||||
image = Image.fromarray(image) # Convert NumPy array to PIL Image
|
||||
if task_prompt == 'Caption':
|
||||
task_prompt = '<CAPTION>'
|
||||
results = run_example(task_prompt, image, model_id=model_id)
|
||||
@@ -222,61 +234,6 @@ def process_image(image, task_prompt, text_input=None, model_id='microsoft/Flore
|
||||
else:
|
||||
return "", None # Return empty string and None for unknown task prompts
|
||||
|
||||
|
||||
@spaces.GPU(gpu_objects=[gpu_object], manual_load=False)
|
||||
def run_example_batch(directory, task_prompt, model_id='microsoft/Florence-2-large', save_caption=False, prefix=""):
|
||||
model = models[model_id]
|
||||
processor = processors[model_id]
|
||||
|
||||
match task_prompt:
|
||||
case 'More Detailed Caption':
|
||||
prompt = '<MORE_DETAILED_CAPTION>'
|
||||
case 'Detailed Caption':
|
||||
prompt = '<DETAILED_CAPTION>'
|
||||
case 'Caption':
|
||||
prompt = '<CAPTION>'
|
||||
case _:
|
||||
prompt = '<CAPTION>'
|
||||
|
||||
results = ""
|
||||
|
||||
# batch_images block lifted from modules/img2img.py
|
||||
if isinstance(directory, str):
|
||||
batch_images = list(shared.walk_files(directory, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff", ".avif")))
|
||||
else:
|
||||
batch_images = [os.path.abspath(x.name) for x in directory]
|
||||
|
||||
for file in batch_images:
|
||||
image = Image.open(file)
|
||||
|
||||
if image:
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
|
||||
generated_ids = model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=1024,
|
||||
early_stopping=False,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
||||
parsed_answer = processor.post_process_generation(
|
||||
generated_text,
|
||||
task=task_prompt,
|
||||
image_size=(image.width, image.height)
|
||||
)
|
||||
|
||||
caption_text = prefix + parsed_answer[task_prompt] # prefix imput add
|
||||
results += f"File: {file}\nCaption: {caption_text}\n\n"
|
||||
|
||||
if save_caption:
|
||||
caption_file = file + ".txt"
|
||||
with open(caption_file, 'w') as f:
|
||||
f.write(caption_text)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
css = """
|
||||
#output {
|
||||
height: 500px;
|
||||
@@ -285,9 +242,6 @@ css = """
|
||||
}
|
||||
"""
|
||||
|
||||
caption_task_list = [
|
||||
'Caption', 'Detailed Caption', 'More Detailed Caption'
|
||||
]
|
||||
|
||||
single_task_list =[
|
||||
'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
|
||||
@@ -297,29 +251,30 @@ single_task_list =[
|
||||
'OCR', 'OCR with Region'
|
||||
]
|
||||
|
||||
cascaded_task_list =[
|
||||
cascased_task_list =[
|
||||
'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
|
||||
]
|
||||
|
||||
|
||||
def update_task_dropdown(choice):
|
||||
if choice == 'Cascaded task':
|
||||
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
|
||||
if choice == 'Cascased task':
|
||||
return gr.Dropdown(choices=cascased_task_list, value='Caption + Grounding')
|
||||
else:
|
||||
return gr.Dropdown(choices=single_task_list, value='Caption')
|
||||
|
||||
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.Markdown(DESCRIPTION)
|
||||
with gr.Tab(label="Florence-2 Image Captioning"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_img = gr.Image(label="Input picture", type="pil")
|
||||
input_img = gr.Image(label="Input Picture")
|
||||
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value=list(models.keys())[0])
|
||||
task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
|
||||
task_prompt = gr.Dropdown(choices=single_task_list, label="Task prompt", value="More Detailed Caption")
|
||||
task_type = gr.Radio(choices=['Single task', 'Cascased task'], label='Task type selector', value='Single task')
|
||||
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="More Detailed Caption")
|
||||
task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
|
||||
text_input = gr.Textbox(label="Text input (optional)")
|
||||
text_input = gr.Textbox(label="Text Input (optional)")
|
||||
submit_btn = gr.Button(value="Submit")
|
||||
with gr.Column():
|
||||
output_text = gr.Textbox(label="Output Text")
|
||||
@@ -339,25 +294,6 @@ with gr.Blocks(css=css) as demo:
|
||||
|
||||
submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
|
||||
|
||||
with gr.Tab(label="Batch captioning"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_directory = gr.Textbox(label="Input directory")
|
||||
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value=list(models.keys())[0])
|
||||
task_prompt = gr.Dropdown(choices=caption_task_list, label="Task prompt", value="More Detailed Caption")
|
||||
save_captions = gr.Checkbox(label="Save captions to textfiles (same filename, same directory)", value=False)
|
||||
prefix_input = gr.Textbox(label="Prefix to add to captions (optional)")
|
||||
batch_btn = gr.Button(value="Submit")
|
||||
with gr.Column():
|
||||
output_text = gr.Textbox(label="Output captions")
|
||||
|
||||
|
||||
batch_btn.click(
|
||||
fn=run_example_batch,
|
||||
inputs=[input_directory, task_prompt, model_selector, save_captions, prefix_input],
|
||||
outputs=output_text
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
demo.launch(debug=True)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import spaces
|
||||
import math
|
||||
import os
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -133,11 +132,6 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
|
||||
spaces.automatically_move_pipeline_components(t2i_pipe)
|
||||
|
||||
|
||||
def overwrite_components(components):
|
||||
global tokenizer, text_encoder, vae, unet, t2i_pipe, i2i_pipe
|
||||
tokenizer, text_encoder, vae, unet, t2i_pipe, i2i_pipe = components
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_prompt_inner(txt: str):
|
||||
max_length = tokenizer.model_max_length
|
||||
@@ -240,14 +234,8 @@ def run_rmbg(img, sigma=0.0):
|
||||
return result.clip(0, 255).astype(np.uint8), alpha
|
||||
|
||||
|
||||
external_processor = None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
||||
if external_processor is not None:
|
||||
return external_processor(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
|
||||
|
||||
bg_source = BGSource(bg_source)
|
||||
input_bg = None
|
||||
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
"tag": "Face Swap, Human Identification, and Style Transfer",
|
||||
"title": "PhotoMaker V2: Improved ID Fidelity and Better Controllability",
|
||||
"repo_id": "TencentARC/PhotoMaker-V2",
|
||||
"revision": "745c135ad240f80e168e21db53e7acf9605edcb5"
|
||||
"revision": "abbf3252f4188b0373bb5384db31f429be40176c"
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, shared
|
||||
from modules.ui_components import InputAccordion
|
||||
from modules import scripts
|
||||
from backend.misc.image_resize import adaptive_resize
|
||||
|
||||
|
||||
@@ -15,16 +14,11 @@ class PatchModelAddDownscale:
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if sigma <= sigma_start and sigma >= sigma_end:
|
||||
h = adaptive_resize(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||
|
||||
shared.kohya_shrink_shape = (h.shape[-1], h.shape[-2])
|
||||
shared.kohya_shrink_shape_out = None
|
||||
return h
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[2] != hsp.shape[2]:
|
||||
h = adaptive_resize(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
|
||||
shared.kohya_shrink_shape_out = (h.shape[-1], h.shape[-2])
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
@@ -50,28 +44,15 @@ class KohyaHRFixForForge(scripts.Script):
|
||||
|
||||
def ui(self, *args, **kwargs):
|
||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
with InputAccordion(False, label=self.title()) as enabled:
|
||||
with gr.Row():
|
||||
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
|
||||
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
|
||||
with gr.Row():
|
||||
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
|
||||
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
|
||||
with gr.Accordion(open=False, label=self.title()):
|
||||
enabled = gr.Checkbox(label='Enabled', value=False)
|
||||
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
|
||||
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
|
||||
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
|
||||
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
|
||||
downscale_after_skip = gr.Checkbox(label='Downscale After Skip', value=True)
|
||||
with gr.Row():
|
||||
downscale_method = gr.Dropdown(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
upscale_method = gr.Dropdown(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
|
||||
self.infotext_fields = [
|
||||
(enabled, lambda d: d.get("kohya_hrfix_enabled", False)),
|
||||
(block_number, "kohya_hrfix_block_number"),
|
||||
(downscale_factor, "kohya_hrfix_downscale_factor"),
|
||||
(start_percent, "kohya_hrfix_start_percent"),
|
||||
(end_percent, "kohya_hrfix_end_percent"),
|
||||
(downscale_after_skip, "kohya_hrfix_downscale_after_skip"),
|
||||
(downscale_method, "kohya_hrfix_downscale_method"),
|
||||
(upscale_method, "kohya_hrfix_upscale_method"),
|
||||
]
|
||||
downscale_method = gr.Radio(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
upscale_method = gr.Radio(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
|
||||
return enabled, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method
|
||||
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts
|
||||
from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
|
||||
from backend.patcher.base import set_model_options_patch_replace
|
||||
from backend.sampling.sampling_function import calc_cond_uncond_batch
|
||||
from modules.ui_components import InputAccordion
|
||||
|
||||
|
||||
class PerturbedAttentionGuidanceForForge(scripts.Script):
|
||||
sorting_priority = 13
|
||||
|
||||
attenuated_scale = 3.0
|
||||
doPAG = True
|
||||
|
||||
def title(self):
|
||||
return "PerturbedAttentionGuidance Integrated"
|
||||
|
||||
@@ -20,82 +15,43 @@ class PerturbedAttentionGuidanceForForge(scripts.Script):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, *args, **kwargs):
|
||||
with InputAccordion(False, label=self.title()) as enabled:
|
||||
with gr.Row():
|
||||
scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0)
|
||||
attenuation = gr.Slider(label='Attenuation (linear, % of scale)', minimum=0.0, maximum=100.0, step=0.1, value=0.0)
|
||||
with gr.Row():
|
||||
start_step = gr.Slider(label='Start step', minimum=0.0, maximum=1.0, step=0.01, value=0.0)
|
||||
end_step = gr.Slider(label='End step', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
|
||||
with gr.Accordion(open=False, label=self.title()):
|
||||
enabled = gr.Checkbox(label='Enabled', value=False)
|
||||
scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0)
|
||||
|
||||
self.infotext_fields = [
|
||||
(enabled, lambda d: d.get("pagi_enabled", False)),
|
||||
(scale, "pagi_scale"),
|
||||
(attenuation, "pagi_attenuation"),
|
||||
(start_step, "pagi_start_step"),
|
||||
(end_step, "pagi_end_step"),
|
||||
]
|
||||
|
||||
return enabled, scale, attenuation, start_step, end_step
|
||||
|
||||
def denoiser_callback(self, params):
|
||||
thisStep = (params.sampling_step) / (params.total_sampling_steps - 1)
|
||||
|
||||
if thisStep >= PerturbedAttentionGuidanceForForge.PAG_start and thisStep <= PerturbedAttentionGuidanceForForge.PAG_end:
|
||||
PerturbedAttentionGuidanceForForge.doPAG = True
|
||||
else:
|
||||
PerturbedAttentionGuidanceForForge.doPAG = False
|
||||
return enabled, scale
|
||||
|
||||
def process_before_every_sampling(self, p, *script_args, **kwargs):
|
||||
enabled, scale, attenuation, start_step, end_step = script_args
|
||||
enabled, scale = script_args
|
||||
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
PerturbedAttentionGuidanceForForge.scale = scale
|
||||
PerturbedAttentionGuidanceForForge.PAG_start = start_step
|
||||
PerturbedAttentionGuidanceForForge.PAG_end = end_step
|
||||
on_cfg_denoiser(self.denoiser_callback)
|
||||
|
||||
unet = p.sd_model.forge_objects.unet.clone()
|
||||
|
||||
def attn_proc(q, k, v, to):
|
||||
return v
|
||||
|
||||
def post_cfg_function(args):
|
||||
denoised = args["denoised"]
|
||||
model, cond_denoised, cond, denoised, sigma, x = \
|
||||
args["model"], args["cond_denoised"], args["cond"], args["denoised"], args["sigma"], args["input"]
|
||||
|
||||
if PerturbedAttentionGuidanceForForge.scale <= 0.0:
|
||||
return denoised
|
||||
new_options = set_model_options_patch_replace(args["model_options"], attn_proc, "attn1", "middle", 0)
|
||||
|
||||
if not PerturbedAttentionGuidanceForForge.doPAG:
|
||||
if scale == 0:
|
||||
return denoised
|
||||
|
||||
model, cond_denoised, cond, sigma, x, options = \
|
||||
args["model"], args["cond_denoised"], args["cond"], args["sigma"], args["input"], args["model_options"].copy()
|
||||
new_options = set_model_options_patch_replace(options, attn_proc, "attn1", "middle", 0)
|
||||
|
||||
degraded, _ = calc_cond_uncond_batch(model, cond, None, x, sigma, new_options)
|
||||
|
||||
result = denoised + (cond_denoised - degraded) * PerturbedAttentionGuidanceForForge.scale
|
||||
PerturbedAttentionGuidanceForForge.scale -= scale * attenuation / 100.0
|
||||
|
||||
return result
|
||||
return denoised + (cond_denoised - degraded) * scale
|
||||
|
||||
unet.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
p.sd_model.forge_objects.unet = unet
|
||||
|
||||
p.extra_generation_params.update(dict(
|
||||
pagi_enabled = enabled,
|
||||
pagi_scale = scale,
|
||||
pagi_attenuation = attenuation,
|
||||
pagi_start_step = start_step,
|
||||
pagi_end_step = end_step,
|
||||
PerturbedAttentionGuidance_enabled=enabled,
|
||||
PerturbedAttentionGuidance_scale=scale,
|
||||
))
|
||||
|
||||
return
|
||||
|
||||
def postprocess(self, params, processed, *args):
|
||||
remove_current_script_callbacks()
|
||||
return
|
||||
|
||||
@@ -62,21 +62,12 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
|
||||
# original method: works for all normal inputs that *do not* have Kohya HRFix scaling; typically fails with scaling
|
||||
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
||||
h = math.ceil(lh / ratio)
|
||||
w = math.ceil(lw / ratio)
|
||||
|
||||
if h * w != mask.size(1):
|
||||
kohya_shrink_shape = getattr(shared, 'kohya_shrink_shape', None)
|
||||
if kohya_shrink_shape:
|
||||
w = kohya_shrink_shape[0] # works with all block numbers for kohya hrfix
|
||||
h = kohya_shrink_shape[1]
|
||||
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, h, w)
|
||||
mask.reshape(b, *mid_shape)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
@@ -201,14 +192,11 @@ class SAGForForge(scripts.Script):
|
||||
|
||||
# not for FLux
|
||||
if not shared.sd_model.is_webui_legacy_model(): # ideally would be is_flux
|
||||
print("Self Attention Guidance is not compatible with Flux")
|
||||
gr.Info ("Self Attention Guidance is not compatible with Flux")
|
||||
return
|
||||
# Self Attention Guidance errors if CFG is 1
|
||||
if p.is_hr_pass == False and p.cfg_scale <= 1:
|
||||
print("Self Attention Guidance requires CFG > 1")
|
||||
return
|
||||
if p.is_hr_pass == True and p.hr_cfg <= 1:
|
||||
print("Self Attention Guidance (hires pass) requires Hires CFG > 1")
|
||||
if p.cfg_scale == 1:
|
||||
gr.Info ("Self Attention Guidance requires CFG > 1")
|
||||
return
|
||||
|
||||
unet = p.sd_model.forge_objects.unet
|
||||
|
||||
@@ -5,9 +5,6 @@ from modules.ui_components import InputAccordion
|
||||
import modules.scripts as scripts
|
||||
from modules.torch_utils import float64
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from scipy.ndimage import convolve
|
||||
from joblib import Parallel, delayed, cpu_count
|
||||
|
||||
class SoftInpaintingSettings:
|
||||
def __init__(self,
|
||||
@@ -247,76 +244,7 @@ def apply_masks(
|
||||
return masks_for_overlay
|
||||
|
||||
|
||||
|
||||
|
||||
def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width):
|
||||
"""
|
||||
Apply the weighted histogram filter to a single pixel.
|
||||
This function is now refactored to be accessible for parallelization.
|
||||
"""
|
||||
idx = np.array(idx)
|
||||
kernel_min = -kernel_center
|
||||
kernel_max = np.array(kernel.shape) - kernel_center
|
||||
|
||||
# Precompute the minimum and maximum valid indices for the kernel
|
||||
min_index = np.maximum(0, idx + kernel_min)
|
||||
max_index = np.minimum(np.array(img.shape), idx + kernel_max)
|
||||
window_shape = max_index - min_index
|
||||
|
||||
# Initialize values and weights arrays
|
||||
values = []
|
||||
weights = []
|
||||
|
||||
for window_tup in np.ndindex(*window_shape):
|
||||
window_index = np.array(window_tup)
|
||||
image_index = window_index + min_index
|
||||
centered_kernel_index = image_index - idx
|
||||
kernel_index = centered_kernel_index + kernel_center
|
||||
values.append(img[tuple(image_index)])
|
||||
weights.append(kernel[tuple(kernel_index)])
|
||||
|
||||
# Convert to NumPy arrays
|
||||
values = np.array(values)
|
||||
weights = np.array(weights)
|
||||
|
||||
# Sort values and weights by values
|
||||
sorted_indices = np.argsort(values)
|
||||
values = values[sorted_indices]
|
||||
weights = weights[sorted_indices]
|
||||
|
||||
# Calculate cumulative weights
|
||||
cumulative_weights = np.cumsum(weights)
|
||||
|
||||
# Define window boundaries
|
||||
sum_weights = cumulative_weights[-1]
|
||||
window_min = sum_weights * percentile_min
|
||||
window_max = sum_weights * percentile_max
|
||||
window_width = window_max - window_min
|
||||
|
||||
# Ensure window is at least `min_width` wide
|
||||
if window_width < min_width:
|
||||
window_center = (window_min + window_max) / 2
|
||||
window_min = window_center - min_width / 2
|
||||
window_max = window_center + min_width / 2
|
||||
|
||||
if window_max > sum_weights:
|
||||
window_max = sum_weights
|
||||
window_min = sum_weights - min_width
|
||||
|
||||
if window_min < 0:
|
||||
window_min = 0
|
||||
window_max = min_width
|
||||
|
||||
# Calculate overlap for each value
|
||||
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||
|
||||
# Weighted average calculation
|
||||
result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||
return result
|
||||
|
||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1):
|
||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
|
||||
"""
|
||||
Generalization convolution filter capable of applying
|
||||
weighted mean, median, maximum, and minimum filters
|
||||
@@ -343,74 +271,101 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe
|
||||
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
||||
"""
|
||||
|
||||
# Ensure kernel_center is a 1D array
|
||||
if isinstance(kernel_center, int):
|
||||
kernel_center = np.array([kernel_center, kernel_center])
|
||||
elif len(kernel_center) == 1:
|
||||
kernel_center = np.array([kernel_center[0], kernel_center[0]])
|
||||
kernel_radius = max(kernel_center)
|
||||
padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0)
|
||||
img_out = np.zeros_like(img)
|
||||
img_shape = img.shape
|
||||
pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])]
|
||||
# Converts an index tuple into a vector.
|
||||
def vec(x):
|
||||
return np.array(x)
|
||||
|
||||
kernel_min = -kernel_center
|
||||
kernel_max = vec(kernel.shape) - kernel_center
|
||||
|
||||
def weighted_histogram_filter_single(idx):
|
||||
"""
|
||||
Single-pixel weighted histogram calculation.
|
||||
"""
|
||||
row, col = idx
|
||||
idx = (row + kernel_radius, col + kernel_radius)
|
||||
min_index = np.array(idx) - kernel_center
|
||||
max_index = min_index + kernel.shape
|
||||
idx = vec(idx)
|
||||
min_index = np.maximum(0, idx + kernel_min)
|
||||
max_index = np.minimum(vec(img.shape), idx + kernel_max)
|
||||
window_shape = max_index - min_index
|
||||
|
||||
window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]]
|
||||
window_values = window.flatten()
|
||||
window_weights = kernel.flatten()
|
||||
class WeightedElement:
|
||||
"""
|
||||
An element of the histogram, its weight
|
||||
and bounds.
|
||||
"""
|
||||
|
||||
sorted_indices = np.argsort(window_values)
|
||||
values = window_values[sorted_indices]
|
||||
weights = window_weights[sorted_indices]
|
||||
def __init__(self, value, weight):
|
||||
self.value: float = value
|
||||
self.weight: float = weight
|
||||
self.window_min: float = 0.0
|
||||
self.window_max: float = 1.0
|
||||
|
||||
cumulative_weights = np.cumsum(weights)
|
||||
sum_weights = cumulative_weights[-1]
|
||||
window_min = max(0, sum_weights * percentile_min)
|
||||
window_max = min(sum_weights, sum_weights * percentile_max)
|
||||
# Collect the values in the image as WeightedElements,
|
||||
# weighted by their corresponding kernel values.
|
||||
values = []
|
||||
for window_tup in np.ndindex(tuple(window_shape)):
|
||||
window_index = vec(window_tup)
|
||||
image_index = window_index + min_index
|
||||
centered_kernel_index = image_index - idx
|
||||
kernel_index = centered_kernel_index + kernel_center
|
||||
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
|
||||
values.append(element)
|
||||
|
||||
def sort_key(x: WeightedElement):
|
||||
return x.value
|
||||
|
||||
values.sort(key=sort_key)
|
||||
|
||||
# Calculate the height of the stack (sum)
|
||||
# and each sample's range they occupy in the stack
|
||||
sum = 0
|
||||
for i in range(len(values)):
|
||||
values[i].window_min = sum
|
||||
sum += values[i].weight
|
||||
values[i].window_max = sum
|
||||
|
||||
# Calculate what range of this stack ("window")
|
||||
# we want to get the weighted average across.
|
||||
window_min = sum * percentile_min
|
||||
window_max = sum * percentile_max
|
||||
window_width = window_max - window_min
|
||||
|
||||
# Ensure the window is within the stack and at least a certain size.
|
||||
if window_width < min_width:
|
||||
window_center = (window_min + window_max) / 2
|
||||
window_min = max(0, window_center - min_width / 2)
|
||||
window_max = min(sum_weights, window_center + min_width / 2)
|
||||
window_min = window_center - min_width / 2
|
||||
window_max = window_center + min_width / 2
|
||||
|
||||
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||
if window_max > sum:
|
||||
window_max = sum
|
||||
window_min = sum - min_width
|
||||
|
||||
return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||
if window_min < 0:
|
||||
window_min = 0
|
||||
window_max = min_width
|
||||
|
||||
# Split pixel_coords into equal chunks based on n_jobs
|
||||
n_jobs = -1
|
||||
if cpu_count() > 6:
|
||||
n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px
|
||||
value = 0
|
||||
value_weight = 0
|
||||
|
||||
chunk_size = len(pixel_coords) // n_jobs
|
||||
pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)]
|
||||
# Get the weighted average of all the samples
|
||||
# that overlap with the window, weighted
|
||||
# by the size of their overlap.
|
||||
for i in range(len(values)):
|
||||
if window_min >= values[i].window_max:
|
||||
continue
|
||||
if window_max <= values[i].window_min:
|
||||
break
|
||||
|
||||
# joblib to process chunks in parallel
|
||||
def process_chunk(chunk):
|
||||
chunk_result = {}
|
||||
for idx in chunk:
|
||||
chunk_result[idx] = weighted_histogram_filter_single(idx)
|
||||
return chunk_result
|
||||
s = max(window_min, values[i].window_min)
|
||||
e = min(window_max, values[i].window_max)
|
||||
w = e - s
|
||||
|
||||
results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration
|
||||
delayed(process_chunk)(chunk) for chunk in pixel_chunks
|
||||
)
|
||||
value += values[i].value * w
|
||||
value_weight += w
|
||||
|
||||
# Combine results into the output image
|
||||
for chunk_result in results:
|
||||
for (row, col), value in chunk_result.items():
|
||||
img_out[row, col] = value
|
||||
return value / value_weight if value_weight != 0 else 0
|
||||
|
||||
img_out = img.copy()
|
||||
|
||||
# Apply the kernel operation over each pixel.
|
||||
for index in np.ndindex(img.shape):
|
||||
img_out[index] = weighted_histogram_filter_single(index)
|
||||
|
||||
return img_out
|
||||
|
||||
@@ -530,7 +485,7 @@ el_ids = SoftInpaintingSettings(
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
self.section = "inpaint"
|
||||
# self.section = "inpaint"
|
||||
self.masks_for_overlay = None
|
||||
self.overlay_images = None
|
||||
|
||||
|
||||
@@ -177,7 +177,7 @@ function modalTileImageToggle(event) {
|
||||
}
|
||||
|
||||
onAfterUiUpdate(function() {
|
||||
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > button > button > img, .gradio-gallery > .livePreview');
|
||||
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > button > button > img');
|
||||
if (fullImg_preview != null) {
|
||||
fullImg_preview.forEach(setupImageForLightbox);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ from torchdiffeq import odeint
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
from k_diffusion import deis
|
||||
from backend.modules.k_prediction import PredictionFlux
|
||||
|
||||
from . import utils
|
||||
|
||||
@@ -140,8 +139,6 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
if isinstance(model.inner_model.predictor, PredictionFlux):
|
||||
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
@@ -158,32 +155,6 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
#sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if eta > 0:
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
@@ -248,8 +219,6 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
if isinstance(model.inner_model.predictor, PredictionFlux):
|
||||
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
@@ -275,38 +244,6 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
|
||||
@@ -21,9 +21,7 @@ from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing,
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
from modules.textual_inversion.textual_inversion import create_embedding
|
||||
from PIL import PngImagePlugin
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
@@ -122,7 +120,7 @@ def encode_pil_to_base64(image):
|
||||
if opts.samples_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
else:
|
||||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid image format")
|
||||
@@ -267,10 +265,6 @@ class Api:
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
|
||||
|
||||
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
|
||||
|
||||
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
@@ -750,6 +744,8 @@ class Api:
|
||||
return styleList
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
@@ -763,13 +759,13 @@ class Api:
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": convert_embeddings(self.embedding_db.word_embeddings),
|
||||
"skipped": convert_embeddings(self.embedding_db.skipped_embeddings),
|
||||
"loaded": convert_embeddings(db.word_embeddings),
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_embeddings(self):
|
||||
with self.queue_lock:
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
with self.queue_lock:
|
||||
@@ -782,14 +778,15 @@ class Api:
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_embedding")
|
||||
filename = modules.textual_inversion.textual_inversion.create_embedding(**args) # create empty embedding
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) # reload embeddings so new one can be immediately used
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_hypernetwork")
|
||||
|
||||
@@ -15,7 +15,6 @@ parser.add_argument("--update-check", action='store_true', help="launch.py argum
|
||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||
parser.add_argument("--skip-google-blockly", action='store_true', help="launch.py argument: do not initialize google blockly modules")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
|
||||
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
||||
|
||||
@@ -14,7 +14,7 @@ class UpscalerDAT(Upscaler):
|
||||
self.scalers = []
|
||||
super().__init__()
|
||||
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth", ".safetensors"]):
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||
name = modelloader.friendly_name(file)
|
||||
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
|
||||
self.scalers.append(scaler_data)
|
||||
@@ -51,18 +51,7 @@ class UpscalerDAT(Upscaler):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
hash_prefix=scaler.sha256,
|
||||
)
|
||||
|
||||
if os.path.getsize(scaler.local_data_path) < 200:
|
||||
# Re-download if the file is too small, probably an LFS pointer
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
hash_prefix=scaler.sha256,
|
||||
re_download=True,
|
||||
)
|
||||
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
@@ -73,23 +62,20 @@ def get_dat_models(scaler):
|
||||
return [
|
||||
UpscalerData(
|
||||
name="DAT x2",
|
||||
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
sha256='7760aa96e4ee77e29d4f89c3a4486200042e019461fdb8aa286f49aa00b89b51',
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x3",
|
||||
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
|
||||
scale=3,
|
||||
upscaler=scaler,
|
||||
sha256='581973e02c06f90d4eb90acf743ec9604f56f3c2c6f9e1e2c2b38ded1f80d197',
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x4",
|
||||
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
sha256='391a6ce69899dff5ea3214557e9d585608254579217169faf3d4c353caff049e',
|
||||
),
|
||||
]
|
||||
|
||||
@@ -54,7 +54,7 @@ device_interrogate: torch.device = memory_management.text_encoder_device() # fo
|
||||
device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = torch.float32 if memory_management.unet_dtype() is torch.float32 else torch.float16
|
||||
dtype: torch.dtype = memory_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = memory_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = memory_management.unet_dtype()
|
||||
dtype_inference: torch.dtype = memory_management.unet_dtype()
|
||||
|
||||
@@ -13,7 +13,7 @@ class UpscalerESRGAN(Upscaler):
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth", ".safetensors"])
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
|
||||
@@ -20,8 +20,7 @@ def calculate_sha256_real(filename):
|
||||
|
||||
|
||||
def calculate_sha256(filename):
|
||||
print("Calculating real hash: ", filename)
|
||||
return calculate_sha256_real(filename)
|
||||
return forge_fake_calculate_sha256(filename)
|
||||
|
||||
|
||||
def forge_fake_calculate_sha256(filename):
|
||||
@@ -60,8 +59,8 @@ def sha256(filename, title, use_addnet_hash=False):
|
||||
if shared.cmd_opts.no_hashing:
|
||||
return None
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='', flush=True)
|
||||
sha256_value = calculate_sha256_real(filename)
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
sha256_value = forge_fake_calculate_sha256(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
||||
hashes[title] = {
|
||||
|
||||
@@ -13,15 +13,13 @@ import numpy as np
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
|
||||
from PIL import __version__ as pillow_version
|
||||
from pkg_resources import parse_version
|
||||
# pillow_avif needs to be imported somewhere in code for it to work
|
||||
import pillow_avif # noqa: F401
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks, errors, stealth_infotext
|
||||
from modules import sd_samplers, shared, script_callbacks, errors
|
||||
from modules.paths_internal import roboto_ttf_file
|
||||
from modules.shared import opts
|
||||
|
||||
@@ -170,18 +168,9 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
for line in lines:
|
||||
fnt = initial_fnt
|
||||
fontsize = initial_fontsize
|
||||
if parse_version(pillow_version) >= parse_version('10.0.0'):
|
||||
# New code for Pillow 10.0.0+
|
||||
text_width, text_height = drawing.multiline_textbbox((0, 0), line.text, font=fnt)[2:]
|
||||
while text_width > line.allowed_width and fontsize > 0:
|
||||
fontsize -= 1
|
||||
fnt = get_font(fontsize)
|
||||
text_width, text_height = drawing.multiline_textbbox((0, 0), line.text, font=fnt)[2:]
|
||||
else:
|
||||
# Old code for Pillow versions below 10.0.0
|
||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||
fontsize -= 1
|
||||
fnt = get_font(fontsize)
|
||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||
fontsize -= 1
|
||||
fnt = get_font(fontsize)
|
||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
|
||||
if not line.is_active:
|
||||
@@ -275,9 +264,6 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None, force_RGBA=
|
||||
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
||||
"""
|
||||
|
||||
if not force_RGBA and im.mode == 'RGBA':
|
||||
im = im.convert('RGB')
|
||||
|
||||
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
||||
|
||||
def resize(im, w, h):
|
||||
@@ -656,7 +642,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
||||
no_prompt:
|
||||
TODO I don't know its meaning.
|
||||
p (`StableDiffusionProcessing` or `Processing`)
|
||||
p (`StableDiffusionProcessing`)
|
||||
forced_filename (`str`):
|
||||
If specified, `basename` and filename pattern will be ignored.
|
||||
save_to_dirs (bool):
|
||||
@@ -687,13 +673,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
if forced_filename is None:
|
||||
if short_filename or seed is None:
|
||||
file_decoration = ""
|
||||
elif hasattr(p, 'override_settings'):
|
||||
file_decoration = p.override_settings.get("samples_filename_pattern")
|
||||
elif opts.save_to_dirs:
|
||||
file_decoration = opts.samples_filename_pattern or "[seed]"
|
||||
else:
|
||||
file_decoration = None
|
||||
|
||||
if file_decoration is None:
|
||||
file_decoration = opts.samples_filename_pattern or ("[seed]" if opts.save_to_dirs else "[seed]-[prompt_spaces]")
|
||||
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
||||
|
||||
file_decoration = namegen.apply(file_decoration) + suffix
|
||||
|
||||
@@ -720,8 +703,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
pnginfo[pnginfo_section_name] = info
|
||||
|
||||
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
||||
if opts.enable_pnginfo:
|
||||
stealth_infotext.add_stealth_pnginfo(params)
|
||||
script_callbacks.before_image_saved_callback(params)
|
||||
|
||||
image = params.image
|
||||
@@ -737,15 +718,12 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
||||
|
||||
filename = filename_without_extension + extension
|
||||
without_extension = filename_without_extension
|
||||
if shared.opts.save_images_replace_action != "Replace":
|
||||
n = 0
|
||||
while os.path.exists(filename):
|
||||
n += 1
|
||||
without_extension = f"{filename_without_extension}-{n}"
|
||||
filename = without_extension + extension
|
||||
filename = f"{filename_without_extension}-{n}{extension}"
|
||||
os.replace(temp_file_path, filename)
|
||||
return without_extension
|
||||
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
if hasattr(os, 'statvfs'):
|
||||
@@ -753,9 +731,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
|
||||
params.filename = fullfn_without_extension + extension
|
||||
fullfn = params.filename
|
||||
_atomically_save_image(image, fullfn_without_extension, extension)
|
||||
|
||||
fullfn_without_extension = _atomically_save_image(image, fullfn_without_extension, extension)
|
||||
fullfn = fullfn_without_extension + extension
|
||||
image.already_saved_as = fullfn
|
||||
|
||||
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
||||
@@ -774,7 +751,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
except Exception:
|
||||
image = image.resize(resize_to)
|
||||
try:
|
||||
_ = _atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
except Exception as e:
|
||||
errors.display(e, "saving image as downscaled JPG")
|
||||
|
||||
@@ -798,53 +775,44 @@ IGNORED_INFO_KEYS = {
|
||||
|
||||
|
||||
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
||||
"""Read generation info from an image, checking standard metadata first, then stealth info if needed."""
|
||||
items = (image.info or {}).copy()
|
||||
|
||||
def read_standard():
|
||||
items = (image.info or {}).copy()
|
||||
geninfo = items.pop('parameters', None)
|
||||
|
||||
geninfo = items.pop('parameters', None)
|
||||
if "exif" in items:
|
||||
exif_data = items["exif"]
|
||||
try:
|
||||
exif = piexif.load(exif_data)
|
||||
except OSError:
|
||||
# memory / exif was not valid so piexif tried to read from a file
|
||||
exif = None
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
if "exif" in items:
|
||||
exif_data = items["exif"]
|
||||
try:
|
||||
exif = piexif.load(exif_data)
|
||||
except OSError:
|
||||
# memory / exif was not valid so piexif tried to read from a file
|
||||
exif = None
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
if exif_comment:
|
||||
geninfo = exif_comment
|
||||
elif "comment" in items: # for gif
|
||||
if isinstance(items["comment"], bytes):
|
||||
geninfo = items["comment"].decode('utf8', errors="ignore")
|
||||
else:
|
||||
geninfo = items["comment"]
|
||||
|
||||
if exif_comment:
|
||||
geninfo = exif_comment
|
||||
elif "comment" in items: # for gif
|
||||
if isinstance(items["comment"], bytes):
|
||||
geninfo = items["comment"].decode('utf8', errors="ignore")
|
||||
else:
|
||||
geninfo = items["comment"]
|
||||
for field in IGNORED_INFO_KEYS:
|
||||
items.pop(field, None)
|
||||
|
||||
for field in IGNORED_INFO_KEYS:
|
||||
items.pop(field, None)
|
||||
if items.get("Software", None) == "NovelAI":
|
||||
try:
|
||||
json_info = json.loads(items["Comment"])
|
||||
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
||||
|
||||
if items.get("Software", None) == "NovelAI":
|
||||
try:
|
||||
json_info = json.loads(items["Comment"])
|
||||
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
||||
|
||||
geninfo = f"""{items["Description"]}
|
||||
Negative prompt: {json_info["uc"]}
|
||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||
except Exception:
|
||||
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
|
||||
|
||||
return geninfo, items
|
||||
|
||||
geninfo, items = read_standard()
|
||||
if geninfo is None:
|
||||
geninfo = stealth_infotext.read_info_from_image_stealth(image)
|
||||
geninfo = f"""{items["Description"]}
|
||||
Negative prompt: {json_info["uc"]}
|
||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||
except Exception:
|
||||
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
|
||||
|
||||
return geninfo, items
|
||||
|
||||
|
||||
@@ -122,14 +122,16 @@ def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False,
|
||||
if output_dir:
|
||||
p.outpath_samples = output_dir
|
||||
p.override_settings['save_to_dirs'] = False
|
||||
|
||||
if opts.img2img_batch_use_original_name:
|
||||
filename_pattern = f'{image_path.stem}-[generation_number]' if p.n_iter > 1 or p.batch_size > 1 else f'{image_path.stem}'
|
||||
p.override_settings['samples_filename_pattern'] = filename_pattern
|
||||
p.override_settings['save_images_replace_action'] = "Add number suffix"
|
||||
if p.n_iter > 1 or p.batch_size > 1:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||
else:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
|
||||
if proc is None:
|
||||
p.override_settings.pop('save_images_replace_action', None)
|
||||
proc = process_images(p)
|
||||
|
||||
if not discard_further_results and proc:
|
||||
|
||||
@@ -197,14 +197,10 @@ def connect_paste_params_buttons():
|
||||
def send_image_and_dimensions(x):
|
||||
if isinstance(x, Image.Image):
|
||||
img = x
|
||||
if img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
elif isinstance(x, list) and isinstance(x[0], tuple):
|
||||
img = x[0][0]
|
||||
else:
|
||||
img = image_from_url_text(x)
|
||||
if img is not None and img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
|
||||
if shared.opts.send_size and isinstance(img, Image.Image):
|
||||
w = img.width
|
||||
|
||||
@@ -133,4 +133,8 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
extra_networks.register_default_extra_networks()
|
||||
startup_timer.record("initialize extra networks")
|
||||
|
||||
from modules_forge import google_blockly
|
||||
google_blockly.initialization()
|
||||
startup_timer.record("initialize google blockly")
|
||||
|
||||
return
|
||||
|
||||
@@ -395,13 +395,15 @@ def prepare_environment():
|
||||
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
# k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
||||
google_blockly_repo = os.environ.get('GOOGLE_BLOCKLY_REPO', 'https://github.com/lllyasviel/google_blockly_prototypes')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
# k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "84826248b49bb7ca754c73293299c4d4e23a548d")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "70942022b6bcd17d941c1b4172804175758618e2")
|
||||
google_blockly_commit_hash = os.environ.get('GOOGLE_BLOCKLY_COMMIT_HASH', "bf36e6fd3750a081f44209ba4f645adb598f7e37")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
@@ -462,6 +464,7 @@ def prepare_environment():
|
||||
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
# git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
||||
git_clone(google_blockly_repo, repo_dir('google_blockly_prototypes'), "google_blockly", google_blockly_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
@@ -560,7 +563,7 @@ def dump_sysinfo():
|
||||
import datetime
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
file.write(text)
|
||||
|
||||
@@ -10,7 +10,6 @@ import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.util import load_file_from_url # noqa, backwards compatibility
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import spandrel
|
||||
@@ -18,6 +17,30 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_file_from_url(
|
||||
url: str,
|
||||
*,
|
||||
model_dir: str,
|
||||
progress: bool = True,
|
||||
file_name: str | None = None,
|
||||
hash_prefix: str | None = None,
|
||||
) -> str:
|
||||
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
if not file_name:
|
||||
parts = urlparse(url)
|
||||
file_name = os.path.basename(parts.path)
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
from torch.hub import download_url_to_file
|
||||
download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)
|
||||
return cached_file
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
|
||||
17
modules/models/sd35/LICENSE-CODE
Normal file
17
modules/models/sd35/LICENSE-CODE
Normal file
@@ -0,0 +1,17 @@
|
||||
MIT License
|
||||
|
||||
Copyright © 2024 Stability AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
File diff suppressed because it is too large
Load Diff
868
modules/models/sd35/other_impls.py
Normal file
868
modules/models/sd35/other_impls.py
Normal file
@@ -0,0 +1,868 @@
|
||||
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
#################################################################################################
|
||||
### Core/Utility
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
bias=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.fc1 = nn.Linear(
|
||||
in_features, hidden_features, bias=bias, dtype=dtype, device=device
|
||||
)
|
||||
self.act = act_layer
|
||||
self.fc2 = nn.Linear(
|
||||
hidden_features, out_features, bias=bias, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### CLIP
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.out_proj = nn.Linear(
|
||||
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
out = attention(q, k, v, self.heads, mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
ACTIVATIONS = {
|
||||
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
}
|
||||
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(
|
||||
embed_dim,
|
||||
intermediate_size,
|
||||
embed_dim,
|
||||
act_layer=ACTIVATIONS[intermediate_activation],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x += self.self_attn(self.layer_norm1(x), mask)
|
||||
x += self.mlp(self.layer_norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
CLIPLayer(
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x, mask)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.token_embedding = torch.nn.Embedding(
|
||||
vocab_size, embed_dim, dtype=dtype, device=device
|
||||
)
|
||||
self.position_embedding = torch.nn.Embedding(
|
||||
num_positions, embed_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.encoder = CLIPEncoder(
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True
|
||||
):
|
||||
x = self.embeddings(input_tokens)
|
||||
causal_mask = (
|
||||
torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
|
||||
.fill_(float("-inf"))
|
||||
.triu_(1)
|
||||
)
|
||||
x, i = self.encoder(
|
||||
x, mask=causal_mask, intermediate_output=intermediate_output
|
||||
)
|
||||
x = self.final_layer_norm(x)
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
pooled_output = x[
|
||||
torch.arange(x.shape[0], device=x.device),
|
||||
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
|
||||
]
|
||||
return x, i, pooled_output
|
||||
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = nn.Linear(
|
||||
embed_dim, embed_dim, bias=False, dtype=dtype, device=device
|
||||
)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.text_model(*args, **kwargs)
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
nesting_level = 0
|
||||
for char in string:
|
||||
if char == "(":
|
||||
if nesting_level == 0:
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
current_item = "("
|
||||
else:
|
||||
current_item = "("
|
||||
else:
|
||||
current_item += char
|
||||
nesting_level += 1
|
||||
elif char == ")":
|
||||
nesting_level -= 1
|
||||
if nesting_level == 0:
|
||||
result.append(current_item + ")")
|
||||
current_item = ""
|
||||
else:
|
||||
current_item += char
|
||||
else:
|
||||
current_item += char
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
return result
|
||||
|
||||
|
||||
def token_weights(string, current_weight):
|
||||
a = parse_parentheses(string)
|
||||
out = []
|
||||
for x in a:
|
||||
weight = current_weight
|
||||
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
|
||||
x = x[1:-1]
|
||||
xx = x.rfind(":")
|
||||
weight *= 1.1
|
||||
if xx > 0:
|
||||
try:
|
||||
weight = float(x[xx + 1 :])
|
||||
x = x[:xx]
|
||||
except:
|
||||
pass
|
||||
out += token_weights(x, weight)
|
||||
else:
|
||||
out += [(x, current_weight)]
|
||||
return out
|
||||
|
||||
|
||||
def escape_important(text):
|
||||
text = text.replace("\\)", "\0\1")
|
||||
text = text.replace("\\(", "\0\2")
|
||||
return text
|
||||
|
||||
|
||||
def unescape_important(text):
|
||||
text = text.replace("\0\1", ")")
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
max_length=77,
|
||||
pad_with_end=True,
|
||||
tokenizer=None,
|
||||
has_start_token=True,
|
||||
pad_to_max_length=True,
|
||||
min_length=None,
|
||||
extra_padding_token=None,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
|
||||
empty = self.tokenizer("")["input_ids"]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
self.extra_padding_token = extra_padding_token
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.max_word_length = 8
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
"""
|
||||
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
|
||||
"""
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
# tokenize words
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = (
|
||||
unescape_important(weighted_segment).replace("\n", " ").split(" ")
|
||||
)
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
# parse word
|
||||
tokens.append(
|
||||
[
|
||||
(t, weight)
|
||||
for t in self.tokenizer(word)["input_ids"][
|
||||
self.tokens_start : -1
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
# reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
# determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
# break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend(
|
||||
[(t, w, i + 1) for t, w in t_group[:remaining_length]]
|
||||
)
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
# add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
# start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group])
|
||||
t_group = []
|
||||
|
||||
# pad extra padding token first befor getting to the end token
|
||||
if self.extra_padding_token is not None:
|
||||
batch.extend(
|
||||
[(self.extra_padding_token, 1.0, 0)]
|
||||
* (self.min_length - len(batch) - 1)
|
||||
)
|
||||
# fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||
|
||||
return batched_tokens
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
class SDXLClipGTokenizer(SDTokenizer):
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(pad_with_end=False, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self):
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||
self.t5xxl = T5XXLTokenizer()
|
||||
|
||||
def tokenize_with_weights(self, text: str):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
|
||||
return out
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
|
||||
out, pooled = self([tokens])
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled
|
||||
output = [out[0:1]]
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
max_length=77,
|
||||
layer="last",
|
||||
layer_idx=None,
|
||||
textmodel_json_config=None,
|
||||
dtype=None,
|
||||
model_class=CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
|
||||
layer_norm_hidden_state=True,
|
||||
return_projected_pooled=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
self.max_length = max_length
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (
|
||||
self.layer,
|
||||
self.layer_idx,
|
||||
self.return_projected_pooled,
|
||||
)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get(
|
||||
"projected_pooled", self.return_projected_pooled
|
||||
)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
outputs = self.transformer(
|
||||
tokens,
|
||||
intermediate_output=self.layer_idx,
|
||||
final_layer_norm_intermediate=self.layer_norm_hidden_state,
|
||||
)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs[1]
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if (
|
||||
not self.return_projected_pooled
|
||||
and len(outputs) >= 4
|
||||
and outputs[3] is not None
|
||||
):
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
class SDXLClipG(SDClipModel):
|
||||
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
|
||||
|
||||
def __init__(
|
||||
self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None
|
||||
):
|
||||
if layer == "penultimate":
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
)
|
||||
|
||||
|
||||
class T5XXLModel(SDClipModel):
|
||||
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
|
||||
|
||||
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"end": 1, "pad": 0},
|
||||
model_class=T5,
|
||||
)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
pad_with_end=False,
|
||||
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
|
||||
has_start_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=77,
|
||||
)
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
forwarded_states = self.layer_norm(x)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
x += forwarded_states
|
||||
return x
|
||||
|
||||
|
||||
class T5Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
|
||||
):
|
||||
super().__init__()
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(
|
||||
self.relative_attention_num_buckets, self.num_heads, device=device
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||
):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(
|
||||
relative_position, torch.zeros_like(relative_position)
|
||||
)
|
||||
# now relative_position is in the range [0, inf)
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_position_if_large = torch.min(
|
||||
relative_position_if_large,
|
||||
torch.full_like(relative_position_if_large, num_buckets - 1),
|
||||
)
|
||||
relative_buckets += torch.where(
|
||||
is_small, relative_position, relative_position_if_large
|
||||
)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
|
||||
:, None
|
||||
]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
|
||||
None, :
|
||||
]
|
||||
relative_position = (
|
||||
memory_position - context_position
|
||||
) # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=True,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(
|
||||
relative_position_bucket
|
||||
) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(
|
||||
0
|
||||
) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
if past_bias is not None:
|
||||
mask = past_bias
|
||||
out = attention(
|
||||
q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask
|
||||
)
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
||||
class T5LayerSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(
|
||||
model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
|
||||
)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
|
||||
x += output
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(
|
||||
T5LayerSelfAttention(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
)
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
x, past_bias = self.layer[0](x, past_bias)
|
||||
x = self.layer[-1](x)
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
vocab_size,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
||||
self.block = torch.nn.ModuleList(
|
||||
[
|
||||
T5Block(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True
|
||||
):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids)
|
||||
past_bias = None
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class T5(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
self.encoder = T5Stack(
|
||||
self.num_layers,
|
||||
config_dict["d_model"],
|
||||
config_dict["d_model"],
|
||||
config_dict["d_ff"],
|
||||
config_dict["num_heads"],
|
||||
config_dict["vocab_size"],
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.encoder.embed_tokens = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.encoder(*args, **kwargs)
|
||||
222
modules/models/sd35/sd3_cond.py
Normal file
222
modules/models/sd35/sd3_cond.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
import safetensors
|
||||
import torch
|
||||
import typing
|
||||
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
|
||||
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
|
||||
|
||||
|
||||
class SafetensorsMapping(typing.Mapping):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file.keys())
|
||||
|
||||
def __iter__(self):
|
||||
for key in self.file.keys():
|
||||
yield key
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.file.get_tensor(key)
|
||||
|
||||
|
||||
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"textual_inversion_key": "clip_g",
|
||||
}
|
||||
|
||||
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
|
||||
def __init__(self, clip_l, clip_g):
|
||||
super().__init__()
|
||||
|
||||
self.clip_l = clip_l
|
||||
self.clip_g = clip_g
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.id_start = empty[0]
|
||||
self.id_end = empty[1]
|
||||
self.id_pad = empty[1]
|
||||
|
||||
self.return_pooled = True
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
tokens_g = tokens.clone()
|
||||
|
||||
for batch_pos in range(tokens_g.shape[0]):
|
||||
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
|
||||
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
|
||||
|
||||
l_out, l_pooled = self.clip_l(tokens)
|
||||
g_out, g_pooled = self.clip_g(tokens_g)
|
||||
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
|
||||
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
lg_out.pooled = vector_out
|
||||
return lg_out
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
|
||||
|
||||
|
||||
class Sd3T5(torch.nn.Module):
|
||||
def __init__(self, t5xxl):
|
||||
super().__init__()
|
||||
|
||||
self.t5xxl = t5xxl
|
||||
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
|
||||
|
||||
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
|
||||
self.id_end = empty[0]
|
||||
self.id_pad = empty[1]
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def tokenize_line(self, line, *, target_token_count=None):
|
||||
if shared.opts.emphasis != "None":
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
|
||||
for text_tokens, (text, weight) in zip(tokenized, parsed):
|
||||
if text == 'BREAK' and weight == -1:
|
||||
continue
|
||||
|
||||
tokens += text_tokens
|
||||
multipliers += [weight] * len(text_tokens)
|
||||
|
||||
tokens += [self.id_end]
|
||||
multipliers += [1.0]
|
||||
|
||||
if target_token_count is not None:
|
||||
if len(tokens) < target_token_count:
|
||||
tokens += [self.id_pad] * (target_token_count - len(tokens))
|
||||
multipliers += [1.0] * (target_token_count - len(tokens))
|
||||
else:
|
||||
tokens = tokens[0:target_token_count]
|
||||
multipliers = multipliers[0:target_token_count]
|
||||
|
||||
return tokens, multipliers
|
||||
|
||||
def forward(self, texts, *, token_count):
|
||||
if not self.t5xxl or not shared.opts.sd3_enable_t5:
|
||||
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
|
||||
|
||||
tokens_batch = []
|
||||
|
||||
for text in texts:
|
||||
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
|
||||
tokens_batch.append(tokens)
|
||||
|
||||
t5_out, t5_pooled = self.t5xxl(tokens_batch)
|
||||
|
||||
return t5_out
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
|
||||
|
||||
|
||||
class SD3Cond(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
|
||||
if shared.opts.sd3_enable_t5:
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||
self.model_t5 = Sd3T5(self.t5xxl)
|
||||
|
||||
def forward(self, prompts: list[str]):
|
||||
with devices.without_autocast():
|
||||
lg_out, vector_out = self.model_lg(prompts)
|
||||
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
|
||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
|
||||
return {
|
||||
'crossattn': lgt_out,
|
||||
'vector': vector_out,
|
||||
}
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||
|
||||
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
|
||||
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.model_lg.tokenize(texts)
|
||||
|
||||
def medvram_modules(self):
|
||||
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||
|
||||
def get_token_count(self, text):
|
||||
_, token_count = self.model_lg.process_texts([text])
|
||||
|
||||
return token_count
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
return self.model_lg.get_target_prompt_token_count(token_count)
|
||||
623
modules/models/sd35/sd3_impls.py
Normal file
623
modules/models/sd35/sd3_impls.py
Normal file
@@ -0,0 +1,623 @@
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules.models.sd35.mmditx import MMDiTX
|
||||
|
||||
#################################################################################################
|
||||
### MMDiT Model Wrapping
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||
|
||||
def __init__(self, shift=1.0):
|
||||
super().__init__()
|
||||
self.shift = shift
|
||||
timesteps = 1000
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer("sigmas", ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core MM-DiT model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
shift=1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
prefix = ''
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
qk_norm = (
|
||||
"rms"
|
||||
if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys()
|
||||
else None
|
||||
)
|
||||
x_block_self_attn_layers = sorted(
|
||||
[
|
||||
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
|
||||
for key in list(
|
||||
filter(
|
||||
re.compile(".*.x_block.attn2.ln_k.weight").match, state_dict.keys()
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
"in_features": context_shape[1],
|
||||
"out_features": context_shape[0],
|
||||
},
|
||||
}
|
||||
self.diffusion_model = MMDiTX(
|
||||
input_size=None,
|
||||
pos_embed_scaling_factor=None,
|
||||
pos_embed_offset=None,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=16,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
adm_in_channels=adm_in_channels,
|
||||
context_embedder_config=context_embedder_config,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
# device=kwargs['device'],
|
||||
# dtype=kwargs['dtype'],
|
||||
# verbose=kwargs['verbose'],
|
||||
# **kwargs
|
||||
)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
|
||||
def apply_model(self, x, sigma, y=None, *args, **kwargs):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
model_output = self.diffusion_model(
|
||||
x.to(dtype), timestep, context=kwargs["context"].to(dtype), y=y.to(dtype)
|
||||
).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""Helper for applying CFG Scaling to diffusion outputs"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, x, timestep, cond, uncond, cond_scale):
|
||||
# Run cond and uncond in a batch together
|
||||
batched = self.model.apply_model(
|
||||
torch.cat([x, x]),
|
||||
torch.cat([timestep, timestep]),
|
||||
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
|
||||
y=torch.cat([cond["y"], uncond["y"]]),
|
||||
)
|
||||
# Then split and apply CFG Scaling
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
scaled = neg_out + (pos_out - neg_out) * cond_scale
|
||||
return scaled
|
||||
|
||||
|
||||
class SD3LatentFormat:
|
||||
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
"""Quick RGB approximate preview of sd3 latents"""
|
||||
factors = torch.tensor(
|
||||
[
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[0.0028, 0.0312, 0.0650],
|
||||
[0.1848, 0.0762, 0.0360],
|
||||
[0.0944, 0.0360, 0.0889],
|
||||
[0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259],
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
|
||||
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Samplers
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_euler(model, x, sigmas, extra_args=None):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_dpmpp_2m(model, x, sigmas, extra_args=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_denoised is None or sigmas[i + 1] == 0:
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### VAE
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self, *, in_channels, out_channels=None, dtype=torch.float32, device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = None
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = x
|
||||
hidden = self.norm1(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv1(hidden)
|
||||
hidden = self.norm2(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv2(hidden)
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class AttnBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.norm(x)
|
||||
q = self.q(hidden)
|
||||
k = self.k(hidden)
|
||||
v = self.v(hidden)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(
|
||||
lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
hidden = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v
|
||||
) # scale is dim ** -0.5 per default
|
||||
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
hidden = self.proj_out(hidden)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class Downsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class VAEEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
in_channels=3,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = torch.nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = torch.nn.ModuleList()
|
||||
attn = torch.nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = torch.nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
||||
self.down.append(down)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class VAEDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
resolution=256,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
|
||||
)
|
||||
# upsampling
|
||||
self.up = torch.nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = torch.nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = torch.nn.Module()
|
||||
up.block = block
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
hidden = self.conv_in(z)
|
||||
# middle
|
||||
hidden = self.mid.block_1(hidden)
|
||||
hidden = self.mid.attn_1(hidden)
|
||||
hidden = self.mid.block_2(hidden)
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
hidden = self.up[i_level].block[i_block](hidden)
|
||||
if i_level != 0:
|
||||
hidden = self.up[i_level].upsample(hidden)
|
||||
# end
|
||||
hidden = self.norm_out(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv_out(hidden)
|
||||
return hidden
|
||||
|
||||
|
||||
class SDVAE(torch.nn.Module):
|
||||
def __init__(self, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
||||
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def decode(self, latent):
|
||||
return self.decoder(latent)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def encode(self, image):
|
||||
hidden = self.encoder(image)
|
||||
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
||||
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
436
modules/models/sd35/sd3_infer.py
Normal file
436
modules/models/sd35/sd3_infer.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# NOTE: Must have folder `models` with the following files:
|
||||
# - `clip_g.safetensors` (openclip bigG, same as SDXL)
|
||||
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
|
||||
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
|
||||
# - `sd3_medium.safetensors` (or whichever main MMDiT model file)
|
||||
# Also can have
|
||||
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules.models.sd35 import sd3_impls
|
||||
from modules.models.sd35.other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
|
||||
from modules.models.sd35.sd3_impls import SDVAE, BaseModel, CFGDenoiser, SD3LatentFormat
|
||||
|
||||
#################################################################################################
|
||||
### Wrappers for model parts
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def load_into(f, model, prefix, device, dtype=None):
|
||||
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
|
||||
for key in f.keys():
|
||||
if key.startswith(prefix) and not key.startswith("loss."):
|
||||
path = key[len(prefix) :].split(".")
|
||||
obj = model
|
||||
for p in path:
|
||||
if obj is list:
|
||||
obj = obj[int(p)]
|
||||
else:
|
||||
obj = getattr(obj, p, None)
|
||||
if obj is None:
|
||||
print(
|
||||
f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model"
|
||||
)
|
||||
break
|
||||
if obj is None:
|
||||
continue
|
||||
try:
|
||||
tensor = f.get_tensor(key).to(device=device)
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
obj.requires_grad_(False)
|
||||
obj.set_(tensor)
|
||||
except Exception as e:
|
||||
print(f"Failed to load key '{key}' in safetensors file: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
}
|
||||
|
||||
|
||||
class ClipG:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_g.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
|
||||
class ClipL:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_l.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDClipModel(
|
||||
layer="hidden",
|
||||
layer_idx=-2,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
layer_norm_hidden_state=False,
|
||||
return_projected_pooled=False,
|
||||
textmodel_json_config=CLIPL_CONFIG,
|
||||
)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class T5XXL:
|
||||
def __init__(self):
|
||||
with safe_open("models/t5xxl.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
class SD3:
|
||||
def __init__(self, model, shift, verbose=False):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = BaseModel(
|
||||
shift=shift,
|
||||
file=f,
|
||||
prefix="model.diffusion_model.",
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
verbose=verbose,
|
||||
).eval()
|
||||
load_into(f, self.model, "model.", "cpu", torch.float16)
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, model):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
|
||||
prefix = ""
|
||||
if any(k.startswith("first_stage_model.") for k in f.keys()):
|
||||
prefix = "first_stage_model."
|
||||
load_into(f, self.model, prefix, "cpu", torch.float16)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Main inference logic
|
||||
#################################################################################################
|
||||
|
||||
|
||||
# Note: Sigma shift value, publicly released models use 3.0
|
||||
SHIFT = 3.0
|
||||
# Naturally, adjust to the width/height of the model you have
|
||||
WIDTH = 1024
|
||||
HEIGHT = 1024
|
||||
# Pick your prompt
|
||||
PROMPT = "a photo of a cat"
|
||||
# Most models prefer the range of 4-5, but still work well around 7
|
||||
CFG_SCALE = 4.5
|
||||
# Different models want different step counts but most will be good at 50, albeit that's slow to run
|
||||
# sd3_medium is quite decent at 28 steps
|
||||
STEPS = 40
|
||||
# Seed
|
||||
SEED = 23
|
||||
# SEEDTYPE = "fixed"
|
||||
SEEDTYPE = "rand"
|
||||
# SEEDTYPE = "roll"
|
||||
# Actual model file path
|
||||
# MODEL = "models/sd3_medium.safetensors"
|
||||
# MODEL = "models/sd3.5_large_turbo.safetensors"
|
||||
MODEL = "models/sd3.5_large.safetensors"
|
||||
# VAE model file path, or set None to use the same model file
|
||||
VAEFile = None # "models/sd3_vae.safetensors"
|
||||
# Optional init image file path
|
||||
INIT_IMAGE = None
|
||||
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
|
||||
DENOISE = 0.6
|
||||
# Output file path
|
||||
OUTDIR = "outputs"
|
||||
# SAMPLER
|
||||
# SAMPLER = "euler"
|
||||
SAMPLER = "dpmpp_2m"
|
||||
|
||||
|
||||
class SD3Inferencer:
|
||||
def print(self, txt):
|
||||
if self.verbose:
|
||||
print(txt)
|
||||
|
||||
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT, verbose=False):
|
||||
self.verbose = verbose
|
||||
print("Loading tokenizers...")
|
||||
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
|
||||
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
|
||||
# (T5 tokenizer is different though)
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
print("Loading OpenAI CLIP L...")
|
||||
self.clip_l = ClipL()
|
||||
print("Loading OpenCLIP bigG...")
|
||||
self.clip_g = ClipG()
|
||||
print("Loading Google T5-v1-XXL...")
|
||||
self.t5xxl = T5XXL()
|
||||
print(f"Loading SD3 model {os.path.basename(model)}...")
|
||||
self.sd3 = SD3(model, shift, verbose)
|
||||
print("Loading VAE model...")
|
||||
self.vae = VAE(vae or model)
|
||||
print("Models loaded.")
|
||||
|
||||
def get_empty_latent(self, width, height):
|
||||
self.print("Prep an empty latent...")
|
||||
return torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
|
||||
|
||||
def get_sigmas(self, sampling, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(sampling.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def get_noise(self, seed, latent):
|
||||
generator = torch.manual_seed(seed)
|
||||
self.print(
|
||||
f"dtype = {latent.dtype}, layout = {latent.layout}, device = {latent.device}"
|
||||
)
|
||||
return torch.randn(
|
||||
latent.size(),
|
||||
dtype=torch.float32,
|
||||
layout=latent.layout,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
).to(latent.dtype)
|
||||
|
||||
def get_cond(self, prompt):
|
||||
self.print("Encode prompt...")
|
||||
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
||||
l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
|
||||
g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
|
||||
t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
return torch.cat([lg_out, t5_out], dim=-2), torch.cat(
|
||||
(l_pooled, g_pooled), dim=-1
|
||||
)
|
||||
|
||||
def max_denoise(self, sigmas):
|
||||
max_sigma = float(self.sd3.model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
def fix_cond(self, cond):
|
||||
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
|
||||
return {"c_crossattn": cond, "y": pooled}
|
||||
|
||||
def do_sampling(
|
||||
self,
|
||||
latent,
|
||||
seed,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler="dpmpp_2m",
|
||||
denoise=1.0,
|
||||
) -> torch.Tensor:
|
||||
self.print("Sampling...")
|
||||
latent = latent.half().cuda()
|
||||
self.sd3.model = self.sd3.model.cuda()
|
||||
noise = self.get_noise(seed, latent).cuda()
|
||||
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
|
||||
sigmas = sigmas[int(steps * (1 - denoise)) :]
|
||||
conditioning = self.fix_cond(conditioning)
|
||||
neg_cond = self.fix_cond(neg_cond)
|
||||
extra_args = {"cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale}
|
||||
noise_scaled = self.sd3.model.model_sampling.noise_scaling(
|
||||
sigmas[0], noise, latent, self.max_denoise(sigmas)
|
||||
)
|
||||
sample_fn = getattr(sd3_impls, f"sample_{sampler}")
|
||||
latent = sample_fn(
|
||||
CFGDenoiser(self.sd3.model), noise_scaled, sigmas, extra_args=extra_args
|
||||
)
|
||||
latent = SD3LatentFormat().process_out(latent)
|
||||
self.sd3.model = self.sd3.model.cpu()
|
||||
self.print("Sampling done")
|
||||
return latent
|
||||
|
||||
def vae_encode(self, image) -> torch.Tensor:
|
||||
self.print("Encoding image to latent...")
|
||||
image = image.convert("RGB")
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
image_np = np.moveaxis(image_np, 2, 0)
|
||||
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
|
||||
image_torch = torch.from_numpy(batch_images)
|
||||
image_torch = 2.0 * image_torch - 1.0
|
||||
image_torch = image_torch.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
latent = self.vae.model.encode(image_torch).cpu()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
self.print("Encoded")
|
||||
return latent
|
||||
|
||||
def vae_decode(self, latent) -> Image.Image:
|
||||
self.print("Decoding latent to image...")
|
||||
latent = latent.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
image = self.vae.model.decode(latent)
|
||||
image = image.float()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
out_image = Image.fromarray(decoded_np)
|
||||
self.print("Decoded")
|
||||
return out_image
|
||||
|
||||
def gen_image(
|
||||
self,
|
||||
prompts=[PROMPT],
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
steps=STEPS,
|
||||
cfg_scale=CFG_SCALE,
|
||||
sampler=SAMPLER,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
out_dir=OUTDIR,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
):
|
||||
latent = self.get_empty_latent(width, height)
|
||||
if init_image:
|
||||
image_data = Image.open(init_image)
|
||||
image_data = image_data.resize((width, height), Image.LANCZOS)
|
||||
latent = self.vae_encode(image_data)
|
||||
latent = SD3LatentFormat().process_in(latent)
|
||||
neg_cond = self.get_cond("")
|
||||
seed_num = None
|
||||
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True)
|
||||
for i, prompt in pbar:
|
||||
if seed_type == "roll":
|
||||
seed_num = seed if seed_num is None else seed_num + 1
|
||||
elif seed_type == "rand":
|
||||
seed_num = torch.randint(0, 100000, (1,)).item()
|
||||
else: # fixed
|
||||
seed_num = seed
|
||||
conditioning = self.get_cond(prompt)
|
||||
sampled_latent = self.do_sampling(
|
||||
latent,
|
||||
seed_num,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler,
|
||||
denoise if init_image else 1.0,
|
||||
)
|
||||
image = self.vae_decode(sampled_latent)
|
||||
save_path = os.path.join(out_dir, f"{i:06d}.png")
|
||||
self.print(f"Will save to {save_path}")
|
||||
image.save(save_path)
|
||||
self.print("Done")
|
||||
|
||||
|
||||
CONFIGS = {
|
||||
"sd3_medium": {
|
||||
"shift": 1.0,
|
||||
"cfg": 5.0,
|
||||
"steps": 50,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large": {
|
||||
"shift": 3.0,
|
||||
"cfg": 4.5,
|
||||
"steps": 40,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(
|
||||
prompt=PROMPT,
|
||||
model=MODEL,
|
||||
out_dir=OUTDIR,
|
||||
postfix=None,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
sampler=None,
|
||||
steps=None,
|
||||
cfg=None,
|
||||
shift=None,
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
vae=VAEFile,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
verbose=False,
|
||||
):
|
||||
steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
|
||||
cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
|
||||
shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
|
||||
sampler = (
|
||||
sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
|
||||
)
|
||||
|
||||
inferencer = SD3Inferencer()
|
||||
inferencer.load(model, vae, shift, verbose)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if os.path.splitext(prompt)[-1] == ".txt":
|
||||
with open(prompt, "r") as f:
|
||||
prompts = [l.strip() for l in f.readlines()]
|
||||
else:
|
||||
prompts = [prompt]
|
||||
|
||||
out_dir = os.path.join(
|
||||
out_dir,
|
||||
os.path.splitext(os.path.basename(model))[0],
|
||||
os.path.splitext(os.path.basename(prompt))[0][:50]
|
||||
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
|
||||
)
|
||||
print(f"Saving images to {out_dir}")
|
||||
os.makedirs(out_dir, exist_ok=False)
|
||||
|
||||
inferencer.gen_image(
|
||||
prompts,
|
||||
width,
|
||||
height,
|
||||
steps,
|
||||
cfg,
|
||||
sampler,
|
||||
seed,
|
||||
seed_type,
|
||||
out_dir,
|
||||
init_image,
|
||||
denoise,
|
||||
)
|
||||
|
||||
|
||||
fire.Fire(main)
|
||||
@@ -1402,7 +1402,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
fp_additional_modules = getattr(shared.opts, 'forge_additional_modules')
|
||||
|
||||
reload = False
|
||||
if hasattr(self, 'hr_additional_modules') and 'Use same choices' not in self.hr_additional_modules:
|
||||
if 'Use same choices' not in self.hr_additional_modules:
|
||||
modules_changed = main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False)
|
||||
if modules_changed:
|
||||
reload = True
|
||||
|
||||
@@ -21,29 +21,34 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||
with gr.Row():
|
||||
refiner_checkpoint = gr.Dropdown(label='Checkpoint', info='(use model of same architecture)', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles(use_short=True)], value='', tooltip="switch to another model in the middle of generation")
|
||||
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles(use_short=True)}, self.elem_id("checkpoint_refresh"))
|
||||
|
||||
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||
|
||||
def lookup_checkpoint(title):
|
||||
info = sd_models.get_closet_checkpoint_match(title)
|
||||
return None if info is None else info.short_title
|
||||
|
||||
self.infotext_fields = [
|
||||
PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
||||
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
||||
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
||||
]
|
||||
gr.Markdown('Refiner is currently under maintenance and unavailable. Sorry for the inconvenience.')
|
||||
|
||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||
#
|
||||
# with gr.Row():
|
||||
# refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles()], value='', tooltip="switch to another model in the middle of generation", interactive=False, visible=False)
|
||||
# # create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
|
||||
#
|
||||
# refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation", interactive=False, visible=False)
|
||||
#
|
||||
# def lookup_checkpoint(title):
|
||||
# info = sd_models.get_closet_checkpoint_match(title)
|
||||
# return None if info is None else info.title
|
||||
#
|
||||
# self.infotext_fields = [
|
||||
# PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
||||
# PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
||||
# PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
||||
# ]
|
||||
|
||||
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||
p.refiner_checkpoint = None
|
||||
p.refiner_switch_at = None
|
||||
else:
|
||||
p.refiner_checkpoint = refiner_checkpoint
|
||||
p.refiner_switch_at = refiner_switch_at
|
||||
return [enable_refiner] # , refiner_checkpoint, refiner_switch_at
|
||||
|
||||
def setup(self, p, enable_refiner):
|
||||
pass
|
||||
# # the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||
#
|
||||
# if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||
# p.refiner_checkpoint = None
|
||||
# p.refiner_switch_at = None
|
||||
# else:
|
||||
# p.refiner_checkpoint = refiner_checkpoint
|
||||
# p.refiner_switch_at = refiner_switch_at
|
||||
|
||||
@@ -450,7 +450,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
memory_management.unload_all_models()
|
||||
return
|
||||
pass
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
|
||||
@@ -63,10 +63,9 @@ def set_samplers():
|
||||
|
||||
def add_sampler(sampler):
|
||||
global all_samplers, all_samplers_map
|
||||
if sampler.name not in [x.name for x in all_samplers]:
|
||||
all_samplers.append(sampler)
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
set_samplers()
|
||||
all_samplers.append(sampler)
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
set_samplers()
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -168,9 +168,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
|
||||
sigma = real_sigma
|
||||
|
||||
if sd_samplers_common.apply_refiner(self, x):
|
||||
cond = self.sampler.sampler_extra_args['cond']
|
||||
uncond = self.sampler.sampler_extra_args['uncond']
|
||||
# if sd_samplers_common.apply_refiner(self, x):
|
||||
# cond = self.sampler.sampler_extra_args['cond']
|
||||
# uncond = self.sampler.sampler_extra_args['uncond']
|
||||
|
||||
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if uncond is not None else None
|
||||
|
||||
@@ -8,7 +8,7 @@ from modules.shared import opts, state
|
||||
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
||||
from modules import extra_networks
|
||||
import k_diffusion.sampling
|
||||
from modules_forge import main_entry
|
||||
|
||||
|
||||
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
@@ -70,12 +70,9 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = x_sample.cpu()
|
||||
x_sample.clamp_(0.0, 1.0)
|
||||
x_sample.mul_(255.)
|
||||
x_sample.round_()
|
||||
x_sample = x_sample.to(torch.uint8)
|
||||
x_sample = np.moveaxis(x_sample.numpy(), 0, 2)
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
@@ -164,51 +161,45 @@ replace_torchsde_browinan()
|
||||
|
||||
|
||||
def apply_refiner(cfg_denoiser, x):
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
|
||||
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||
return False
|
||||
|
||||
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
return False
|
||||
|
||||
if getattr(cfg_denoiser.p, "enable_hr", False):
|
||||
is_second_pass = cfg_denoiser.p.is_hr_pass
|
||||
|
||||
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
||||
return False
|
||||
|
||||
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
||||
return False
|
||||
|
||||
if opts.hires_fix_refiner_pass != "second pass":
|
||||
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
||||
|
||||
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
|
||||
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
|
||||
checkpoint_changed = main_entry.checkpoint_change(refiner_checkpoint_info.short_title, save=False, refresh=False)
|
||||
if checkpoint_changed:
|
||||
try:
|
||||
main_entry.refresh_model_loading_parameters()
|
||||
sd_models.forge_model_reload()
|
||||
finally:
|
||||
main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=True)
|
||||
|
||||
if not cfg_denoiser.p.disable_extra_networks:
|
||||
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
||||
|
||||
cfg_denoiser.p.setup_conds()
|
||||
cfg_denoiser.update_inner_model()
|
||||
|
||||
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||
return True
|
||||
# completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
# refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
# refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
#
|
||||
# if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||
# return False
|
||||
#
|
||||
# if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
# return False
|
||||
#
|
||||
# if getattr(cfg_denoiser.p, "enable_hr", False):
|
||||
# is_second_pass = cfg_denoiser.p.is_hr_pass
|
||||
#
|
||||
# if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
||||
# return False
|
||||
#
|
||||
# if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
||||
# return False
|
||||
#
|
||||
# if opts.hires_fix_refiner_pass != "second pass":
|
||||
# cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
||||
#
|
||||
# cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
# cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
#
|
||||
# sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||
#
|
||||
# with sd_models.SkipWritingToConfig():
|
||||
# sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||
#
|
||||
# if not cfg_denoiser.p.disable_extra_networks:
|
||||
# extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
||||
#
|
||||
# cfg_denoiser.p.setup_conds()
|
||||
# cfg_denoiser.update_inner_model()
|
||||
#
|
||||
# sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||
# return True
|
||||
pass
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
|
||||
@@ -188,10 +188,10 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
|
||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}),
|
||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}),
|
||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model negative prompt"),
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model prompt"),
|
||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
|
||||
@@ -225,13 +225,12 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
|
||||
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
|
||||
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||
"img2img_inpaint_mask_high_contrast": OptionInfo(True, "For inpainting, use a high-contrast brush pattern").info("use a checkerboard brush pattern instead of color brush").needs_reload_ui(),
|
||||
"img2img_inpaint_mask_scribble_alpha": OptionInfo(75, "Inpaint mask alpha (transparency)", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("only affects non-high-contrast brush").needs_reload_ui(),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
|
||||
"overlay_inpaint": OptionInfo(True, "Overlay original for inpaint").info("when inpainting, overlay the original image over the areas that weren't inpainted."),
|
||||
"img2img_autosize": OptionInfo(False, "After loading into Img2img, automatically update Width and Height"),
|
||||
"img2img_batch_use_original_name": OptionInfo(False, "Save using original filename in img2img batch. Applies to 'Upload' and 'From directory' tabs.").info("Warning: overwriting is possible, based on Settings > Saving images/grids > Saving the image to an existing file.")
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
|
||||
@@ -320,7 +319,6 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
|
||||
"sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
|
||||
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
|
||||
"open_dir_button_choice": OptionInfo("Subdirectory", "What directory the [📂] button opens", gr.Radio, {"choices": ["Output Root", "Subdirectory", "Subdirectory (even temp dir)"]}),
|
||||
"hires_button_gallery_insert": OptionInfo(False, "Insert [✨] hires button results into gallery").info("Default: original image will be replaced"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), {
|
||||
@@ -340,7 +338,6 @@ options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||
"quick_setting_list": OptionInfo([], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"tabs_without_quick_settings_bar": OptionInfo(["Spaces"], "UI tabs without Quicksettings bar (top row)", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}),
|
||||
"ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||
@@ -358,7 +355,6 @@ Infotext is what this software calls the text that contains generation parameter
|
||||
It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button.
|
||||
"""),
|
||||
"enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"),
|
||||
"stealth_pnginfo_option": OptionInfo("Alpha", "Stealth infotext mode", gr.Radio, {"choices": ["Alpha", "RGB", "None"]}).info("Ignored if infotext is disabled"),
|
||||
"save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"),
|
||||
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to infotext"),
|
||||
|
||||
@@ -4,10 +4,8 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
|
||||
from modules import errors, shared, devices
|
||||
from backend.args import args
|
||||
from typing import Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -36,10 +34,6 @@ class State:
|
||||
|
||||
def __init__(self):
|
||||
self.server_start = time.time()
|
||||
if args.cuda_stream:
|
||||
self.vae_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self.vae_stream = None
|
||||
|
||||
@property
|
||||
def need_restart(self) -> bool:
|
||||
@@ -159,18 +153,10 @@ class State:
|
||||
import modules.sd_samplers
|
||||
|
||||
try:
|
||||
if self.vae_stream is not None:
|
||||
# not waiting on default stream will result in corrupt results
|
||||
# will not block main stream under any circumstances
|
||||
self.vae_stream.wait_stream(torch.cuda.default_stream())
|
||||
vae_context = torch.cuda.stream(self.vae_stream)
|
||||
if shared.opts.show_progress_grid:
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
vae_context = nullcontext()
|
||||
with vae_context:
|
||||
if shared.opts.show_progress_grid:
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
import gzip
|
||||
|
||||
from modules.script_callbacks import ImageSaveParams
|
||||
from modules import shared
|
||||
|
||||
|
||||
def add_stealth_pnginfo(params: ImageSaveParams):
|
||||
stealth_pnginfo_option = shared.opts.data.get('stealth_pnginfo_option', 'Alpha')
|
||||
if not stealth_pnginfo_option or stealth_pnginfo_option == 'None':
|
||||
return
|
||||
if not params.filename.endswith('.png') or params.pnginfo is None:
|
||||
return
|
||||
if 'parameters' not in params.pnginfo:
|
||||
return
|
||||
add_data(params, str(stealth_pnginfo_option), True)
|
||||
|
||||
def prepare_data(params, mode='Alpha', compressed=True):
|
||||
signature = f"stealth_{'png' if mode == 'Alpha' else 'rgb'}{'info' if not compressed else 'comp'}"
|
||||
binary_signature = ''.join(format(byte, '08b') for byte in signature.encode('utf-8'))
|
||||
param = params.encode('utf-8') if not compressed else gzip.compress(bytes(params, 'utf-8'))
|
||||
binary_param = ''.join(format(byte, '08b') for byte in param)
|
||||
binary_param_len = format(len(binary_param), '032b')
|
||||
return binary_signature + binary_param_len + binary_param
|
||||
|
||||
def add_data(params, mode='Alpha', compressed=True):
|
||||
binary_data = prepare_data(params.pnginfo['parameters'], mode, compressed)
|
||||
if mode == 'Alpha':
|
||||
params.image.putalpha(255)
|
||||
width, height = params.image.size
|
||||
pixels = params.image.load()
|
||||
index = 0
|
||||
end_write = False
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
if index >= len(binary_data):
|
||||
end_write = True
|
||||
break
|
||||
values = pixels[x, y]
|
||||
if mode == 'Alpha':
|
||||
r, g, b, a = values
|
||||
else:
|
||||
r, g, b = values
|
||||
if mode == 'Alpha':
|
||||
a = (a & ~1) | int(binary_data[index])
|
||||
index += 1
|
||||
else:
|
||||
r = (r & ~1) | int(binary_data[index])
|
||||
if index + 1 < len(binary_data):
|
||||
g = (g & ~1) | int(binary_data[index + 1])
|
||||
if index + 2 < len(binary_data):
|
||||
b = (b & ~1) | int(binary_data[index + 2])
|
||||
index += 3
|
||||
pixels[x, y] = (r, g, b, a) if mode == 'Alpha' else (r, g, b)
|
||||
if end_write:
|
||||
break
|
||||
|
||||
def read_info_from_image_stealth(image):
|
||||
geninfo = None
|
||||
width, height = image.size
|
||||
pixels = image.load()
|
||||
|
||||
has_alpha = True if image.mode == 'RGBA' else False
|
||||
mode = None
|
||||
compressed = False
|
||||
binary_data = ''
|
||||
buffer_a = ''
|
||||
buffer_rgb = ''
|
||||
index_a = 0
|
||||
index_rgb = 0
|
||||
sig_confirmed = False
|
||||
confirming_signature = True
|
||||
reading_param_len = False
|
||||
reading_param = False
|
||||
read_end = False
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
if has_alpha:
|
||||
r, g, b, a = pixels[x, y]
|
||||
buffer_a += str(a & 1)
|
||||
index_a += 1
|
||||
else:
|
||||
r, g, b = pixels[x, y]
|
||||
buffer_rgb += str(r & 1)
|
||||
buffer_rgb += str(g & 1)
|
||||
buffer_rgb += str(b & 1)
|
||||
index_rgb += 3
|
||||
if confirming_signature:
|
||||
if index_a == len('stealth_pnginfo') * 8:
|
||||
decoded_sig = bytearray(int(buffer_a[i:i + 8], 2) for i in
|
||||
range(0, len(buffer_a), 8)).decode('utf-8', errors='ignore')
|
||||
if decoded_sig in {'stealth_pnginfo', 'stealth_pngcomp'}:
|
||||
confirming_signature = False
|
||||
sig_confirmed = True
|
||||
reading_param_len = True
|
||||
mode = 'alpha'
|
||||
if decoded_sig == 'stealth_pngcomp':
|
||||
compressed = True
|
||||
buffer_a = ''
|
||||
index_a = 0
|
||||
else:
|
||||
read_end = True
|
||||
break
|
||||
elif index_rgb == len('stealth_pnginfo') * 8:
|
||||
decoded_sig = bytearray(int(buffer_rgb[i:i + 8], 2) for i in
|
||||
range(0, len(buffer_rgb), 8)).decode('utf-8', errors='ignore')
|
||||
if decoded_sig in {'stealth_rgbinfo', 'stealth_rgbcomp'}:
|
||||
confirming_signature = False
|
||||
sig_confirmed = True
|
||||
reading_param_len = True
|
||||
mode = 'rgb'
|
||||
if decoded_sig == 'stealth_rgbcomp':
|
||||
compressed = True
|
||||
buffer_rgb = ''
|
||||
index_rgb = 0
|
||||
elif reading_param_len:
|
||||
if mode == 'alpha':
|
||||
if index_a == 32:
|
||||
param_len = int(buffer_a, 2)
|
||||
reading_param_len = False
|
||||
reading_param = True
|
||||
buffer_a = ''
|
||||
index_a = 0
|
||||
else:
|
||||
if index_rgb == 33:
|
||||
pop = buffer_rgb[-1]
|
||||
buffer_rgb = buffer_rgb[:-1]
|
||||
param_len = int(buffer_rgb, 2)
|
||||
reading_param_len = False
|
||||
reading_param = True
|
||||
buffer_rgb = pop
|
||||
index_rgb = 1
|
||||
elif reading_param:
|
||||
if mode == 'alpha':
|
||||
if index_a == param_len:
|
||||
binary_data = buffer_a
|
||||
read_end = True
|
||||
break
|
||||
else:
|
||||
if index_rgb >= param_len:
|
||||
diff = param_len - index_rgb
|
||||
if diff < 0:
|
||||
buffer_rgb = buffer_rgb[:diff]
|
||||
binary_data = buffer_rgb
|
||||
read_end = True
|
||||
break
|
||||
else:
|
||||
# impossible
|
||||
read_end = True
|
||||
break
|
||||
if read_end:
|
||||
break
|
||||
if sig_confirmed and binary_data != '':
|
||||
# Convert binary string to UTF-8 encoded text
|
||||
byte_data = bytearray(int(binary_data[i:i + 8], 2) for i in range(0, len(binary_data), 8))
|
||||
try:
|
||||
if compressed:
|
||||
decoded_data = gzip.decompress(bytes(byte_data)).decode('utf-8')
|
||||
else:
|
||||
decoded_data = byte_data.decode('utf-8', errors='ignore')
|
||||
geninfo = decoded_data
|
||||
except:
|
||||
pass
|
||||
return geninfo
|
||||
@@ -102,21 +102,14 @@ def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
insert = getattr(shared.opts, 'hires_button_gallery_insert', False)
|
||||
new_gallery = []
|
||||
for i, image in enumerate(gallery):
|
||||
if insert or i != gallery_index:
|
||||
image[0].already_saved_as = image[0].filename.rsplit('?', 1)[0]
|
||||
new_gallery.append(image)
|
||||
if i == gallery_index:
|
||||
new_gallery.extend(processed.images)
|
||||
|
||||
new_index = gallery_index
|
||||
if insert:
|
||||
new_index += 1
|
||||
geninfo["infotexts"].insert(new_index, processed.info)
|
||||
else:
|
||||
geninfo["infotexts"][gallery_index] = processed.info
|
||||
else:
|
||||
new_gallery.append(image)
|
||||
|
||||
geninfo["infotexts"][gallery_index] = processed.info
|
||||
|
||||
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
|
||||
|
||||
new_width = p.hr_resize_x or p.hr_upscale_to_x
|
||||
new_height = p.hr_resize_y or p.hr_upscale_to_y
|
||||
|
||||
|
||||
new_width -= new_width % 8 # note: hardcoded latent size 8
|
||||
new_height -= new_height % 8
|
||||
|
||||
@@ -346,7 +346,7 @@ def create_ui():
|
||||
with FormRow(elem_id="txt2img_hires_fix_row_cfg", variant="compact"):
|
||||
hr_distilled_cfg = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label="Hires Distilled CFG Scale", value=3.5, elem_id="txt2img_hr_distilled_cfg")
|
||||
hr_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.1, label="Hires CFG Scale", value=7.0, elem_id="txt2img_hr_cfg")
|
||||
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_checkpoint_container:
|
||||
hr_checkpoint_name = gr.Dropdown(label='Hires Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint", scale=2)
|
||||
|
||||
@@ -360,7 +360,7 @@ def create_ui():
|
||||
else:
|
||||
modules_list += list(main_entry.module_list.keys())
|
||||
return modules_list
|
||||
|
||||
|
||||
modules_list = get_additional_modules()
|
||||
|
||||
def refresh_model_and_modules():
|
||||
@@ -470,12 +470,6 @@ def create_ui():
|
||||
toprow.prompt.submit(**txt2img_args)
|
||||
toprow.submit.click(**txt2img_args)
|
||||
|
||||
def select_gallery_image(index):
|
||||
index = int(index)
|
||||
if getattr(shared.opts, 'hires_button_gallery_insert', False):
|
||||
index += 1
|
||||
return gr.update(selected_index=index)
|
||||
|
||||
txt2img_upscale_inputs = txt2img_inputs[0:1] + [output_panel.gallery, dummy_component_number, output_panel.generation_info] + txt2img_inputs[1:]
|
||||
output_panel.button_upscale.click(
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
|
||||
@@ -483,7 +477,7 @@ def create_ui():
|
||||
inputs=txt2img_upscale_inputs,
|
||||
outputs=txt2img_outputs,
|
||||
show_progress=False,
|
||||
).then(fn=select_gallery_image, js="selected_gallery_index", inputs=[dummy_component], outputs=[output_panel.gallery])
|
||||
)
|
||||
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
@@ -593,7 +587,7 @@ def create_ui():
|
||||
add_copy_image_controls('sketch', sketch)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||
init_img_with_mask = ForgeCanvas(elem_id="img2maskimg", height=512, contrast_scribbles=opts.img2img_inpaint_mask_high_contrast, scribble_color=opts.img2img_inpaint_mask_brush_color, scribble_color_fixed=True, scribble_alpha=opts.img2img_inpaint_mask_scribble_alpha, scribble_alpha_fixed=True, scribble_softness_fixed=True)
|
||||
init_img_with_mask = ForgeCanvas(elem_id="img2maskimg", height=512, contrast_scribbles=opts.img2img_inpaint_mask_high_contrast, scribble_color=opts.img2img_inpaint_mask_brush_color, scribble_color_fixed=True, scribble_alpha_fixed=True, scribble_softness_fixed=True)
|
||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||
|
||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||
@@ -921,13 +915,13 @@ def create_ui():
|
||||
|
||||
scripts.scripts_current = None
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||
with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface:
|
||||
ui_postprocessing.create_ui()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
||||
with gr.Blocks(analytics_enabled=False, head=canvas_head) as pnginfo_interface:
|
||||
with ResizeHandleRow(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil", height="50vh", image_mode="RGBA")
|
||||
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
html = gr.HTML()
|
||||
@@ -969,6 +963,8 @@ def create_ui():
|
||||
extensions_interface = ui_extensions.create_ui()
|
||||
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
||||
|
||||
interface_names_without_quick_setting_bars = ["Spaces"]
|
||||
|
||||
shared.tab_names = []
|
||||
for _interface, label, _ifid in interfaces:
|
||||
shared.tab_names.append(label)
|
||||
@@ -996,8 +992,7 @@ def create_ui():
|
||||
loadsave.setup_ui()
|
||||
|
||||
def tab_changed(evt: gr.SelectData):
|
||||
no_quick_setting = getattr(shared.opts, "tabs_without_quick_settings_bar", [])
|
||||
return gr.update(visible=evt.value not in no_quick_setting)
|
||||
return gr.update(visible=evt.value not in interface_names_without_quick_setting_bars)
|
||||
|
||||
tabs.select(tab_changed, outputs=[quicksettings_row], show_progress=False, queue=False)
|
||||
|
||||
@@ -1074,7 +1069,7 @@ def setup_ui_api(app):
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
|
||||
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def create_ui():
|
||||
with gr.Column(variant='compact'):
|
||||
with gr.Tabs(elem_id="mode_extras"):
|
||||
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||
extras_image = gr.Image(label="Source", interactive=True, type="pil", elem_id="extras_image", image_mode="RGBA", height="55vh")
|
||||
extras_image = ForgeCanvas(elem_id="extras_image", height=512, no_scribbles=True).background
|
||||
|
||||
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
||||
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
||||
|
||||
@@ -93,14 +93,13 @@ class UpscalerData:
|
||||
scaler: Upscaler = None
|
||||
model: None
|
||||
|
||||
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None, sha256: str = None):
|
||||
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
||||
self.name = name
|
||||
self.data_path = path
|
||||
self.local_data_path = path
|
||||
self.scaler = upscaler
|
||||
self.scale = scale
|
||||
self.model = model
|
||||
self.sha256 = sha256
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
|
||||
|
||||
@@ -211,66 +211,3 @@ Requested path was: {path}
|
||||
subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])])
|
||||
else:
|
||||
subprocess.Popen(["xdg-open", path])
|
||||
|
||||
|
||||
def load_file_from_url(
|
||||
url: str,
|
||||
*,
|
||||
model_dir: str,
|
||||
progress: bool = True,
|
||||
file_name: str | None = None,
|
||||
hash_prefix: str | None = None,
|
||||
re_download: bool = False,
|
||||
) -> str:
|
||||
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||
Returns the path to the downloaded file.
|
||||
|
||||
file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
|
||||
file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
|
||||
hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
|
||||
if the hash does not match, the temporary file is deleted and a ValueError is raised.
|
||||
re_download: forcibly re-download the file even if it already exists.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
if not file_name:
|
||||
parts = urlparse(url)
|
||||
file_name = os.path.basename(parts.path)
|
||||
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||
|
||||
if re_download or not os.path.exists(cached_file):
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
temp_file = os.path.join(model_dir, f"{file_name}.tmp")
|
||||
print(f'\nDownloading: "{url}" to {cached_file}')
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar:
|
||||
with open(temp_file, 'wb') as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
if hash_prefix and not compare_sha256(temp_file, hash_prefix):
|
||||
print(f"Hash mismatch for {temp_file}. Deleting the temporary file.")
|
||||
os.remove(temp_file)
|
||||
raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!")
|
||||
|
||||
os.rename(temp_file, cached_file)
|
||||
return cached_file
|
||||
|
||||
|
||||
def compare_sha256(file_path: str, hash_prefix: str) -> bool:
|
||||
"""Check if the SHA256 hash of the file matches the given prefix."""
|
||||
import hashlib
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
|
||||
|
||||
@@ -2,7 +2,7 @@ import pkg_resources
|
||||
|
||||
from modules.launch_utils import run_pip
|
||||
|
||||
target_bitsandbytes_version = '0.45.3'
|
||||
target_bitsandbytes_version = '0.43.3'
|
||||
|
||||
|
||||
def try_install_bnb():
|
||||
|
||||
@@ -5,16 +5,6 @@
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.forge-image-container-plain {
|
||||
width: 100%;
|
||||
height: calc(100% - 6px);
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
background-color: #202020;
|
||||
background-size: 20px 20px;
|
||||
background-position: 0 0, 10px 10px;
|
||||
}
|
||||
|
||||
.forge-image-container {
|
||||
width: 100%;
|
||||
height: calc(100% - 6px);
|
||||
@@ -58,16 +48,6 @@
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.forge-toolbar-static {
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
left: 0px;
|
||||
z-index: 10 !important;
|
||||
background: rgba(47, 47, 47, 0.8);
|
||||
padding: 6px 10px;
|
||||
opacity: 1.0 !important;
|
||||
}
|
||||
|
||||
.forge-toolbar {
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
@@ -79,7 +59,7 @@
|
||||
transition: opacity 0.3s ease;
|
||||
}
|
||||
|
||||
.forge-toolbar .forge-btn, .forge-toolbar-static .forge-btn {
|
||||
.forge-toolbar .forge-btn {
|
||||
padding: 2px 6px;
|
||||
border: none;
|
||||
background-color: #4a4a4a;
|
||||
@@ -89,11 +69,11 @@
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
|
||||
.forge-toolbar .forge-btn, .forge-toolbar-static .forge-btn:hover {
|
||||
.forge-toolbar .forge-btn:hover {
|
||||
background-color: #5e5e5e;
|
||||
}
|
||||
|
||||
.forge-toolbar .forge-btn, .forge-toolbar-static .forge-btn:active {
|
||||
.forge-toolbar .forge-btn:active {
|
||||
background-color: #3e3e3e;
|
||||
}
|
||||
|
||||
@@ -176,4 +156,4 @@
|
||||
width: 30%;
|
||||
height: 30%;
|
||||
transform: translate(-50%, -50%);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ from io import BytesIO
|
||||
from gradio.context import Context
|
||||
from functools import wraps
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
canvas_js_root_path = os.path.dirname(__file__)
|
||||
|
||||
@@ -137,15 +136,7 @@ class ForgeCanvas:
|
||||
elem_classes=None
|
||||
):
|
||||
self.uuid = 'uuid_' + uuid.uuid4().hex
|
||||
|
||||
canvas_html_uuid = canvas_html.replace('forge_mixin', self.uuid)
|
||||
|
||||
if opts.forge_canvas_plain:
|
||||
canvas_html_uuid = canvas_html_uuid.replace('class="forge-image-container"', 'class="forge-image-container-plain"').replace('stroke="white"', 'stroke=#444')
|
||||
if opts.forge_canvas_toolbar_always:
|
||||
canvas_html_uuid = canvas_html_uuid.replace('class="forge-toolbar"', 'class="forge-toolbar-static"')
|
||||
|
||||
self.block = gr.HTML(canvas_html_uuid, visible=visible, elem_id=elem_id, elem_classes=elem_classes)
|
||||
self.block = gr.HTML(canvas_html.replace('forge_mixin', self.uuid), visible=visible, elem_id=elem_id, elem_classes=elem_classes)
|
||||
self.foreground = LogicalImage(visible=DEBUG_MODE, label='foreground', numpy=numpy, elem_id=self.uuid, elem_classes=['logical_image_foreground'])
|
||||
self.background = LogicalImage(visible=DEBUG_MODE, label='background', numpy=numpy, value=initial_image, elem_id=self.uuid, elem_classes=['logical_image_background'])
|
||||
Context.root_block.load(None, js=f'async ()=>{{new ForgeCanvas("{self.uuid}", {no_upload}, {no_scribbles}, {contrast_scribbles}, {height}, '
|
||||
|
||||
30
modules_forge/google_blockly.py
Normal file
30
modules_forge/google_blockly.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# See also: https://github.com/lllyasviel/google_blockly_prototypes/blob/main/LICENSE_pyz
|
||||
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import importlib.util
|
||||
|
||||
|
||||
pyz_dir = os.path.abspath(os.path.realpath(os.path.join(__file__, '../../repositories/google_blockly_prototypes/forge')))
|
||||
module_suffix = ".pyz"
|
||||
|
||||
|
||||
def initialization():
|
||||
print('Loading additional modules ... ', end='')
|
||||
|
||||
for filename in os.listdir(pyz_dir):
|
||||
if not filename.endswith(module_suffix):
|
||||
continue
|
||||
|
||||
module_name = filename[:-len(module_suffix)]
|
||||
module_package_name = __package__ + '.' + module_name
|
||||
dynamic_module = importlib.util.module_from_spec(importlib.util.spec_from_loader(module_package_name, loader=None))
|
||||
dynamic_module.__dict__['__file__'] = os.path.join(pyz_dir, module_name + '.py')
|
||||
dynamic_module.__dict__['__package__'] = module_package_name
|
||||
google_blockly_context = gzip.open(os.path.join(pyz_dir, filename), 'rb').read().decode('utf-8')
|
||||
exec(google_blockly_context, dynamic_module.__dict__)
|
||||
globals()[module_name] = dynamic_module
|
||||
|
||||
print('done.')
|
||||
return
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
def register(options_templates, options_section, OptionInfo):
|
||||
options_templates.update(options_section((None, "Forge Hidden options"), {
|
||||
"forge_unet_storage_dtype": OptionInfo('Automatic'),
|
||||
@@ -7,7 +8,3 @@ def register(options_templates, options_section, OptionInfo):
|
||||
"forge_preset": OptionInfo('sd'),
|
||||
"forge_additional_modules": OptionInfo([]),
|
||||
}))
|
||||
options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), {
|
||||
"forge_canvas_plain": OptionInfo(False, "ForgeCanvas: use plain background").needs_reload_ui(),
|
||||
"forge_canvas_toolbar_always": OptionInfo(False, "ForgeCanvas: toolbar always visible").needs_reload_ui(),
|
||||
}))
|
||||
|
||||
104
packages_3rdparty/comfyui_lora_collection/lora.py
vendored
104
packages_3rdparty/comfyui_lora_collection/lora.py
vendored
@@ -30,19 +30,6 @@ LORA_CLIP_MAP = {
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
# BFL loras for Flux; from ComfyUI: comfy/lora_convert.py
|
||||
def convert_lora_bfl_control(sd):
|
||||
import torch
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
return sd_out
|
||||
|
||||
if "img_in.lora_A.weight" in lora and "single_blocks.0.norm.key_norm.scale" in lora:
|
||||
lora = convert_lora_bfl_control(lora)
|
||||
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@@ -202,12 +189,6 @@ def load_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
set_weight_name = "{}.set_weight".format(x)
|
||||
set_weight = lora.get(set_weight_name, None)
|
||||
if set_weight is not None:
|
||||
patch_dict[to_load[x]] = ("set", (set_weight,))
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||
return patch_dict, remaining_dict
|
||||
|
||||
@@ -253,32 +234,32 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map[lora_key] = k
|
||||
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
#####
|
||||
lora_key = "lora_te2_{}".format(l_key.replace(".", "_"))#OneTrainer Flux lora, by Forge
|
||||
key_map[lora_key] = k
|
||||
#####
|
||||
# for k in sdk:
|
||||
# if k.endswith(".weight"):
|
||||
# if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
# l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
# lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
# key_map[lora_key] = k
|
||||
#
|
||||
# #####
|
||||
# lora_key = "lora_te2_{}".format(l_key.replace(".", "_"))#OneTrainer Flux lora, by Forge
|
||||
# key_map[lora_key] = k
|
||||
# #####
|
||||
# elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
# l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
# lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
# key_map[lora_key] = k
|
||||
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
# key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||
key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
||||
|
||||
k = "clip_l.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
||||
#
|
||||
#
|
||||
# k = "clip_g.transformer.text_projection.weight"
|
||||
# if k in sdk:
|
||||
# key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||
# # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||
# key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
||||
#
|
||||
# k = "clip_l.transformer.text_projection.weight"
|
||||
# if k in sdk:
|
||||
# key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
||||
|
||||
return sdk, key_map
|
||||
|
||||
@@ -288,13 +269,11 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = sd.keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
else:
|
||||
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
diffusers_keys = utils.unet_to_diffusers(model.diffusion_model.config)
|
||||
for k in diffusers_keys:
|
||||
@@ -302,8 +281,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
||||
@@ -311,19 +289,19 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
# if 'stable-diffusion-3' in model.config.huggingface_repo.lower(): #Diffusers lora SD3
|
||||
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||
# key_map[key_lora] = to
|
||||
|
||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||
# key_map[key_lora] = to
|
||||
|
||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
# key_map[key_lora] = to
|
||||
# if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
# for k in diffusers_keys:
|
||||
# if k.endswith(".weight"):
|
||||
# to = diffusers_keys[k]
|
||||
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
# key_map[key_lora] = to
|
||||
#
|
||||
# if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
# diffusers_keys = utils.auraflow_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
|
||||
3
packages_3rdparty/gguf/quants.py
vendored
3
packages_3rdparty/gguf/quants.py
vendored
@@ -268,9 +268,6 @@ class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
|
||||
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
||||
|
||||
class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
@classmethod
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
setuptools==69.5.1 # temp fix for compatibility with some old packages
|
||||
GitPython==3.1.32
|
||||
Pillow==10.4.0
|
||||
accelerate==0.31.0
|
||||
blendmodes==2024.1.1
|
||||
Pillow==9.5.0
|
||||
accelerate==0.21.0
|
||||
blendmodes==2022
|
||||
clean-fid==0.1.35
|
||||
diskcache==5.6.3
|
||||
einops==0.4.1
|
||||
@@ -14,7 +14,7 @@ inflection==0.5.1
|
||||
jsonmerge==1.8.0
|
||||
kornia==0.6.7
|
||||
lark==1.1.2
|
||||
numpy==1.26.4
|
||||
numpy==1.26.2
|
||||
omegaconf==2.2.3
|
||||
open-clip-torch==2.20.0
|
||||
piexif==1.1.3
|
||||
@@ -30,14 +30,14 @@ tomesd==0.1.3
|
||||
torch
|
||||
torchdiffeq==0.2.3
|
||||
torchsde==0.2.6
|
||||
transformers==4.46.1
|
||||
transformers==4.44.0
|
||||
httpx==0.24.1
|
||||
pillow-avif-plugin==1.4.3
|
||||
diffusers==0.31.0
|
||||
diffusers==0.29.2
|
||||
gradio_rangeslider==0.0.6
|
||||
gradio_imageslider==0.0.20
|
||||
loadimg==0.1.2
|
||||
tqdm==4.66.1
|
||||
peft==0.13.2
|
||||
peft==0.12.0
|
||||
pydantic==2.8.2
|
||||
huggingface-hub==0.26.2
|
||||
huggingface-hub==0.24.6
|
||||
|
||||
@@ -8,8 +8,7 @@ import gradio as gr
|
||||
from modules import sd_samplers, errors, sd_models
|
||||
from modules.processing import Processed, process_images
|
||||
from modules.shared import state
|
||||
from modules.images import image_grid, save_image
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def process_model_tag(tag):
|
||||
info = sd_models.get_closet_checkpoint_match(tag)
|
||||
@@ -102,10 +101,10 @@ def cmdargs(line):
|
||||
|
||||
def load_prompt_file(file):
|
||||
if file is None:
|
||||
return None, gr.update()
|
||||
return None, gr.update(), gr.update(lines=7)
|
||||
else:
|
||||
lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
|
||||
return None, "\n".join(lines)
|
||||
return None, "\n".join(lines), gr.update(lines=7)
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
@@ -116,16 +115,19 @@ class Script(scripts.Script):
|
||||
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
|
||||
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
|
||||
prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
|
||||
make_combined = gr.Checkbox(label="Make a combined image containing all outputs (if more than one)", value=False)
|
||||
|
||||
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=2, elem_id=self.elem_id("prompt_txt"))
|
||||
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
|
||||
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
|
||||
|
||||
file.upload(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt], show_progress=False)
|
||||
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt], show_progress=False)
|
||||
|
||||
return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt, make_combined]
|
||||
# We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
|
||||
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
|
||||
# be unclear to the user that shift-enter is needed.
|
||||
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
|
||||
return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
|
||||
|
||||
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str, make_combined):
|
||||
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
|
||||
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
|
||||
|
||||
p.do_not_save_grid = True
|
||||
@@ -186,36 +188,4 @@ class Script(scripts.Script):
|
||||
all_prompts += proc.all_prompts
|
||||
infotexts += proc.infotexts
|
||||
|
||||
if make_combined and len(images) > 1:
|
||||
combined_image = image_grid(images, batch_size=1, rows=None).convert("RGB")
|
||||
full_infotext = "\n".join(infotexts)
|
||||
|
||||
is_img2img = getattr(p, "init_images", None) is not None
|
||||
|
||||
if opts.grid_save: # use grid specific Settings
|
||||
save_image(
|
||||
combined_image,
|
||||
opts.outdir_grids or (opts.outdir_img2img_grids if is_img2img else opts.outdir_txt2img_grids),
|
||||
"",
|
||||
-1,
|
||||
prompt_txt,
|
||||
opts.grid_format,
|
||||
full_infotext,
|
||||
grid=True
|
||||
)
|
||||
else: # use normal output Settings
|
||||
save_image(
|
||||
combined_image,
|
||||
opts.outdir_samples or (opts.outdir_img2img_samples if is_img2img else opts.outdir_txt2img_samples),
|
||||
"",
|
||||
-1,
|
||||
prompt_txt,
|
||||
opts.samples_format,
|
||||
full_infotext
|
||||
)
|
||||
|
||||
images.insert(0, combined_image)
|
||||
all_prompts.insert(0, prompt_txt)
|
||||
infotexts.insert(0, full_infotext)
|
||||
|
||||
return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user