3 Commits
py3.12 ... sd35

Author SHA1 Message Date
layerdiffusion
c5853ef9f8 dep ver 2024-10-28 23:04:13 -07:00
lllyasviel
2d5b6cacef Merge pull request #2183 from graemeniedermayer/sd35_integration
sd3.5 integration (naive)
2024-10-28 21:54:10 -07:00
grae
1363999fb1 sd3.5 integration 2024-10-25 18:39:45 -06:00
101 changed files with 202320 additions and 3898 deletions

View File

@@ -1,7 +0,0 @@
About Gradio 5: will try to upgrade to Gradio 5 at about 2025 March. If failed, then will try again on about 2025 June. relatively positive that we can have Gradio5 before next summer.
2024 Oct 28: A new branch `sd35` is contributed by [#2183](https://github.com/lllyasviel/stable-diffusion-webui-forge/pull/2183) . I will take a look at quants and sampling and transformer's clip-g vs that clip-g rewrite before merging to main ... (Oct 29: okay maybe medium also need to take a look later)
About Flux ControlNet (sync [here](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/932)): The rewrite of ControlNet Intergrated will ~start at about Sep 29~ (delayed) ~start at about Oct 15~ (delayed) ~start at about Oct 30~ (delayed) start at about Nov 20. (When this note is announced, the main targets include some diffusers formatted Flux ControlNets and some community implementation of Union ControlNets. However, this may be extended if stronger models come out after this note.)
2024 Sep 7: New sampler `Flux Realistic` is available now! Recommended scheduler is "simple".

View File

@@ -6,7 +6,9 @@ The name "Forge" is inspired from "Minecraft Forge". This project is aimed at be
Forge is currently based on SD-WebUI 1.10.1 at [this commit](https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/82a973c04367123ae98bd9abdf80d9eda9b910e2). (Because original SD-WebUI is almost static now, Forge will sync with original WebUI every 90 days, or when important fixes.)
News are moved to this link: [Click here to see the News section](https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/NEWS.md)
# News
2024 Sep 7: New sampler `Flux Realistic` is available now! Recommended scheduler is "simple".
# Quick List

View File

@@ -1,72 +0,0 @@
import torch
from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args
from backend.modules.k_prediction import PredictionFlux
from backend import memory_management
class Chroma(ForgeDiffusionEngine):
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
self.is_inpaint = False
clip = CLIP(
model_dict={
't5xxl': huggingface_components['text_encoder']
},
tokenizer_dict={
't5xxl': huggingface_components['tokenizer']
}
)
vae = VAE(model=huggingface_components['vae'])
k_predictor = PredictionFlux(
mu=1.0
)
unet = UnetPatcher.from_model(
model=huggingface_components['transformer'],
diffusers_scheduler=None,
k_predictor=k_predictor,
config=estimated_config
)
self.text_processing_engine_t5 = T5TextProcessingEngine(
text_encoder=clip.cond_stage_model.t5xxl,
tokenizer=clip.tokenizer.t5xxl,
emphasis_name=dynamic_args['emphasis_name'],
min_length=1
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
def set_clip_skip(self, clip_skip):
pass
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
return self.text_processing_engine_t5(prompt)
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
return token_count, max(255, token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)

View File

@@ -33,17 +33,9 @@ class Flux(ForgeDiffusionEngine):
vae = VAE(model=huggingface_components['vae'])
if 'schnell' in estimated_config.huggingface_repo.lower():
k_predictor = PredictionFlux(
mu=1.0
)
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000)
else:
k_predictor = PredictionFlux(
seq_len=4096,
base_seq_len=256,
max_seq_len=4096,
base_shift=0.5,
max_shift=1.15,
)
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000)
self.use_distilled_cfg_scale = True
unet = UnetPatcher.from_model(

View File

@@ -9,9 +9,6 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
import safetensors.torch as sf
from backend import utils
class StableDiffusion(ForgeDiffusionEngine):
matched_guesses = [model_list.SD15]
@@ -82,19 +79,3 @@ class StableDiffusion(ForgeDiffusionEngine):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SD15.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename

View File

@@ -9,9 +9,6 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
import safetensors.torch as sf
from backend import utils
class StableDiffusion2(ForgeDiffusionEngine):
matched_guesses = [model_list.SD20]
@@ -82,19 +79,3 @@ class StableDiffusion2(ForgeDiffusionEngine):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SD20.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename

View File

@@ -1,149 +1,137 @@
import torch
from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
from backend.modules.k_prediction import PredictionDiscreteFlow
from modules.shared import opts
## patch SD3 Class in huggingface_guess.model_list
def SD3_clip_target(self, state_dict={}):
return {'clip_l': 'text_encoder', 'clip_g': 'text_encoder_2', 't5xxl': 'text_encoder_3'}
model_list.SD3.unet_target = 'transformer'
model_list.SD3.clip_target = SD3_clip_target
## end patch
class StableDiffusion3(ForgeDiffusionEngine):
matched_guesses = [model_list.SD3]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
self.is_inpaint = False
clip = CLIP(
model_dict={
'clip_l': huggingface_components['text_encoder'],
'clip_g': huggingface_components['text_encoder_2'],
't5xxl' : huggingface_components['text_encoder_3']
},
tokenizer_dict={
'clip_l': huggingface_components['tokenizer'],
'clip_g': huggingface_components['tokenizer_2'],
't5xxl' : huggingface_components['tokenizer_3']
}
)
k_predictor = PredictionDiscreteFlow(shift=3.0)
vae = VAE(model=huggingface_components['vae'])
unet = UnetPatcher.from_model(
model=huggingface_components['transformer'],
diffusers_scheduler= None,
k_predictor=k_predictor,
config=estimated_config
)
self.text_processing_engine_l = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_l,
tokenizer=clip.tokenizer.clip_l,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_l',
embedding_expected_shape=768,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=True,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=True,
final_layer_norm=False,
)
self.text_processing_engine_g = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_g,
tokenizer=clip.tokenizer.clip_g,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_g',
embedding_expected_shape=1280,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=True,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=True,
final_layer_norm=False,
)
self.text_processing_engine_t5 = T5TextProcessingEngine(
text_encoder=clip.cond_stage_model.t5xxl,
tokenizer=clip.tokenizer.t5xxl,
emphasis_name=dynamic_args['emphasis_name'],
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
# WebUI Legacy
self.is_sd3 = True
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip
self.text_processing_engine_g.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_g, g_pooled = self.text_processing_engine_g(prompt)
cond_l, l_pooled = self.text_processing_engine_l(prompt)
if opts.sd3_enable_t5:
cond_t5 = self.text_processing_engine_t5(prompt)
else:
cond_t5 = torch.zeros([len(prompt), 256, 4096]).to(cond_l.device)
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
if force_zero_negative_prompt:
l_pooled = torch.zeros_like(l_pooled)
g_pooled = torch.zeros_like(g_pooled)
cond_l = torch.zeros_like(cond_l)
cond_g = torch.zeros_like(cond_g)
cond_t5 = torch.zeros_like(cond_t5)
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
cond = dict(
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
vector=torch.cat([l_pooled, g_pooled], dim=-1),
)
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
return token_count, max(255, token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
import torch
from huggingface_guess import model_list
# from huggingface_guess.latent import SD3
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
from backend.modules.k_prediction import PredictionDiscreteFlow
class StableDiffusion3(ForgeDiffusionEngine):
matched_guesses = [model_list.SD35]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
clip = CLIP(
model_dict={
'clip_l': huggingface_components['text_encoder'],
'clip_g': huggingface_components['text_encoder_2'],
't5xxl': huggingface_components['text_encoder_3']
},
tokenizer_dict={
'clip_l': huggingface_components['tokenizer'],
'clip_g': huggingface_components['tokenizer_2'],
't5xxl': huggingface_components['tokenizer_3']
}
)
k_predictor = PredictionDiscreteFlow( shift=3.0)
vae = VAE(model=huggingface_components['vae'])
unet = UnetPatcher.from_model(
model=huggingface_components['transformer'],
diffusers_scheduler= None,
k_predictor=k_predictor,
config=estimated_config
)
self.text_processing_engine_l = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_l,
tokenizer=clip.tokenizer.clip_l,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_l',
embedding_expected_shape=768,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=True,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=True,
final_layer_norm=False,
)
self.text_processing_engine_g = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_g,
tokenizer=clip.tokenizer.clip_g,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_g',
embedding_expected_shape=1280,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=True,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=True,
final_layer_norm=False,
)
self.text_processing_engine_t5 = T5TextProcessingEngine(
text_encoder=clip.cond_stage_model.t5xxl,
tokenizer=clip.tokenizer.t5xxl,
emphasis_name=dynamic_args['emphasis_name'],
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
# WebUI Legacy
self.is_sd3 = True
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip
self.text_processing_engine_g.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_g, g_pooled = self.text_processing_engine_g(prompt)
cond_l, l_pooled = self.text_processing_engine_l(prompt)
# if enabled?
cond_t5 = self.text_processing_engine_t5(prompt)
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
if force_zero_negative_prompt:
l_pooled = torch.zeros_like(l_pooled)
g_pooled = torch.zeros_like(g_pooled)
cond_l = torch.zeros_like(cond_l)
cond_g = torch.zeros_like(cond_g)
cond_t5 = torch.zeros_like(cond_t5)
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
cond = dict(
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
vector=torch.cat([l_pooled, g_pooled], dim=-1),
)
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
return token_count, max(255, token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)

View File

@@ -10,11 +10,6 @@ from backend.args import dynamic_args
from backend import memory_management
from backend.nn.unet import Timestep
import safetensors.torch as sf
from backend import utils
from modules.shared import opts
class StableDiffusionXL(ForgeDiffusionEngine):
matched_guesses = [model_list.SDXL]
@@ -93,8 +88,8 @@ class StableDiffusionXL(ForgeDiffusionEngine):
height = getattr(prompt, 'height', 1024) or 1024
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
crop_w = opts.sdxl_crop_left
crop_h = opts.sdxl_crop_top
crop_w = 0
crop_h = 0
target_width = width
target_height = height
@@ -136,137 +131,3 @@ class StableDiffusionXL(ForgeDiffusionEngine):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SDXL.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename
class StableDiffusionXLRefiner(ForgeDiffusionEngine):
matched_guesses = [model_list.SDXLRefiner]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
clip = CLIP(
model_dict={
'clip_g': huggingface_components['text_encoder']
},
tokenizer_dict={
'clip_g': huggingface_components['tokenizer'],
}
)
vae = VAE(model=huggingface_components['vae'])
unet = UnetPatcher.from_model(
model=huggingface_components['unet'],
diffusers_scheduler=huggingface_components['scheduler'],
config=estimated_config
)
self.text_processing_engine_g = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_g,
tokenizer=clip.tokenizer.clip_g,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_g',
embedding_expected_shape=2048,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=True,
minimal_clip_skip=2,
clip_skip=2,
return_pooled=True,
final_layer_norm=False,
)
self.embedder = Timestep(256)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
# WebUI Legacy
self.is_sdxl = True
def set_clip_skip(self, clip_skip):
self.text_processing_engine_g.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
width = getattr(prompt, 'width', 1024) or 1024
height = getattr(prompt, 'height', 1024) or 1024
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
crop_w = opts.sdxl_crop_left
crop_h = opts.sdxl_crop_top
aesthetic = opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else opts.sdxl_refiner_high_aesthetic_score
out = [
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
self.embedder(torch.Tensor([aesthetic]))
]
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
if force_zero_negative_prompt:
clip_pooled = torch.zeros_like(clip_pooled)
cond_g = torch.zeros_like(cond_g)
cond = dict(
crossattn=cond_g,
vector=torch.cat([clip_pooled, flat], dim=1),
)
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine_g.process_texts([prompt])
return token_count, self.text_processing_engine_g.get_target_prompt_token_count(token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SDXLRefiner.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename

View File

@@ -1,24 +0,0 @@
{
"_class_name": "FluxPipeline",
"_diffusers_version": "0.30.0.dev0",
"scheduler": [
"diffusers",
"FlowMatchEulerDiscreteScheduler"
],
"text_encoder": [
"transformers",
"T5EncoderModel"
],
"tokenizer": [
"transformers",
"T5TokenizerFast"
],
"transformer": [
"diffusers",
"ChromaTransformer2DModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}

View File

@@ -1,11 +0,0 @@
{
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.30.0.dev0",
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 1.0,
"use_dynamic_shifting": false
}

View File

@@ -0,0 +1,40 @@
{
"_class_name": "StableDiffusion3Pipeline",
"_diffusers_version": "0.30.3.dev0",
"scheduler": [
"diffusers",
"FlowMatchEulerDiscreteScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModelWithProjection"
],
"text_encoder_2": [
"transformers",
"CLIPTextModelWithProjection"
],
"text_encoder_3": [
"transformers",
"T5EncoderModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"tokenizer_2": [
"transformers",
"CLIPTokenizer"
],
"tokenizer_3": [
"transformers",
"T5TokenizerFast"
],
"transformer": [
"diffusers",
"SD3Transformer2DModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}

View File

@@ -0,0 +1,6 @@
{
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.29.0.dev0",
"num_train_timesteps": 1000,
"shift": 3.0
}

View File

@@ -0,0 +1,24 @@
{
"architectures": [
"CLIPTextModelWithProjection"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float16",
"transformers_version": "4.41.2",
"vocab_size": 49408
}

View File

@@ -0,0 +1,24 @@
{
"architectures": [
"CLIPTextModelWithProjection"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 20,
"num_hidden_layers": 32,
"pad_token_id": 1,
"projection_dim": 1280,
"torch_dtype": "float16",
"transformers_version": "4.41.2",
"vocab_size": 49408
}

View File

@@ -1,11 +1,16 @@
{
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
@@ -16,7 +21,11 @@
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.41.2",
"use_cache": true,
"vocab_size": 32128
}

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,30 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "<|endoftext|>",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "!",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

View File

@@ -0,0 +1,38 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "!",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "!",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
{
"_class_name": "SD3Transformer2DModel",
"_diffusers_version": "0.31.0.dev0",
"attention_head_dim": 64,
"caption_projection_dim": 2432,
"in_channels": 16,
"joint_attention_dim": 4096,
"num_attention_heads": 38,
"num_layers": 38,
"out_channels": 16,
"patch_size": 2,
"pooled_projection_dim": 2048,
"pos_embed_max_size": 192,
"qk_norm": "rms_norm",
"sample_size": 128
}

View File

@@ -1,7 +1,7 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.30.0.dev0",
"_name_or_path": "../checkpoints/flux-dev",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": "../sdxl-vae/",
"act_fn": "silu",
"block_out_channels": [
128,
@@ -25,8 +25,8 @@
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 1024,
"scaling_factor": 0.3611,
"shift_factor": 0.1159,
"scaling_factor": 1.5305,
"shift_factor": 0.0609,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",

View File

@@ -9,10 +9,18 @@
"EulerDiscreteScheduler"
],
"text_encoder": [
null,
null
],
"text_encoder_2": [
"transformers",
"CLIPTextModelWithProjection"
],
"tokenizer": [
null,
null
],
"tokenizer_2": [
"transformers",
"CLIPTokenizer"
],

View File

@@ -19,13 +19,12 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
from backend.diffusion_engine.sd35 import StableDiffusion3
from backend.diffusion_engine.sdxl import StableDiffusionXL
from backend.diffusion_engine.flux import Flux
from backend.diffusion_engine.chroma import Chroma
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Chroma, Flux]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, StableDiffusion3, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -109,20 +108,17 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
return model
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel', 'ChromaTransformer2DModel']:
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
model_loader = None
if cls_name == 'UNet2DConditionModel':
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
elif cls_name == 'FluxTransformer2DModel':
if cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
elif cls_name == 'ChromaTransformer2DModel':
from backend.nn.chroma import IntegratedChromaTransformer2DModel
model_loader = lambda c: IntegratedChromaTransformer2DModel(**c)
elif cls_name == 'SD3Transformer2DModel':
from backend.nn.mmditx import MMDiTX
if cls_name == 'SD3Transformer2DModel':
from modules.models.sd35.mmditx import MMDiTX
model_loader = lambda c: MMDiTX(**c)
unet_config = guess.unet_config.copy()
@@ -178,7 +174,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return None
def replace_state_dict(sd, asd, guess):
def replace_state_dict(sd, asd, guess, is_clip_g = False):
vae_key_prefix = guess.vae_key_prefix[0]
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
@@ -217,217 +213,19 @@ def replace_state_dict(sd, asd, guess):
for k, v in asd.items():
sd[vae_key_prefix + k] = v
## identify model type
flux_test_key = "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
sd3_test_key = "model.diffusion_model.final_layer.adaLN_modulation.1.bias"
legacy_test_key = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
model_type = "-"
if legacy_test_key in sd:
match sd[legacy_test_key].shape[1]:
case 768:
model_type = "sd1"
case 1024:
model_type = "sd2"
case 1280:
model_type = "xlrf" # sdxl refiner model
case 2048:
model_type = "sdxl"
elif flux_test_key in sd:
model_type = "flux"
elif sd3_test_key in sd:
model_type = "sd3"
## prefixes used by various model types for CLIP-L
prefix_L = {
"-" : None,
"sd1" : "cond_stage_model.transformer.",
"sd2" : None,
"xlrf": None,
"sdxl": "conditioner.embedders.0.transformer.",
"flux": "text_encoders.clip_l.transformer.",
"sd3" : "text_encoders.clip_l.transformer.",
}
## prefixes used by various model types for CLIP-G
prefix_G = {
"-" : None,
"sd1" : None,
"sd2" : None,
"xlrf": "conditioner.embedders.0.model.transformer.",
"sdxl": "conditioner.embedders.1.model.transformer.",
"flux": None,
"sd3" : "text_encoders.clip_g.transformer.",
}
## prefixes used by various model types for CLIP-H
prefix_H = {
"-" : None,
"sd1" : None,
"sd2" : "conditioner.embedders.0.model.",
"xlrf": None,
"sdxl": None,
"flux": None,
"sd3" : None,
}
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
if "first_stage_model.decoder.conv_in.weight" in asd:
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
if model_type == "sd1" or model_type == "sd2" or model_type == "xlrf" or model_type == "sdxl":
if channels == 4:
for k, v in asd.items():
sd[k] = v
elif model_type == "sd3":
if channels == 16:
for k, v in asd.items():
sd[k] = v
## CLIP-H
CLIP_H = { # key to identify source model old_prefix
'cond_stage_model.model.ln_final.weight' : 'cond_stage_model.model.',
# 'text_model.encoder.layers.0.layer_norm1.bias' : 'text_model'. # would need converting
}
for CLIP_key in CLIP_H.keys():
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1024:
new_prefix = prefix_H[model_type]
old_prefix = CLIP_H[CLIP_key]
if new_prefix is not None:
for k, v in asd.items():
new_k = k.replace(old_prefix, new_prefix)
sd[new_k] = v
## CLIP-G
CLIP_G = { # key to identify source model old_prefix
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.transformer.',
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.transformer.',
'text_model.encoder.layers.0.layer_norm1.bias' : '',
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
}
for CLIP_key in CLIP_G.keys():
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280:
new_prefix = prefix_G[model_type]
old_prefix = CLIP_G[CLIP_key]
if new_prefix is not None:
if "resblocks" not in CLIP_key and model_type != "sd3": # need to convert
def convert_transformers(statedict, prefix_from, prefix_to, number):
keys_to_replace = {
"{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding",
"{}text_model.embeddings.token_embedding.weight" : "{}token_embedding.weight",
"{}text_model.final_layer_norm.weight" : "{}ln_final.weight",
"{}text_model.final_layer_norm.bias" : "{}ln_final.bias",
"text_projection.weight" : "{}text_projection",
}
resblock_to_replace = {
"layer_norm1" : "ln_1",
"layer_norm2" : "ln_2",
"mlp.fc1" : "mlp.c_fc",
"mlp.fc2" : "mlp.c_proj",
"self_attn.out_proj" : "attn.out_proj" ,
}
for x in keys_to_replace: # remove trailing 'transformer.' from new prefix
k = x.format(prefix_from)
statedict[keys_to_replace[x].format(prefix_to[:-12])] = statedict.pop(k)
for resblock in range(number):
for y in ["weight", "bias"]:
for x in resblock_to_replace:
k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
statedict[k_to] = statedict.pop(k)
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y)
weightsQ = statedict.pop(k_from)
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.k_proj", y)
weightsK = statedict.pop(k_from)
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y)
weightsV = statedict.pop(k_from)
k_to = "{}resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV))
return statedict
asd = convert_transformers(asd, old_prefix, new_prefix, 32)
for k, v in asd.items():
sd[k] = v
elif old_prefix == "":
for k, v in asd.items():
new_k = new_prefix + k
sd[new_k] = v
else:
for k, v in asd.items():
new_k = k.replace(old_prefix, new_prefix)
sd[new_k] = v
## CLIP-L
CLIP_L = { # key to identify source model old_prefix
'cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'cond_stage_model.transformer.',
'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.',
'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.',
'text_model.encoder.layers.0.layer_norm1.bias' : '',
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
}
for CLIP_key in CLIP_L.keys():
if CLIP_key in asd and asd[CLIP_key].shape[0] == 768:
new_prefix = prefix_L[model_type]
old_prefix = CLIP_L[CLIP_key]
if new_prefix is not None:
if "resblocks" in CLIP_key: # need to convert
def transformers_convert(statedict, prefix_from, prefix_to, number):
keys_to_replace = {
"positional_embedding" : "{}text_model.embeddings.position_embedding.weight",
"token_embedding.weight": "{}text_model.embeddings.token_embedding.weight",
"ln_final.weight" : "{}text_model.final_layer_norm.weight",
"ln_final.bias" : "{}text_model.final_layer_norm.bias",
"text_projection" : "text_projection.weight",
}
resblock_to_replace = {
"ln_1" : "layer_norm1",
"ln_2" : "layer_norm2",
"mlp.c_fc" : "mlp.fc1",
"mlp.c_proj" : "mlp.fc2",
"attn.out_proj" : "self_attn.out_proj",
}
for k in keys_to_replace:
statedict[keys_to_replace[k].format(prefix_to)] = statedict.pop(k)
for resblock in range(number):
for y in ["weight", "bias"]:
for x in resblock_to_replace:
k = "{}resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
statedict[k_to] = statedict.pop(k)
k_from = "{}resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
weights = statedict.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
statedict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return statedict
asd = transformers_convert(asd, old_prefix, new_prefix, 12)
for k, v in asd.items():
sd[k] = v
elif old_prefix == "":
for k, v in asd.items():
new_k = new_prefix + k
sd[new_k] = v
else:
for k, v in asd.items():
new_k = k.replace(old_prefix, new_prefix)
sd[new_k] = v
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
if is_clip_g:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_g.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}clip_g.transformer.{k}"] = v
else:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
@@ -440,8 +238,9 @@ def replace_state_dict(sd, asd, guess):
def preprocess_state_dict(sd):
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
if any("double_block" in k for k in sd.keys()):
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
return sd
@@ -453,16 +252,14 @@ def split_state_dict(sd, additional_state_dicts: list = None):
if isinstance(additional_state_dicts, list):
for asd in additional_state_dicts:
is_clip_g = 'clip_g' in asd
asd = load_torch_file(asd)
sd = replace_state_dict(sd, asd, guess)
del asd
sd = replace_state_dict(sd, asd, guess, is_clip_g)
guess.clip_target = guess.clip_target(sd)
guess.model_type = guess.model_type(sd)
guess.ztsnr = 'ztsnr' in sd
sd = guess.process_vae_state_dict(sd)
state_dict = {
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
@@ -482,18 +279,7 @@ def split_state_dict(sd, additional_state_dicts: list = None):
return state_dict, guess
# To be removed once PR merged on huggingface_guess
chroma_is_in_huggingface_guess = hasattr(huggingface_guess.model_list, "Chroma")
if not chroma_is_in_huggingface_guess:
class GuessChroma:
huggingface_repo = 'Chroma'
unet_extra_config = {
'guidance_out_dim': 3072,
'guidance_hidden_dim': 5120,
'guidance_n_layers': 5
}
unet_remove_config = ['guidance_embed']
@torch.inference_mode()
def forge_loader(sd, additional_state_dicts=None):
try:
@@ -501,17 +287,6 @@ def forge_loader(sd, additional_state_dicts=None):
except:
raise ValueError('Failed to recognize model type!')
if not chroma_is_in_huggingface_guess \
and estimated_config.huggingface_repo == "black-forest-labs/FLUX.1-schnell" \
and "transformer" in state_dicts \
and "distilled_guidance_layer.layers.0.in_layer.bias" in state_dicts["transformer"]:
estimated_config.huggingface_repo = GuessChroma.huggingface_repo
for x in GuessChroma.unet_extra_config:
estimated_config.unet_config[x] = GuessChroma.unet_extra_config[x]
for x in GuessChroma.unet_remove_config:
del estimated_config.unet_config[x]
state_dicts['text_encoder'] = state_dicts['text_encoder_2']
del state_dicts['text_encoder_2']
repo_name = estimated_config.huggingface_repo
local_path = os.path.join(dir_path, 'huggingface', repo_name)
@@ -527,47 +302,15 @@ def forge_loader(sd, additional_state_dicts=None):
if component is not None:
huggingface_components[component_name] = component
yaml_config = None
yaml_config_prediction_type = None
try:
import yaml
from pathlib import Path
config_filename = os.path.splitext(sd)[0] + '.yaml'
if Path(config_filename).is_file():
with open(config_filename, 'r') as stream:
yaml_config = yaml.safe_load(stream)
except ImportError:
pass
# Fix Huggingface prediction type using .yaml config or estimated config detection
# Fix Huggingface prediction type using estimated config detection
prediction_types = {
'EPS': 'epsilon',
'V_PREDICTION': 'v_prediction',
'EDM': 'edm',
}
if 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config:
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
has_prediction_type = 'scheduler' in huggingface_components and hasattr(huggingface_components['scheduler'], 'config') and 'prediction_type' in huggingface_components['scheduler'].config
if yaml_config is not None:
yaml_config_prediction_type: str = (
yaml_config.get('model', {}).get('params', {}).get('parameterization', '')
or yaml_config.get('model', {}).get('params', {}).get('denoiser_config', {}).get('params', {}).get('scaling_config', {}).get('target', '')
)
if yaml_config_prediction_type == 'v' or yaml_config_prediction_type.endswith(".VScaling"):
yaml_config_prediction_type = 'v_prediction'
else:
# Use estimated prediction config if no suitable prediction type found
yaml_config_prediction_type = ''
if has_prediction_type:
if yaml_config_prediction_type:
huggingface_components['scheduler'].config.prediction_type = yaml_config_prediction_type
else:
huggingface_components['scheduler'].config.prediction_type = prediction_types.get(estimated_config.model_type.name, huggingface_components['scheduler'].config.prediction_type)
if not chroma_is_in_huggingface_guess and estimated_config.huggingface_repo == "Chroma":
return Chroma(estimated_config=estimated_config, huggingface_components=huggingface_components)
for M in possible_models:
if any(isinstance(estimated_config, x) for x in M.matched_guesses):
return M(estimated_config=estimated_config, huggingface_components=huggingface_components)

View File

@@ -560,11 +560,6 @@ def unload_model_clones(model):
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
# this check fully unloads any 'abandoned' models
for i in range(len(current_loaded_models) - 1, -1, -1):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
if free_all:
memory_required = 1e30
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")

View File

@@ -2,9 +2,6 @@ import math
import torch
import numpy as np
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.flux.pipeline_flux import calculate_shift
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
betas = []
@@ -44,6 +41,10 @@ def time_snr_shift(alpha, t):
return alpha * t / (1 + (alpha - 1) * t)
def flux_time_shift(mu, sigma, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
@@ -274,31 +275,13 @@ class PredictionDiscreteFlow(AbstractPrediction):
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
class PredictionFlux(AbstractPrediction):
def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None):
super().__init__(sigma_data=1.0, prediction_type='const')
self.mu = mu
self.pseudo_timestep_range = pseudo_timestep_range
self.apply_mu_transform(seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift, mu=mu)
def apply_mu_transform(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, mu=None):
# TODO: Add an UI option to let user choose whether to call this in each generation to bind latent size to sigmas
# And some cases may want their own mu values or other parameters
if mu is None:
self.mu = calculate_shift(image_seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift)
else:
self.mu = mu
sigmas = torch.arange(1, self.pseudo_timestep_range + 1, 1) / self.pseudo_timestep_range
sigmas = FlowMatchEulerDiscreteScheduler.time_shift(None, self.mu, 1.0, sigmas)
self.register_buffer('sigmas', sigmas)
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
@@ -312,7 +295,7 @@ class PredictionFlux(AbstractPrediction):
return sigma
def sigma(self, timestep):
return timestep
return flux_time_shift(self.shift, 1.0, timestep)
def percent_to_sigma(self, percent):
if percent <= 0.0:

View File

@@ -1,307 +0,0 @@
# implementation of Chroma for Forge, inspired by https://github.com/lodestone-rock/ComfyUI_FluxMod
from dataclasses import dataclass
import math
import torch
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix, tensor2parameter
from backend.nn.flux import attention, rope, timestep_embedding, EmbedND, MLPEmbedder, RMSNorm, QKNorm, SelfAttention
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 4):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm( hidden_dim) for x in range( n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
x = self.out_proj(x)
return x
@dataclass
class ModulationOut:
shift: torch.Tensor
scale: torch.Tensor
gate: torch.Tensor
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img, txt, mod, pe):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = mod
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
B, L, _ = img_qkv.shape
H = self.num_heads
D = img_qkv.shape[-1] // (3 * H)
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
B, L, _ = txt_qkv.shape
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt = fp16_fix(txt)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x, mod, pe):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
del x_mod
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe=pe)
del q, k, v, pe
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
del attn, mlp
x = x + mod.gate * output
x = fp16_fix(x)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
def forward(self, x, mod):
shift, scale = mod
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
class IntegratedChromaTransformer2DModel(nn.Module):
def __init__(self, in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_out_dim: int, guidance_hidden_dim: int, guidance_n_layers: int):
super().__init__()
self.in_channels = in_channels * 4
self.out_channels = self.in_channels
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.distilled_guidance_layer = Approximator(64, guidance_out_dim, guidance_hidden_dim, guidance_n_layers)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
)
for _ in range(depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio)
for _ in range(depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
@staticmethod
def distribute_modulations(tensor, single_block_count: int = 38, double_blocks_count: int = 19):
"""
Distributes slices of the tensor into the block_dict as ModulationOut objects.
Args:
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
"""
batch_size, vectors, dim = tensor.shape
block_dict = {}
for i in range(single_block_count):
key = f"single_blocks.{i}.modulation.lin"
block_dict[key] = None
for i in range(double_blocks_count):
key = f"double_blocks.{i}.img_mod.lin"
block_dict[key] = None
for i in range(double_blocks_count):
key = f"double_blocks.{i}.txt_mod.lin"
block_dict[key] = None
block_dict["final_layer.adaLN_modulation.1"] = None
idx = 0 # Index to keep track of the vector slices
for key in block_dict.keys():
if "single_blocks" in key:
# Single block: 1 ModulationOut
block_dict[key] = ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
idx += 3 # Advance by 3 vectors
elif "img_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "txt_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx:idx+1, :],
scale=tensor[:, idx+1:idx+2, :],
gate=tensor[:, idx+2:idx+3, :]
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block
elif "final_layer" in key:
# Final layer: 1 ModulationOut
block_dict[key] = [
tensor[:, idx:idx+1, :],
tensor[:, idx+1:idx+2, :],
]
idx += 2 # Advance by 2 vectors
return block_dict
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps):
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
device = img.device
dtype = img.dtype # torch.bfloat16
nb_double_block = len(self.double_blocks)
nb_single_block = len(self.single_blocks)
mod_index_length = nb_double_block*12 + nb_single_block*3 + 2
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(device=device, dtype=dtype)
distil_guidance = timestep_embedding(torch.zeros_like(timesteps), 16).to(device=device, dtype=dtype)
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(device=device, dtype=dtype)
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
mod_vectors = self.distilled_guidance_layer(input_vec)
mod_vectors_dict = self.distribute_modulations(mod_vectors, nb_single_block, nb_double_block)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
del txt_ids, img_ids
pe = self.pe_embedder(ids)
del ids
for i, block in enumerate(self.double_blocks):
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
double_mod = [img_mod, txt_mod]
img, txt = block(img=img, txt=txt, mod=double_mod, pe=pe)
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
img = block(img, mod=single_mod, pe=pe)
del pe
img = img[:, txt.shape[1]:, ...]
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
img = self.final_layer(img, final_mod)
return img
def forward(self, x, timestep, context, **kwargs):
bs, c, h, w = x.shape
input_device = x.device
input_dtype = x.dtype
patch_size = 2
pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
del x, pad_h, pad_w
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=input_device, dtype=input_dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
del input_device, input_dtype
out = self.inner_forward(img, img_ids, context, txt_ids, timestep)
del img, img_ids, txt_ids, timestep, context
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
del h_len, w_len, bs
return out

View File

@@ -4,11 +4,10 @@ import math
from backend.attention import attention_pytorch as attention_function
from transformers.activations import NewGELUActivation
activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
"relu": torch.nn.functional.relu,
"gelu_new": lambda a: NewGELUActivation()(a),
"gelu_new": lambda a: NewGELUActivation()(a)
}

View File

@@ -12,7 +12,6 @@ stash = {}
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
scale_weight = getattr(layer, 'scale_weight', None)
patches = getattr(layer, 'forge_online_loras', None)
weight_patches, bias_patches = None, None
@@ -33,8 +32,6 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None,
weight = weight_fn(weight)
if weight_args is not None:
weight = weight.to(**weight_args)
if scale_weight is not None:
weight = weight*scale_weight.to(device=weight.device, dtype=weight.dtype)
if weight_patches is not None:
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
@@ -130,7 +127,6 @@ class ForgeOperations:
self.out_features = out_features
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.scale_weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
@@ -138,8 +134,6 @@ class ForgeOperations:
if hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
if prefix + 'scale_weight' in state_dict:
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy

View File

@@ -13,7 +13,6 @@ quants_mapping = {
gguf.GGMLQuantizationType.Q5_K: gguf.Q5_K,
gguf.GGMLQuantizationType.Q6_K: gguf.Q6_K,
gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0,
gguf.GGMLQuantizationType.BF16: gguf.BF16,
}

View File

@@ -49,33 +49,24 @@ def model_lora_keys_unet(model, key_map={}):
@torch.inference_mode()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/80a44b97f5cbcb890896e2b9e65d177f1ac6a588/comfy/weight_adapter/base.py#L42
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
if wd_on_output_axis:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
)
else:
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
weight += strength * weight_calc
else:
weight[:] = weight_calc
return weight
@@ -138,15 +129,10 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "lora":
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
@@ -156,26 +142,12 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))
try:
lora_diff = lora_diff.reshape(weight.shape)
except:
if weight.shape[1] < lora_diff.shape[1]:
expand_factor = (lora_diff.shape[1] - weight.shape[1])
weight = torch.nn.functional.pad(weight, (0, expand_factor), mode='constant', value=0)
elif weight.shape[1] > lora_diff.shape[1]:
# expand factor should be 1*64 (for FluxTools Canny or Depth), or 5*64 (for FluxTools Fill)
expand_factor = (weight.shape[1] - lora_diff.shape[1])
lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0)
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
@@ -220,7 +192,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -258,51 +230,29 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "glora":
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
old_glora = True
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
dora_scale = v[5]
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
if v[4] is None:
alpha = 1.0
else:
if old_glora:
alpha = v[4] / v[0].shape[0]
else:
alpha = v[4] / v[1].shape[0]
try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=computation_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -349,10 +299,10 @@ class LoraLoader:
self.loaded_hash = str([])
@torch.inference_mode()
def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False):
def refresh(self, lora_patches, offload_device=torch.device('cpu')):
hashes = str(list(lora_patches.keys()))
if hashes == self.loaded_hash and not force_refresh:
if hashes == self.loaded_hash:
return
# Merge Patches

View File

@@ -6,8 +6,6 @@ from backend.text_processing import parsing, emphasis
from backend.text_processing.textual_inversion import EmbeddingDatabase
from backend import memory_management
from modules.shared import opts
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
last_extra_generation_params = {}
@@ -69,7 +67,7 @@ class ClassicTextProcessingEngine:
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.emphasis = emphasis.get_current_option(opts.emphasis)()
self.emphasis = emphasis.get_current_option(emphasis_name)()
self.text_projection = text_projection
self.minimal_clip_skip = minimal_clip_skip
self.clip_skip = clip_skip
@@ -141,14 +139,14 @@ class ClassicTextProcessingEngine:
if self.return_pooled:
pooled_output = outputs.pooler_output
if self.text_projection and self.embedding_key != 'clip_l':
if self.text_projection and self.embedding_key is not 'clip_l':
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
z.pooled = pooled_output
return z
def tokenize_line(self, line):
parsed = parsing.parse_prompt_attention(line, self.emphasis.name)
parsed = parsing.parse_prompt_attention(line)
tokenized = self.tokenize([text for text, _ in parsed])
@@ -250,8 +248,6 @@ class ClassicTextProcessingEngine:
return batch_chunks, token_count
def __call__(self, texts):
self.emphasis = emphasis.get_current_option(opts.emphasis)()
batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {}

View File

@@ -20,7 +20,7 @@ re_attention = re.compile(r"""
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text, emphasis):
def parse_prompt_attention(text):
res = []
round_brackets = []
square_brackets = []
@@ -32,48 +32,44 @@ def parse_prompt_attention(text, emphasis):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
if emphasis == "None":
# interpret literally
res = [[text, 1.0]]
else:
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and round_brackets:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and round_brackets:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and square_brackets:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and round_brackets:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and round_brackets:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and square_brackets:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
if len(res) == 0:
res = [["", 1.0]]
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res

View File

@@ -4,8 +4,6 @@ from collections import namedtuple
from backend.text_processing import parsing, emphasis
from backend import memory_management
from modules.shared import opts
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
@@ -23,7 +21,7 @@ class T5TextProcessingEngine:
self.text_encoder = text_encoder.transformer
self.tokenizer = tokenizer
self.emphasis = emphasis.get_current_option(opts.emphasis)()
self.emphasis = emphasis.get_current_option(emphasis_name)()
self.min_length = min_length
self.id_end = 1
self.id_pad = 0
@@ -66,7 +64,7 @@ class T5TextProcessingEngine:
return z
def tokenize_line(self, line):
parsed = parsing.parse_prompt_attention(line, self.emphasis.name)
parsed = parsing.parse_prompt_attention(line)
tokenized = self.tokenize([text for text, _ in parsed])
@@ -113,8 +111,6 @@ class T5TextProcessingEngine:
zs = []
cache = {}
self.emphasis = emphasis.get_current_option(opts.emphasis)()
for line in texts:
if line in cache:
line_z_values = cache[line]

View File

@@ -25,11 +25,6 @@ class ExtraOptionsSection(scripts.Script):
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
not_allowed = ['sd_model_checkpoint', 'sd_vae', 'CLIP_stop_at_last_layers', 'forge_additional_modules']
for na in not_allowed:
if na in extra_options:
extra_options.remove(na)
mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:

View File

@@ -1,287 +1,69 @@
import spaces
import os
import gradio as gr
import gc
try:
import moviepy.editor as mp
got_mp = True
except:
got_mp = False
from gradio_imageslider import ImageSlider
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import glob
import pathlib
from PIL import Image
import numpy
# torch.set_float32_matmul_precision(["high", "highest"][0])
transform_image = None
birefnet = None
def load_model(model):
global birefnet
birefnet = None
gc.collect()
torch.cuda.empty_cache()
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/"+model, trust_remote_code=True
)
birefnet.eval()
birefnet.half()
spaces.automatically_move_to_gpu_when_forward(birefnet)
os.environ['HOME'] = spaces.convert_root_path() + 'home'
with spaces.capture_gpu_object() as birefnet_gpu_obj:
load_model("BiRefNet_HR")
def common_setup(w, h):
global transform_image
transform_image = transforms.Compose(
[
transforms.Resize((w, h)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
spaces.automatically_move_to_gpu_when_forward(birefnet)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def process(image, save_flat, bg_colour):
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
input_images = transform_image(image).unsqueeze(0).to(spaces.gpu)
# Prediction
with torch.no_grad():
preds = birefnet(input_image)[-1].sigmoid().cpu()
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
if save_flat:
bg_colour += "FF"
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
background = Image.new("RGBA", image_size, colour_rgb)
image = Image.alpha_composite(background, image)
image = image.convert("RGB")
return image
# video processing based on https://huggingface.co/spaces/brokerrobin/video-background-removal/blob/main/app.py
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def video_process(video, bg_colour):
# Load the video using moviepy
video = mp.VideoFileClip(video)
fps = video.fps
# Extract audio from the video
audio = video.audio
# Extract frames at the specified FPS
frames = video.iter_frames(fps=fps)
# Process each frame for background removal
processed_frames = []
for i, frame in enumerate(frames):
print (f"birefnet [video]: frame {i+1}", end='\r', flush=True)
image = Image.fromarray(frame)
if i == 0:
image_size = image.size
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5))
background = Image.new("RGBA", image_size, colour_rgb + (255,))
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
# Prediction
with torch.no_grad():
preds = birefnet(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Apply mask and composite
image.putalpha(mask)
processed_image = Image.alpha_composite(background, image)
processed_frames.append(numpy.array(processed_image))
# Create a new video from the processed frames
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
# Add the original audio back to the processed video
processed_video = processed_video.set_audio(audio)
# Save the processed video using modified original filename (goes to gradio temp)
filename, _ = os.path.splitext(video.filename)
filename += "-birefnet.mp4"
processed_video.write_videofile(filename, codec="libx264")
return filename
return (image, origin)
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def batch_process(input_folder, output_folder, save_png, save_flat, bg_colour):
# Ensure output folder exists
os.makedirs(output_folder, exist_ok=True)
# Supported image extensions
image_extensions = ['.jpg', '.jpeg', '.jfif', '.png', '.bmp', '.webp', ".avif"]
# Collect all image files from input folder
input_images = []
for ext in image_extensions:
input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}')))
if save_flat:
bg_colour += "FF"
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
# Process each image
processed_images = []
for i, image_path in enumerate(input_images):
print (f"birefnet [batch]: image {i+1}", end='\r', flush=True)
try:
# Load image
im = load_img(image_path, output_type="pil")
im = im.convert("RGB")
image_size = im.size
image = load_img(im)
# Prepare image for processing
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
# Prediction
with torch.no_grad():
preds = birefnet(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# Apply mask
image.putalpha(mask)
# Save processed image
output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}")
if save_flat:
background = Image.new("RGBA", image_size, colour_rgb)
image = Image.alpha_composite(background, image)
image = image.convert("RGB")
elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"):
# jpegs don't support alpha channel, so add .png extension (not change, to avoid potential overwrites)
output_filename += ".png"
if save_png and not output_filename.lower().endswith(".png"):
output_filename += ".png"
image.save(output_filename)
processed_images.append(output_filename)
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
return processed_images
slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
def unload():
global birefnet, transform_image
birefnet = None
transform_image = None
gc.collect()
torch.cuda.empty_cache()
chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never"
)
tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never")
css = """
.gradio-container {
max-width: 1280px !important;
}
footer {
display: none !important;
}
"""
with gr.Blocks(css=css, analytics_enabled=False) as demo:
gr.Markdown("# birefnet for background removal")
with gr.Tab("image"):
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload an image", type='pil', height=584)
go_image = gr.Button("Remove background")
with gr.Column():
result1 = gr.Image(label="birefnet", type="pil", height=544)
with gr.Tab("URL"):
with gr.Row():
with gr.Column():
text = gr.Textbox(label="URL to image, or local path to image", max_lines=1)
go_text = gr.Button("Remove background")
with gr.Column():
result2 = gr.Image(label="birefnet", type="pil", height=544)
if got_mp:
with gr.Tab("video"):
with gr.Row():
with gr.Column():
video = gr.Video(label="Upload a video", height=584)
go_video = gr.Button("Remove background")
with gr.Column():
result4 = gr.Video(label="birefnet", height=544, show_share_button=False)
with gr.Tab("batch"):
with gr.Row():
with gr.Column():
input_dir = gr.Textbox(label="Input folder path", max_lines=1)
output_dir = gr.Textbox(label="Output folder path (save images will overwrite)", max_lines=1)
always_png = gr.Checkbox(label="Always save as PNG", value=True)
go_batch = gr.Button("Remove background(s)")
with gr.Column():
result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple")
with gr.Tab("options"):
gr.Markdown("*HR* : high resolution; *matting* : better with transparency; *lite* : faster.")
model = gr.Dropdown(label="Model (download on selection, see console for progress)",
choices=["BiRefNet_512x512", "BiRefNet", "BiRefNet_HR", "BiRefNet-matting", "BiRefNet_HR-matting", "BiRefNet_lite", "BiRefNet_lite-2K", "BiRefNet-portrait", "BiRefNet-COD", "BiRefNet-DIS5K", "BiRefNet-DIS5k-TR_TEs", "BiRefNet-HRSOD"], value="BiRefNet_HR", type="value")
gr.Markdown("Regular models trained at 1024 \u00D7 1024; HR models trained at 2048 \u00D7 2048; 2K model trained at 2560 \u00D7 1440.")
gr.Markdown("Greater processing image size will typically give more accurate results, but also requires more VRAM (shared memory works well).")
with gr.Row():
proc_sizeW = gr.Slider(label="birefnet processing image width",
minimum=256, maximum=2560, value=2048, step=32)
proc_sizeH = gr.Slider(label="birefnet processing image height",
minimum=256, maximum=2048, value=2048, step=32)
with gr.Row():
save_flat = gr.Checkbox(label="Save flat (no mask)", value=False)
bg_colour = gr.ColorPicker(label="Background colour for saving flat, and video", value="#00FF00", visible=True, interactive=True)
model.change(fn=load_model, inputs=model, outputs=None)
gr.Markdown("### https://github.com/ZhengPeng7/BiRefNet\n### https://huggingface.co/ZhengPeng7")
go_image.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[image, save_flat, bg_colour], outputs=result1)
go_text.click( fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[text, save_flat, bg_colour], outputs=result2)
if got_mp:
go_video.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
fn=video_process, inputs=[video, bg_colour], outputs=result4)
go_batch.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat, bg_colour], outputs=result3)
demo.unload(unload)
demo = gr.TabbedInterface(
[tab1, tab2], ["image", "text"], title="birefnet for background removal"
)
if __name__ == "__main__":
demo.launch(inbrowser=True)

View File

@@ -3,6 +3,7 @@ import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import os
import requests
import copy
from PIL import Image, ImageDraw, ImageFont
@@ -13,22 +14,33 @@ import matplotlib.patches as patches
import random
import numpy as np
from modules import shared
# import subprocess
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
with spaces.capture_gpu_object() as gpu_object:
models = {
# 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', attn_implementation='sdpa', trust_remote_code=True).to("cuda").eval(),
'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
# 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
}
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
models = {
# 'microsoft/Florence-2-large-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large-ft', attn_implementation='sdpa', trust_remote_code=True).to("cuda").eval(),
'microsoft/Florence-2-large': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to("cuda").eval(),
# 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True).to("cuda").eval(),
# 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to("cuda").eval(),
}
processors = {
# 'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
# 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
# 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
}
@@ -122,8 +134,8 @@ def draw_ocr_bboxes(image, prediction):
fill=color)
return image
def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
image = Image.fromarray(image) # Convert NumPy array to PIL Image
if task_prompt == 'Caption':
task_prompt = '<CAPTION>'
results = run_example(task_prompt, image, model_id=model_id)
@@ -222,61 +234,6 @@ def process_image(image, task_prompt, text_input=None, model_id='microsoft/Flore
else:
return "", None # Return empty string and None for unknown task prompts
@spaces.GPU(gpu_objects=[gpu_object], manual_load=False)
def run_example_batch(directory, task_prompt, model_id='microsoft/Florence-2-large', save_caption=False, prefix=""):
model = models[model_id]
processor = processors[model_id]
match task_prompt:
case 'More Detailed Caption':
prompt = '<MORE_DETAILED_CAPTION>'
case 'Detailed Caption':
prompt = '<DETAILED_CAPTION>'
case 'Caption':
prompt = '<CAPTION>'
case _:
prompt = '<CAPTION>'
results = ""
# batch_images block lifted from modules/img2img.py
if isinstance(directory, str):
batch_images = list(shared.walk_files(directory, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff", ".avif")))
else:
batch_images = [os.path.abspath(x.name) for x in directory]
for file in batch_images:
image = Image.open(file)
if image:
inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
caption_text = prefix + parsed_answer[task_prompt] # prefix imput add
results += f"File: {file}\nCaption: {caption_text}\n\n"
if save_caption:
caption_file = file + ".txt"
with open(caption_file, 'w') as f:
f.write(caption_text)
return results
css = """
#output {
height: 500px;
@@ -285,9 +242,6 @@ css = """
}
"""
caption_task_list = [
'Caption', 'Detailed Caption', 'More Detailed Caption'
]
single_task_list =[
'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
@@ -297,29 +251,30 @@ single_task_list =[
'OCR', 'OCR with Region'
]
cascaded_task_list =[
cascased_task_list =[
'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
]
def update_task_dropdown(choice):
if choice == 'Cascaded task':
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
if choice == 'Cascased task':
return gr.Dropdown(choices=cascased_task_list, value='Caption + Grounding')
else:
return gr.Dropdown(choices=single_task_list, value='Caption')
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Florence-2 Image Captioning"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input picture", type="pil")
input_img = gr.Image(label="Input Picture")
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value=list(models.keys())[0])
task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
task_prompt = gr.Dropdown(choices=single_task_list, label="Task prompt", value="More Detailed Caption")
task_type = gr.Radio(choices=['Single task', 'Cascased task'], label='Task type selector', value='Single task')
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="More Detailed Caption")
task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
text_input = gr.Textbox(label="Text input (optional)")
text_input = gr.Textbox(label="Text Input (optional)")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
@@ -339,25 +294,6 @@ with gr.Blocks(css=css) as demo:
submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
with gr.Tab(label="Batch captioning"):
with gr.Row():
with gr.Column():
input_directory = gr.Textbox(label="Input directory")
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value=list(models.keys())[0])
task_prompt = gr.Dropdown(choices=caption_task_list, label="Task prompt", value="More Detailed Caption")
save_captions = gr.Checkbox(label="Save captions to textfiles (same filename, same directory)", value=False)
prefix_input = gr.Textbox(label="Prefix to add to captions (optional)")
batch_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output captions")
batch_btn.click(
fn=run_example_batch,
inputs=[input_directory, task_prompt, model_selector, save_captions, prefix_input],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()
demo.launch(debug=True)

View File

@@ -1,6 +1,5 @@
import spaces
import math
import os
import gradio as gr
import numpy as np
import torch
@@ -133,11 +132,6 @@ i2i_pipe = StableDiffusionImg2ImgPipeline(
spaces.automatically_move_pipeline_components(t2i_pipe)
def overwrite_components(components):
global tokenizer, text_encoder, vae, unet, t2i_pipe, i2i_pipe
tokenizer, text_encoder, vae, unet, t2i_pipe, i2i_pipe = components
@torch.inference_mode()
def encode_prompt_inner(txt: str):
max_length = tokenizer.model_max_length
@@ -240,14 +234,8 @@ def run_rmbg(img, sigma=0.0):
return result.clip(0, 255).astype(np.uint8), alpha
external_processor = None
@torch.inference_mode()
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
if external_processor is not None:
return external_processor(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
bg_source = BGSource(bg_source)
input_bg = None

View File

@@ -2,5 +2,5 @@
"tag": "Face Swap, Human Identification, and Style Transfer",
"title": "PhotoMaker V2: Improved ID Fidelity and Better Controllability",
"repo_id": "TencentARC/PhotoMaker-V2",
"revision": "745c135ad240f80e168e21db53e7acf9605edcb5"
"revision": "abbf3252f4188b0373bb5384db31f429be40176c"
}

View File

@@ -1,7 +1,6 @@
import gradio as gr
from modules import scripts, shared
from modules.ui_components import InputAccordion
from modules import scripts
from backend.misc.image_resize import adaptive_resize
@@ -15,16 +14,11 @@ class PatchModelAddDownscale:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = adaptive_resize(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
shared.kohya_shrink_shape = (h.shape[-1], h.shape[-2])
shared.kohya_shrink_shape_out = None
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = adaptive_resize(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
shared.kohya_shrink_shape_out = (h.shape[-1], h.shape[-2])
return h, hsp
m = model.clone()
@@ -50,28 +44,15 @@ class KohyaHRFixForForge(scripts.Script):
def ui(self, *args, **kwargs):
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
with InputAccordion(False, label=self.title()) as enabled:
with gr.Row():
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
with gr.Row():
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
with gr.Accordion(open=False, label=self.title()):
enabled = gr.Checkbox(label='Enabled', value=False)
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
downscale_after_skip = gr.Checkbox(label='Downscale After Skip', value=True)
with gr.Row():
downscale_method = gr.Dropdown(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
upscale_method = gr.Dropdown(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
self.infotext_fields = [
(enabled, lambda d: d.get("kohya_hrfix_enabled", False)),
(block_number, "kohya_hrfix_block_number"),
(downscale_factor, "kohya_hrfix_downscale_factor"),
(start_percent, "kohya_hrfix_start_percent"),
(end_percent, "kohya_hrfix_end_percent"),
(downscale_after_skip, "kohya_hrfix_downscale_after_skip"),
(downscale_method, "kohya_hrfix_downscale_method"),
(upscale_method, "kohya_hrfix_upscale_method"),
]
downscale_method = gr.Radio(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
upscale_method = gr.Radio(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
return enabled, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method

View File

@@ -1,18 +1,13 @@
import gradio as gr
from modules import scripts
from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
from backend.patcher.base import set_model_options_patch_replace
from backend.sampling.sampling_function import calc_cond_uncond_batch
from modules.ui_components import InputAccordion
class PerturbedAttentionGuidanceForForge(scripts.Script):
sorting_priority = 13
attenuated_scale = 3.0
doPAG = True
def title(self):
return "PerturbedAttentionGuidance Integrated"
@@ -20,82 +15,43 @@ class PerturbedAttentionGuidanceForForge(scripts.Script):
return scripts.AlwaysVisible
def ui(self, *args, **kwargs):
with InputAccordion(False, label=self.title()) as enabled:
with gr.Row():
scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0)
attenuation = gr.Slider(label='Attenuation (linear, % of scale)', minimum=0.0, maximum=100.0, step=0.1, value=0.0)
with gr.Row():
start_step = gr.Slider(label='Start step', minimum=0.0, maximum=1.0, step=0.01, value=0.0)
end_step = gr.Slider(label='End step', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
with gr.Accordion(open=False, label=self.title()):
enabled = gr.Checkbox(label='Enabled', value=False)
scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0)
self.infotext_fields = [
(enabled, lambda d: d.get("pagi_enabled", False)),
(scale, "pagi_scale"),
(attenuation, "pagi_attenuation"),
(start_step, "pagi_start_step"),
(end_step, "pagi_end_step"),
]
return enabled, scale, attenuation, start_step, end_step
def denoiser_callback(self, params):
thisStep = (params.sampling_step) / (params.total_sampling_steps - 1)
if thisStep >= PerturbedAttentionGuidanceForForge.PAG_start and thisStep <= PerturbedAttentionGuidanceForForge.PAG_end:
PerturbedAttentionGuidanceForForge.doPAG = True
else:
PerturbedAttentionGuidanceForForge.doPAG = False
return enabled, scale
def process_before_every_sampling(self, p, *script_args, **kwargs):
enabled, scale, attenuation, start_step, end_step = script_args
enabled, scale = script_args
if not enabled:
return
PerturbedAttentionGuidanceForForge.scale = scale
PerturbedAttentionGuidanceForForge.PAG_start = start_step
PerturbedAttentionGuidanceForForge.PAG_end = end_step
on_cfg_denoiser(self.denoiser_callback)
unet = p.sd_model.forge_objects.unet.clone()
def attn_proc(q, k, v, to):
return v
def post_cfg_function(args):
denoised = args["denoised"]
model, cond_denoised, cond, denoised, sigma, x = \
args["model"], args["cond_denoised"], args["cond"], args["denoised"], args["sigma"], args["input"]
if PerturbedAttentionGuidanceForForge.scale <= 0.0:
return denoised
new_options = set_model_options_patch_replace(args["model_options"], attn_proc, "attn1", "middle", 0)
if not PerturbedAttentionGuidanceForForge.doPAG:
if scale == 0:
return denoised
model, cond_denoised, cond, sigma, x, options = \
args["model"], args["cond_denoised"], args["cond"], args["sigma"], args["input"], args["model_options"].copy()
new_options = set_model_options_patch_replace(options, attn_proc, "attn1", "middle", 0)
degraded, _ = calc_cond_uncond_batch(model, cond, None, x, sigma, new_options)
result = denoised + (cond_denoised - degraded) * PerturbedAttentionGuidanceForForge.scale
PerturbedAttentionGuidanceForForge.scale -= scale * attenuation / 100.0
return result
return denoised + (cond_denoised - degraded) * scale
unet.set_model_sampler_post_cfg_function(post_cfg_function)
p.sd_model.forge_objects.unet = unet
p.extra_generation_params.update(dict(
pagi_enabled = enabled,
pagi_scale = scale,
pagi_attenuation = attenuation,
pagi_start_step = start_step,
pagi_end_step = end_step,
PerturbedAttentionGuidance_enabled=enabled,
PerturbedAttentionGuidance_scale=scale,
))
return
def postprocess(self, params, processed, *args):
remove_current_script_callbacks()
return

View File

@@ -62,21 +62,12 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
# original method: works for all normal inputs that *do not* have Kohya HRFix scaling; typically fails with scaling
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
h = math.ceil(lh / ratio)
w = math.ceil(lw / ratio)
if h * w != mask.size(1):
kohya_shrink_shape = getattr(shared, 'kohya_shrink_shape', None)
if kohya_shrink_shape:
w = kohya_shrink_shape[0] # works with all block numbers for kohya hrfix
h = kohya_shrink_shape[1]
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, h, w)
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
@@ -201,14 +192,11 @@ class SAGForForge(scripts.Script):
# not for FLux
if not shared.sd_model.is_webui_legacy_model(): # ideally would be is_flux
print("Self Attention Guidance is not compatible with Flux")
gr.Info ("Self Attention Guidance is not compatible with Flux")
return
# Self Attention Guidance errors if CFG is 1
if p.is_hr_pass == False and p.cfg_scale <= 1:
print("Self Attention Guidance requires CFG > 1")
return
if p.is_hr_pass == True and p.hr_cfg <= 1:
print("Self Attention Guidance (hires pass) requires Hires CFG > 1")
if p.cfg_scale == 1:
gr.Info ("Self Attention Guidance requires CFG > 1")
return
unet = p.sd_model.forge_objects.unet

View File

@@ -5,9 +5,6 @@ from modules.ui_components import InputAccordion
import modules.scripts as scripts
from modules.torch_utils import float64
from concurrent.futures import ThreadPoolExecutor
from scipy.ndimage import convolve
from joblib import Parallel, delayed, cpu_count
class SoftInpaintingSettings:
def __init__(self,
@@ -247,76 +244,7 @@ def apply_masks(
return masks_for_overlay
def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width):
"""
Apply the weighted histogram filter to a single pixel.
This function is now refactored to be accessible for parallelization.
"""
idx = np.array(idx)
kernel_min = -kernel_center
kernel_max = np.array(kernel.shape) - kernel_center
# Precompute the minimum and maximum valid indices for the kernel
min_index = np.maximum(0, idx + kernel_min)
max_index = np.minimum(np.array(img.shape), idx + kernel_max)
window_shape = max_index - min_index
# Initialize values and weights arrays
values = []
weights = []
for window_tup in np.ndindex(*window_shape):
window_index = np.array(window_tup)
image_index = window_index + min_index
centered_kernel_index = image_index - idx
kernel_index = centered_kernel_index + kernel_center
values.append(img[tuple(image_index)])
weights.append(kernel[tuple(kernel_index)])
# Convert to NumPy arrays
values = np.array(values)
weights = np.array(weights)
# Sort values and weights by values
sorted_indices = np.argsort(values)
values = values[sorted_indices]
weights = weights[sorted_indices]
# Calculate cumulative weights
cumulative_weights = np.cumsum(weights)
# Define window boundaries
sum_weights = cumulative_weights[-1]
window_min = sum_weights * percentile_min
window_max = sum_weights * percentile_max
window_width = window_max - window_min
# Ensure window is at least `min_width` wide
if window_width < min_width:
window_center = (window_min + window_max) / 2
window_min = window_center - min_width / 2
window_max = window_center + min_width / 2
if window_max > sum_weights:
window_max = sum_weights
window_min = sum_weights - min_width
if window_min < 0:
window_min = 0
window_max = min_width
# Calculate overlap for each value
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
overlap_end = np.minimum(window_max, cumulative_weights)
overlap = np.maximum(0, overlap_end - overlap_start)
# Weighted average calculation
result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
return result
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1):
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
"""
Generalization convolution filter capable of applying
weighted mean, median, maximum, and minimum filters
@@ -343,74 +271,101 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
"""
# Ensure kernel_center is a 1D array
if isinstance(kernel_center, int):
kernel_center = np.array([kernel_center, kernel_center])
elif len(kernel_center) == 1:
kernel_center = np.array([kernel_center[0], kernel_center[0]])
kernel_radius = max(kernel_center)
padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0)
img_out = np.zeros_like(img)
img_shape = img.shape
pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])]
# Converts an index tuple into a vector.
def vec(x):
return np.array(x)
kernel_min = -kernel_center
kernel_max = vec(kernel.shape) - kernel_center
def weighted_histogram_filter_single(idx):
"""
Single-pixel weighted histogram calculation.
"""
row, col = idx
idx = (row + kernel_radius, col + kernel_radius)
min_index = np.array(idx) - kernel_center
max_index = min_index + kernel.shape
idx = vec(idx)
min_index = np.maximum(0, idx + kernel_min)
max_index = np.minimum(vec(img.shape), idx + kernel_max)
window_shape = max_index - min_index
window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]]
window_values = window.flatten()
window_weights = kernel.flatten()
class WeightedElement:
"""
An element of the histogram, its weight
and bounds.
"""
sorted_indices = np.argsort(window_values)
values = window_values[sorted_indices]
weights = window_weights[sorted_indices]
def __init__(self, value, weight):
self.value: float = value
self.weight: float = weight
self.window_min: float = 0.0
self.window_max: float = 1.0
cumulative_weights = np.cumsum(weights)
sum_weights = cumulative_weights[-1]
window_min = max(0, sum_weights * percentile_min)
window_max = min(sum_weights, sum_weights * percentile_max)
# Collect the values in the image as WeightedElements,
# weighted by their corresponding kernel values.
values = []
for window_tup in np.ndindex(tuple(window_shape)):
window_index = vec(window_tup)
image_index = window_index + min_index
centered_kernel_index = image_index - idx
kernel_index = centered_kernel_index + kernel_center
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
values.append(element)
def sort_key(x: WeightedElement):
return x.value
values.sort(key=sort_key)
# Calculate the height of the stack (sum)
# and each sample's range they occupy in the stack
sum = 0
for i in range(len(values)):
values[i].window_min = sum
sum += values[i].weight
values[i].window_max = sum
# Calculate what range of this stack ("window")
# we want to get the weighted average across.
window_min = sum * percentile_min
window_max = sum * percentile_max
window_width = window_max - window_min
# Ensure the window is within the stack and at least a certain size.
if window_width < min_width:
window_center = (window_min + window_max) / 2
window_min = max(0, window_center - min_width / 2)
window_max = min(sum_weights, window_center + min_width / 2)
window_min = window_center - min_width / 2
window_max = window_center + min_width / 2
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
overlap_end = np.minimum(window_max, cumulative_weights)
overlap = np.maximum(0, overlap_end - overlap_start)
if window_max > sum:
window_max = sum
window_min = sum - min_width
return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
if window_min < 0:
window_min = 0
window_max = min_width
# Split pixel_coords into equal chunks based on n_jobs
n_jobs = -1
if cpu_count() > 6:
n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px
value = 0
value_weight = 0
chunk_size = len(pixel_coords) // n_jobs
pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)]
# Get the weighted average of all the samples
# that overlap with the window, weighted
# by the size of their overlap.
for i in range(len(values)):
if window_min >= values[i].window_max:
continue
if window_max <= values[i].window_min:
break
# joblib to process chunks in parallel
def process_chunk(chunk):
chunk_result = {}
for idx in chunk:
chunk_result[idx] = weighted_histogram_filter_single(idx)
return chunk_result
s = max(window_min, values[i].window_min)
e = min(window_max, values[i].window_max)
w = e - s
results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration
delayed(process_chunk)(chunk) for chunk in pixel_chunks
)
value += values[i].value * w
value_weight += w
# Combine results into the output image
for chunk_result in results:
for (row, col), value in chunk_result.items():
img_out[row, col] = value
return value / value_weight if value_weight != 0 else 0
img_out = img.copy()
# Apply the kernel operation over each pixel.
for index in np.ndindex(img.shape):
img_out[index] = weighted_histogram_filter_single(index)
return img_out
@@ -530,7 +485,7 @@ el_ids = SoftInpaintingSettings(
class Script(scripts.Script):
def __init__(self):
self.section = "inpaint"
# self.section = "inpaint"
self.masks_for_overlay = None
self.overlay_images = None

View File

@@ -177,7 +177,7 @@ function modalTileImageToggle(event) {
}
onAfterUiUpdate(function() {
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > button > button > img, .gradio-gallery > .livePreview');
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > button > button > img');
if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox);
}

View File

@@ -7,7 +7,6 @@ from torchdiffeq import odeint
import torchsde
from tqdm.auto import trange, tqdm
from k_diffusion import deis
from backend.modules.k_prediction import PredictionFlux
from . import utils
@@ -140,8 +139,6 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
if isinstance(model.inner_model.predictor, PredictionFlux):
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
@@ -158,32 +155,6 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
#sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if eta > 0:
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
@@ -248,8 +219,6 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
if isinstance(model.inner_model.predictor, PredictionFlux):
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
@@ -275,38 +244,6 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:

View File

@@ -21,9 +21,7 @@ from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing,
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images
import modules.textual_inversion.textual_inversion
from modules.shared import cmd_opts
from modules.textual_inversion.textual_inversion import create_embedding
from PIL import PngImagePlugin
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -122,7 +120,7 @@ def encode_pil_to_base64(image):
if opts.samples_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
else:
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
else:
raise HTTPException(status_code=500, detail="Invalid image format")
@@ -267,10 +265,6 @@ class Api:
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
def add_api_route(self, path: str, endpoint, **kwargs):
@@ -750,6 +744,8 @@ class Api:
return styleList
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {
"step": embedding.step,
@@ -763,13 +759,13 @@ class Api:
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {
"loaded": convert_embeddings(self.embedding_db.word_embeddings),
"skipped": convert_embeddings(self.embedding_db.skipped_embeddings),
"loaded": convert_embeddings(db.word_embeddings),
"skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_embeddings(self):
with self.queue_lock:
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
def refresh_checkpoints(self):
with self.queue_lock:
@@ -782,14 +778,15 @@ class Api:
def create_embedding(self, args: dict):
try:
shared.state.begin(job="create_embedding")
filename = modules.textual_inversion.textual_inversion.create_embedding(**args) # create empty embedding
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) # reload embeddings so new one can be immediately used
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
return models.TrainResponse(info=f"create embedding error: {e}")
finally:
shared.state.end()
def create_hypernetwork(self, args: dict):
try:
shared.state.begin(job="create_hypernetwork")

View File

@@ -15,7 +15,6 @@ parser.add_argument("--update-check", action='store_true', help="launch.py argum
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-google-blockly", action='store_true', help="launch.py argument: do not initialize google blockly modules")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)

View File

@@ -14,7 +14,7 @@ class UpscalerDAT(Upscaler):
self.scalers = []
super().__init__()
for file in self.find_models(ext_filter=[".pt", ".pth", ".safetensors"]):
for file in self.find_models(ext_filter=[".pt", ".pth"]):
name = modelloader.friendly_name(file)
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
self.scalers.append(scaler_data)
@@ -51,18 +51,7 @@ class UpscalerDAT(Upscaler):
scaler.local_data_path = modelloader.load_file_from_url(
scaler.data_path,
model_dir=self.model_download_path,
hash_prefix=scaler.sha256,
)
if os.path.getsize(scaler.local_data_path) < 200:
# Re-download if the file is too small, probably an LFS pointer
scaler.local_data_path = modelloader.load_file_from_url(
scaler.data_path,
model_dir=self.model_download_path,
hash_prefix=scaler.sha256,
re_download=True,
)
if not os.path.exists(scaler.local_data_path):
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
return scaler
@@ -73,23 +62,20 @@ def get_dat_models(scaler):
return [
UpscalerData(
name="DAT x2",
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth",
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
scale=2,
upscaler=scaler,
sha256='7760aa96e4ee77e29d4f89c3a4486200042e019461fdb8aa286f49aa00b89b51',
),
UpscalerData(
name="DAT x3",
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth",
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
scale=3,
upscaler=scaler,
sha256='581973e02c06f90d4eb90acf743ec9604f56f3c2c6f9e1e2c2b38ded1f80d197',
),
UpscalerData(
name="DAT x4",
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth",
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
scale=4,
upscaler=scaler,
sha256='391a6ce69899dff5ea3214557e9d585608254579217169faf3d4c353caff049e',
),
]

View File

@@ -54,7 +54,7 @@ device_interrogate: torch.device = memory_management.text_encoder_device() # fo
device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system
dtype: torch.dtype = torch.float32 if memory_management.unet_dtype() is torch.float32 else torch.float16
dtype: torch.dtype = memory_management.unet_dtype()
dtype_vae: torch.dtype = memory_management.vae_dtype()
dtype_unet: torch.dtype = memory_management.unet_dtype()
dtype_inference: torch.dtype = memory_management.unet_dtype()

View File

@@ -13,7 +13,7 @@ class UpscalerESRGAN(Upscaler):
self.scalers = []
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth", ".safetensors"])
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)

View File

@@ -20,8 +20,7 @@ def calculate_sha256_real(filename):
def calculate_sha256(filename):
print("Calculating real hash: ", filename)
return calculate_sha256_real(filename)
return forge_fake_calculate_sha256(filename)
def forge_fake_calculate_sha256(filename):
@@ -60,8 +59,8 @@ def sha256(filename, title, use_addnet_hash=False):
if shared.cmd_opts.no_hashing:
return None
print(f"Calculating sha256 for {filename}: ", end='', flush=True)
sha256_value = calculate_sha256_real(filename)
print(f"Calculating sha256 for {filename}: ", end='')
sha256_value = forge_fake_calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {

View File

@@ -13,15 +13,13 @@ import numpy as np
import piexif
import piexif.helper
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
from PIL import __version__ as pillow_version
from pkg_resources import parse_version
# pillow_avif needs to be imported somewhere in code for it to work
import pillow_avif # noqa: F401
import string
import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors, stealth_infotext
from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
@@ -170,18 +168,9 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
for line in lines:
fnt = initial_fnt
fontsize = initial_fontsize
if parse_version(pillow_version) >= parse_version('10.0.0'):
# New code for Pillow 10.0.0+
text_width, text_height = drawing.multiline_textbbox((0, 0), line.text, font=fnt)[2:]
while text_width > line.allowed_width and fontsize > 0:
fontsize -= 1
fnt = get_font(fontsize)
text_width, text_height = drawing.multiline_textbbox((0, 0), line.text, font=fnt)[2:]
else:
# Old code for Pillow versions below 10.0.0
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
fontsize -= 1
fnt = get_font(fontsize)
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
fontsize -= 1
fnt = get_font(fontsize)
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
@@ -275,9 +264,6 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None, force_RGBA=
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
"""
if not force_RGBA and im.mode == 'RGBA':
im = im.convert('RGB')
upscaler_name = upscaler_name or opts.upscaler_for_img2img
def resize(im, w, h):
@@ -656,7 +642,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
no_prompt:
TODO I don't know its meaning.
p (`StableDiffusionProcessing` or `Processing`)
p (`StableDiffusionProcessing`)
forced_filename (`str`):
If specified, `basename` and filename pattern will be ignored.
save_to_dirs (bool):
@@ -687,13 +673,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if forced_filename is None:
if short_filename or seed is None:
file_decoration = ""
elif hasattr(p, 'override_settings'):
file_decoration = p.override_settings.get("samples_filename_pattern")
elif opts.save_to_dirs:
file_decoration = opts.samples_filename_pattern or "[seed]"
else:
file_decoration = None
if file_decoration is None:
file_decoration = opts.samples_filename_pattern or ("[seed]" if opts.save_to_dirs else "[seed]-[prompt_spaces]")
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
file_decoration = namegen.apply(file_decoration) + suffix
@@ -720,8 +703,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
pnginfo[pnginfo_section_name] = info
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
if opts.enable_pnginfo:
stealth_infotext.add_stealth_pnginfo(params)
script_callbacks.before_image_saved_callback(params)
image = params.image
@@ -737,15 +718,12 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
filename = filename_without_extension + extension
without_extension = filename_without_extension
if shared.opts.save_images_replace_action != "Replace":
n = 0
while os.path.exists(filename):
n += 1
without_extension = f"{filename_without_extension}-{n}"
filename = without_extension + extension
filename = f"{filename_without_extension}-{n}{extension}"
os.replace(temp_file_path, filename)
return without_extension
fullfn_without_extension, extension = os.path.splitext(params.filename)
if hasattr(os, 'statvfs'):
@@ -753,9 +731,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
params.filename = fullfn_without_extension + extension
fullfn = params.filename
_atomically_save_image(image, fullfn_without_extension, extension)
fullfn_without_extension = _atomically_save_image(image, fullfn_without_extension, extension)
fullfn = fullfn_without_extension + extension
image.already_saved_as = fullfn
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
@@ -774,7 +751,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
except Exception:
image = image.resize(resize_to)
try:
_ = _atomically_save_image(image, fullfn_without_extension, ".jpg")
_atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e:
errors.display(e, "saving image as downscaled JPG")
@@ -798,53 +775,44 @@ IGNORED_INFO_KEYS = {
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
"""Read generation info from an image, checking standard metadata first, then stealth info if needed."""
items = (image.info or {}).copy()
def read_standard():
items = (image.info or {}).copy()
geninfo = items.pop('parameters', None)
geninfo = items.pop('parameters', None)
if "exif" in items:
exif_data = items["exif"]
try:
exif = piexif.load(exif_data)
except OSError:
# memory / exif was not valid so piexif tried to read from a file
exif = None
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
if "exif" in items:
exif_data = items["exif"]
try:
exif = piexif.load(exif_data)
except OSError:
# memory / exif was not valid so piexif tried to read from a file
exif = None
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
if exif_comment:
geninfo = exif_comment
elif "comment" in items: # for gif
if isinstance(items["comment"], bytes):
geninfo = items["comment"].decode('utf8', errors="ignore")
else:
geninfo = items["comment"]
if exif_comment:
geninfo = exif_comment
elif "comment" in items: # for gif
if isinstance(items["comment"], bytes):
geninfo = items["comment"].decode('utf8', errors="ignore")
else:
geninfo = items["comment"]
for field in IGNORED_INFO_KEYS:
items.pop(field, None)
for field in IGNORED_INFO_KEYS:
items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
json_info = json.loads(items["Comment"])
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
if items.get("Software", None) == "NovelAI":
try:
json_info = json.loads(items["Comment"])
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
geninfo = f"""{items["Description"]}
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
geninfo, items = read_standard()
if geninfo is None:
geninfo = stealth_infotext.read_info_from_image_stealth(image)
geninfo = f"""{items["Description"]}
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items

View File

@@ -122,14 +122,16 @@ def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False,
if output_dir:
p.outpath_samples = output_dir
p.override_settings['save_to_dirs'] = False
if opts.img2img_batch_use_original_name:
filename_pattern = f'{image_path.stem}-[generation_number]' if p.n_iter > 1 or p.batch_size > 1 else f'{image_path.stem}'
p.override_settings['samples_filename_pattern'] = filename_pattern
p.override_settings['save_images_replace_action'] = "Add number suffix"
if p.n_iter > 1 or p.batch_size > 1:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
else:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
p.override_settings.pop('save_images_replace_action', None)
proc = process_images(p)
if not discard_further_results and proc:

View File

@@ -197,14 +197,10 @@ def connect_paste_params_buttons():
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
if img.mode == 'RGBA':
img = img.convert('RGB')
elif isinstance(x, list) and isinstance(x[0], tuple):
img = x[0][0]
else:
img = image_from_url_text(x)
if img is not None and img.mode == 'RGBA':
img = img.convert('RGB')
if shared.opts.send_size and isinstance(img, Image.Image):
w = img.width

View File

@@ -133,4 +133,8 @@ def initialize_rest(*, reload_script_modules=False):
extra_networks.register_default_extra_networks()
startup_timer.record("initialize extra networks")
from modules_forge import google_blockly
google_blockly.initialization()
startup_timer.record("initialize google blockly")
return

View File

@@ -395,13 +395,15 @@ def prepare_environment():
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
# k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
google_blockly_repo = os.environ.get('GOOGLE_BLOCKLY_REPO', 'https://github.com/lllyasviel/google_blockly_prototypes')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
# k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "84826248b49bb7ca754c73293299c4d4e23a548d")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "70942022b6bcd17d941c1b4172804175758618e2")
google_blockly_commit_hash = os.environ.get('GOOGLE_BLOCKLY_COMMIT_HASH', "bf36e6fd3750a081f44209ba4f645adb598f7e37")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
@@ -462,6 +464,7 @@ def prepare_environment():
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
# git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
git_clone(google_blockly_repo, repo_dir('google_blockly_prototypes'), "google_blockly", google_blockly_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores")
@@ -560,7 +563,7 @@ def dump_sysinfo():
import datetime
text = sysinfo.get()
filename = f"sysinfo-{datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d-%H-%M')}.json"
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
with open(filename, "w", encoding="utf8") as file:
file.write(text)

View File

@@ -10,7 +10,6 @@ import torch
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.util import load_file_from_url # noqa, backwards compatibility
if TYPE_CHECKING:
import spandrel
@@ -18,6 +17,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def load_file_from_url(
url: str,
*,
model_dir: str,
progress: bool = True,
file_name: str | None = None,
hash_prefix: str | None = None,
) -> str:
"""Download a file from `url` into `model_dir`, using the file present if possible.
Returns the path to the downloaded file.
"""
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
from torch.hub import download_url_to_file
download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)
return cached_file
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.

View File

@@ -0,0 +1,17 @@
MIT License
Copyright © 2024 Stability AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,868 @@
### This file contains impls for underlying related models (CLIP, T5, etc)
import logging
import math
import os
import torch
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
#################################################################################################
### Core/Utility
#################################################################################################
def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
dtype=None,
device=None,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(
in_features, hidden_features, bias=bias, dtype=dtype, device=device
)
self.act = act_layer
self.fc2 = nn.Linear(
hidden_features, out_features, bias=bias, dtype=dtype, device=device
)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
#################################################################################################
### CLIP
#################################################################################################
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device):
super().__init__()
self.heads = heads
self.q_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.k_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.v_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
def forward(self, x, mask=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPLayer(torch.nn.Module):
def __init__(
self,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self.mlp = Mlp(
embed_dim,
intermediate_size,
embed_dim,
act_layer=ACTIVATIONS[intermediate_activation],
dtype=dtype,
device=device,
)
def forward(self, x, mask=None):
x += self.self_attn(self.layer_norm1(x), mask)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(
self,
num_layers,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
CLIPLayer(
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
)
for i in range(num_layers)
]
)
def forward(self, x, mask=None, intermediate_output=None):
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(
self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None
):
super().__init__()
self.token_embedding = torch.nn.Embedding(
vocab_size, embed_dim, dtype=dtype, device=device
)
self.position_embedding = torch.nn.Embedding(
num_positions, embed_dim, dtype=dtype, device=device
)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(
num_layers,
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(
self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True
):
x = self.embeddings(input_tokens)
causal_mask = (
torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
.fill_(float("-inf"))
.triu_(1)
)
x, i = self.encoder(
x, mask=causal_mask, intermediate_output=intermediate_output
)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[
torch.arange(x.shape[0], device=x.device),
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device)
embed_dim = config_dict["hidden_size"]
self.text_projection = nn.Linear(
embed_dim, embed_dim, bias=False, dtype=dtype, device=device
)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
def parse_parentheses(string):
result = []
current_item = ""
nesting_level = 0
for char in string:
if char == "(":
if nesting_level == 0:
if current_item:
result.append(current_item)
current_item = "("
else:
current_item = "("
else:
current_item += char
nesting_level += 1
elif char == ")":
nesting_level -= 1
if nesting_level == 0:
result.append(current_item + ")")
current_item = ""
else:
current_item += char
else:
current_item += char
if current_item:
result.append(current_item)
return result
def token_weights(string, current_weight):
a = parse_parentheses(string)
out = []
for x in a:
weight = current_weight
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
x = x[1:-1]
xx = x.rfind(":")
weight *= 1.1
if xx > 0:
try:
weight = float(x[xx + 1 :])
x = x[:xx]
except:
pass
out += token_weights(x, weight)
else:
out += [(x, current_weight)]
return out
def escape_important(text):
text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2")
return text
def unescape_important(text):
text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(")
return text
class SDTokenizer:
def __init__(
self,
max_length=77,
pad_with_end=True,
tokenizer=None,
has_start_token=True,
pad_to_max_length=True,
min_length=None,
extra_padding_token=None,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer("")["input_ids"]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
self.extra_padding_token = extra_padding_token
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8
def tokenize_with_weights(self, text: str, return_word_ids=False):
"""
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
"""
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
# tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = (
unescape_important(weighted_segment).replace("\n", " ").split(" ")
)
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
# parse word
tokens.append(
[
(t, weight)
for t in self.tokenizer(word)["input_ids"][
self.tokens_start : -1
]
]
)
# reshape token array to CLIP input size
batched_tokens = []
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
# determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
# break word in two and add end token
if is_large:
batch.extend(
[(t, w, i + 1) for t, w in t_group[:remaining_length]]
)
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
# add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
# start new batch
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t, w, i + 1) for t, w in t_group])
t_group = []
# pad extra padding token first befor getting to the end token
if self.extra_padding_token is not None:
batch.extend(
[(self.extra_padding_token, 1.0, 0)]
* (self.min_length - len(batch) - 1)
)
# fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
class SDXLClipGTokenizer(SDTokenizer):
def __init__(self, tokenizer):
super().__init__(pad_with_end=False, tokenizer=tokenizer)
class SD3Tokenizer:
def __init__(self):
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
self.t5xxl = T5XXLTokenizer()
def tokenize_with_weights(self, text: str):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text)
out["g"] = self.clip_g.tokenize_with_weights(text)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
return out
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled
output = [out[0:1]]
return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
device="cpu",
max_length=77,
layer="last",
layer_idx=None,
textmodel_json_config=None,
dtype=None,
model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
layer_norm_hidden_state=True,
return_projected_pooled=True,
):
super().__init__()
assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.set_clip_options({"layer": layer_idx})
self.options_default = (
self.layer,
self.layer_idx,
self.return_projected_pooled,
)
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get(
"projected_pooled", self.return_projected_pooled
)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = torch.LongTensor(tokens).to(device)
outputs = self.transformer(
tokens,
intermediate_output=self.layer_idx,
final_layer_norm_intermediate=self.layer_norm_hidden_state,
)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
pooled_output = None
if len(outputs) >= 3:
if (
not self.return_projected_pooled
and len(outputs) >= 4
and outputs[3] is not None
):
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output
class SDXLClipG(SDClipModel):
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
def __init__(
self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None
):
if layer == "penultimate":
layer = "hidden"
layer_idx = -2
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=config,
dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0},
layer_norm_hidden_state=False,
)
class T5XXLModel(SDClipModel):
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
super().__init__(
device=device,
layer=layer,
layer_idx=layer_idx,
textmodel_json_config=config,
dtype=dtype,
special_tokens={"end": 1, "pad": 0},
model_class=T5,
)
#################################################################################################
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
def __init__(self):
super().__init__(
pad_with_end=False,
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
has_start_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=77,
)
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
super().__init__()
self.weight = torch.nn.Parameter(
torch.ones(hidden_size, dtype=dtype, device=device)
)
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
x = self.wo(x)
return x
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x):
forwarded_states = self.layer_norm(x)
forwarded_states = self.DenseReluDense(forwarded_states)
x += forwarded_states
return x
class T5Attention(torch.nn.Module):
def __init__(
self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
):
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(
self.relative_attention_num_buckets, self.num_heads, device=device
)
@staticmethod
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
def compute_bias(self, query_length, key_length, device):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
:, None
]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
None, :
]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x, past_bias=None):
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None:
mask = past_bias
out = attention(
q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask
)
return self.o(out), past_bias
class T5LayerSelfAttention(torch.nn.Module):
def __init__(
self,
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
):
super().__init__()
self.SelfAttention = T5Attention(
model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x, past_bias=None):
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
x += output
return x, past_bias
class T5Block(torch.nn.Module):
def __init__(
self,
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
):
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias,
dtype,
device,
)
)
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
def forward(self, x, past_bias=None):
x, past_bias = self.layer[0](x, past_bias)
x = self.layer[-1](x)
return x, past_bias
class T5Stack(torch.nn.Module):
def __init__(
self,
num_layers,
model_dim,
inner_dim,
ff_dim,
num_heads,
vocab_size,
dtype,
device,
):
super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
self.block = torch.nn.ModuleList(
[
T5Block(
model_dim,
inner_dim,
ff_dim,
num_heads,
relative_attention_bias=(i == 0),
dtype=dtype,
device=device,
)
for i in range(num_layers)
]
)
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(
self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True
):
intermediate = None
x = self.embed_tokens(input_ids)
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
class T5(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_layers"]
self.encoder = T5Stack(
self.num_layers,
config_dict["d_model"],
config_dict["d_model"],
config_dict["d_ff"],
config_dict["num_heads"],
config_dict["vocab_size"],
dtype,
device,
)
self.dtype = dtype
def get_input_embeddings(self):
return self.encoder.embed_tokens
def set_input_embeddings(self, embeddings):
self.encoder.embed_tokens = embeddings
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)

View File

@@ -0,0 +1,222 @@
import os
import safetensors
import torch
import typing
from transformers import CLIPTokenizer, T5TokenizerFast
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
class SafetensorsMapping(typing.Mapping):
def __init__(self, file):
self.file = file
def __len__(self):
return len(self.file.keys())
def __iter__(self):
for key in self.file.keys():
yield key
def __getitem__(self, key):
return self.file.get_tensor(key)
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
"textual_inversion_key": "clip_g",
}
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
def __init__(self, clip_l, clip_g):
super().__init__()
self.clip_l = clip_l
self.clip_g = clip_g
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
empty = self.tokenizer('')["input_ids"]
self.id_start = empty[0]
self.id_end = empty[1]
self.id_pad = empty[1]
self.return_pooled = True
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def encode_with_transformers(self, tokens):
tokens_g = tokens.clone()
for batch_pos in range(tokens_g.shape[0]):
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
l_out, l_pooled = self.clip_l(tokens)
g_out, g_pooled = self.clip_g(tokens_g)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out.pooled = vector_out
return lg_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
class Sd3T5(torch.nn.Module):
def __init__(self, t5xxl):
super().__init__()
self.t5xxl = t5xxl
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
self.id_end = empty[0]
self.id_pad = empty[1]
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def tokenize_line(self, line, *, target_token_count=None):
if shared.opts.emphasis != "None":
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.tokenize([text for text, _ in parsed])
tokens = []
multipliers = []
for text_tokens, (text, weight) in zip(tokenized, parsed):
if text == 'BREAK' and weight == -1:
continue
tokens += text_tokens
multipliers += [weight] * len(text_tokens)
tokens += [self.id_end]
multipliers += [1.0]
if target_token_count is not None:
if len(tokens) < target_token_count:
tokens += [self.id_pad] * (target_token_count - len(tokens))
multipliers += [1.0] * (target_token_count - len(tokens))
else:
tokens = tokens[0:target_token_count]
multipliers = multipliers[0:target_token_count]
return tokens, multipliers
def forward(self, texts, *, token_count):
if not self.t5xxl or not shared.opts.sd3_enable_t5:
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
tokens_batch = []
for text in texts:
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
tokens_batch.append(tokens)
t5_out, t5_pooled = self.t5xxl(tokens_batch)
return t5_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
class SD3Cond(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = SD3Tokenizer()
with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.sd3_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
else:
self.t5xxl = None
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
self.model_t5 = Sd3T5(self.t5xxl)
def forward(self, prompts: list[str]):
with devices.without_autocast():
lg_out, vector_out = self.model_lg(prompts)
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
return {
'crossattn': lgt_out,
'vector': vector_out,
}
def before_load_weights(self, state_dict):
clip_path = os.path.join(shared.models_path, "CLIP")
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
def encode_embedding_init_text(self, init_text, nvpt):
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
def tokenize(self, texts):
return self.model_lg.tokenize(texts)
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
def get_token_count(self, text):
_, token_count = self.model_lg.process_texts([text])
return token_count
def get_target_prompt_token_count(self, token_count):
return self.model_lg.get_target_prompt_token_count(token_count)

View File

@@ -0,0 +1,623 @@
### Impls of the SD3 core diffusion model and VAE
import math
import re
import einops
import torch
from PIL import Image
from tqdm import tqdm
from modules.models.sd35.mmditx import MMDiTX
#################################################################################################
### MMDiT Model Wrapping
#################################################################################################
class ModelSamplingDiscreteFlow(torch.nn.Module):
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
def __init__(self, shift=1.0):
super().__init__()
self.shift = shift
timesteps = 1000
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.register_buffer("sigmas", ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep: torch.Tensor):
timestep = timestep / 1000.0
if self.shift == 1.0:
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return sigma * noise + (1.0 - sigma) * latent_image
class BaseModel(torch.nn.Module):
"""Wrapper around the core MM-DiT model"""
def __init__(
self,
state_dict,
shift=1.0,
*args,
**kwargs
):
super().__init__()
prefix = ''
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
qk_norm = (
"rms"
if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys()
else None
)
x_block_self_attn_layers = sorted(
[
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
for key in list(
filter(
re.compile(".*.x_block.attn2.ln_k.weight").match, state_dict.keys()
)
)
]
)
context_embedder_config = {
"target": "torch.nn.Linear",
"params": {
"in_features": context_shape[1],
"out_features": context_shape[0],
},
}
self.diffusion_model = MMDiTX(
input_size=None,
pos_embed_scaling_factor=None,
pos_embed_offset=None,
pos_embed_max_size=pos_embed_max_size,
patch_size=patch_size,
in_channels=16,
depth=depth,
num_patches=num_patches,
adm_in_channels=adm_in_channels,
context_embedder_config=context_embedder_config,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
# device=kwargs['device'],
# dtype=kwargs['dtype'],
# verbose=kwargs['verbose'],
# **kwargs
)
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
def apply_model(self, x, sigma, y=None, *args, **kwargs):
dtype = self.get_dtype()
timestep = self.model_sampling.timestep(sigma).float()
model_output = self.diffusion_model(
x.to(dtype), timestep, context=kwargs["context"].to(dtype), y=y.to(dtype)
).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
def get_dtype(self):
return self.diffusion_model.dtype
class CFGDenoiser(torch.nn.Module):
"""Helper for applying CFG Scaling to diffusion outputs"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x, timestep, cond, uncond, cond_scale):
# Run cond and uncond in a batch together
batched = self.model.apply_model(
torch.cat([x, x]),
torch.cat([timestep, timestep]),
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
y=torch.cat([cond["y"], uncond["y"]]),
)
# Then split and apply CFG Scaling
pos_out, neg_out = batched.chunk(2)
scaled = neg_out + (pos_out - neg_out) * cond_scale
return scaled
class SD3LatentFormat:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
def __init__(self):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
def decode_latent_to_preview(self, x0):
"""Quick RGB approximate preview of sd3 latents"""
factors = torch.tensor(
[
[-0.0645, 0.0177, 0.1052],
[0.0028, 0.0312, 0.0650],
[0.1848, 0.0762, 0.0360],
[0.0944, 0.0360, 0.0889],
[0.0897, 0.0506, -0.0364],
[-0.0020, 0.1203, 0.0284],
[0.0855, 0.0118, 0.0283],
[-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700],
[-0.0412, 0.0281, -0.0039],
[0.1106, 0.1171, 0.1220],
[-0.0248, 0.0682, -0.0481],
[0.0815, 0.0846, 0.1207],
[-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259],
],
device="cpu",
)
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()
return Image.fromarray(latents_ubyte.numpy())
#################################################################################################
### Samplers
#################################################################################################
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
return x[(...,) + (None,) * dims_to_append]
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
@torch.no_grad()
@torch.autocast("cuda", dtype=torch.float16)
def sample_euler(model, x, sigmas, extra_args=None):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
@torch.autocast("cuda", dtype=torch.float16)
def sample_dpmpp_2m(model, x, sigmas, extra_args=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in tqdm(range(len(sigmas) - 1)):
denoised = model(x, sigmas[i] * s_in, **extra_args)
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
#################################################################################################
### VAE
#################################################################################################
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
return torch.nn.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
dtype=dtype,
device=device,
)
class ResnetBlock(torch.nn.Module):
def __init__(
self, *, in_channels, out_channels=None, dtype=torch.float32, device=None
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
self.conv1 = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
self.conv2 = torch.nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
else:
self.nin_shortcut = None
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
hidden = x
hidden = self.norm1(hidden)
hidden = self.swish(hidden)
hidden = self.conv1(hidden)
hidden = self.norm2(hidden)
hidden = self.swish(hidden)
hidden = self.conv2(hidden)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + hidden
class AttnBlock(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.norm = Normalize(in_channels, dtype=dtype, device=device)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
dtype=dtype,
device=device,
)
def forward(self, x):
hidden = self.norm(x)
q = self.q(hidden)
k = self.k(hidden)
v = self.v(hidden)
b, c, h, w = q.shape
q, k, v = map(
lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(),
(q, k, v),
)
hidden = torch.nn.functional.scaled_dot_product_attention(
q, k, v
) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden)
return x + hidden
class Downsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0,
dtype=dtype,
device=device,
)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class VAEEncoder(torch.nn.Module):
def __init__(
self,
ch=128,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
in_channels=3,
z_channels=16,
dtype=torch.float32,
device=None,
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels,
ch,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = torch.nn.ModuleList()
for i_level in range(self.num_resolutions):
block = torch.nn.ModuleList()
attn = torch.nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
dtype=dtype,
device=device,
)
)
block_in = block_out
down = torch.nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, dtype=dtype, device=device)
self.down.append(down)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
)
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = self.swish(h)
h = self.conv_out(h)
return h
class VAEDecoder(torch.nn.Module):
def __init__(
self,
ch=128,
out_ch=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
resolution=256,
z_channels=16,
dtype=torch.float32,
device=None,
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, dtype=dtype, device=device
)
# upsampling
self.up = torch.nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
dtype=dtype,
device=device,
)
)
block_in = block_out
up = torch.nn.Module()
up.block = block
if i_level != 0:
up.upsample = Upsample(block_in, dtype=dtype, device=device)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(
block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1,
dtype=dtype,
device=device,
)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, z):
# z to block_in
hidden = self.conv_in(z)
# middle
hidden = self.mid.block_1(hidden)
hidden = self.mid.attn_1(hidden)
hidden = self.mid.block_2(hidden)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
hidden = self.up[i_level].block[i_block](hidden)
if i_level != 0:
hidden = self.up[i_level].upsample(hidden)
# end
hidden = self.norm_out(hidden)
hidden = self.swish(hidden)
hidden = self.conv_out(hidden)
return hidden
class SDVAE(torch.nn.Module):
def __init__(self, dtype=torch.float32, device=None):
super().__init__()
self.encoder = VAEEncoder(dtype=dtype, device=device)
self.decoder = VAEDecoder(dtype=dtype, device=device)
@torch.autocast("cuda", dtype=torch.float16)
def decode(self, latent):
return self.decoder(latent)
@torch.autocast("cuda", dtype=torch.float16)
def encode(self, image):
hidden = self.encoder(image)
mean, logvar = torch.chunk(hidden, 2, dim=1)
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)

View File

@@ -0,0 +1,436 @@
# NOTE: Must have folder `models` with the following files:
# - `clip_g.safetensors` (openclip bigG, same as SDXL)
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
# - `sd3_medium.safetensors` (or whichever main MMDiT model file)
# Also can have
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
import datetime
import math
import os
import fire
import numpy as np
import torch
from PIL import Image
from safetensors import safe_open
from tqdm import tqdm
from modules.models.sd35 import sd3_impls
from modules.models.sd35.other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
from modules.models.sd35.sd3_impls import SDVAE, BaseModel, CFGDenoiser, SD3LatentFormat
#################################################################################################
### Wrappers for model parts
#################################################################################################
def load_into(f, model, prefix, device, dtype=None):
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
for key in f.keys():
if key.startswith(prefix) and not key.startswith("loss."):
path = key[len(prefix) :].split(".")
obj = model
for p in path:
if obj is list:
obj = obj[int(p)]
else:
obj = getattr(obj, p, None)
if obj is None:
print(
f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model"
)
break
if obj is None:
continue
try:
tensor = f.get_tensor(key).to(device=device)
if dtype is not None:
tensor = tensor.to(dtype=dtype)
obj.requires_grad_(False)
obj.set_(tensor)
except Exception as e:
print(f"Failed to load key '{key}' in safetensors file: {e}")
raise e
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
}
class ClipG:
def __init__(self):
with safe_open("models/clip_g.safetensors", framework="pt", device="cpu") as f:
self.model = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
load_into(f, self.model.transformer, "", "cpu", torch.float32)
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}
class ClipL:
def __init__(self):
with safe_open("models/clip_l.safetensors", framework="pt", device="cpu") as f:
self.model = SDClipModel(
layer="hidden",
layer_idx=-2,
device="cpu",
dtype=torch.float32,
layer_norm_hidden_state=False,
return_projected_pooled=False,
textmodel_json_config=CLIPL_CONFIG,
)
load_into(f, self.model.transformer, "", "cpu", torch.float32)
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}
class T5XXL:
def __init__(self):
with safe_open("models/t5xxl.safetensors", framework="pt", device="cpu") as f:
self.model = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
load_into(f, self.model.transformer, "", "cpu", torch.float32)
class SD3:
def __init__(self, model, shift, verbose=False):
with safe_open(model, framework="pt", device="cpu") as f:
self.model = BaseModel(
shift=shift,
file=f,
prefix="model.diffusion_model.",
device="cpu",
dtype=torch.float16,
verbose=verbose,
).eval()
load_into(f, self.model, "model.", "cpu", torch.float16)
class VAE:
def __init__(self, model):
with safe_open(model, framework="pt", device="cpu") as f:
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
prefix = ""
if any(k.startswith("first_stage_model.") for k in f.keys()):
prefix = "first_stage_model."
load_into(f, self.model, prefix, "cpu", torch.float16)
#################################################################################################
### Main inference logic
#################################################################################################
# Note: Sigma shift value, publicly released models use 3.0
SHIFT = 3.0
# Naturally, adjust to the width/height of the model you have
WIDTH = 1024
HEIGHT = 1024
# Pick your prompt
PROMPT = "a photo of a cat"
# Most models prefer the range of 4-5, but still work well around 7
CFG_SCALE = 4.5
# Different models want different step counts but most will be good at 50, albeit that's slow to run
# sd3_medium is quite decent at 28 steps
STEPS = 40
# Seed
SEED = 23
# SEEDTYPE = "fixed"
SEEDTYPE = "rand"
# SEEDTYPE = "roll"
# Actual model file path
# MODEL = "models/sd3_medium.safetensors"
# MODEL = "models/sd3.5_large_turbo.safetensors"
MODEL = "models/sd3.5_large.safetensors"
# VAE model file path, or set None to use the same model file
VAEFile = None # "models/sd3_vae.safetensors"
# Optional init image file path
INIT_IMAGE = None
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
DENOISE = 0.6
# Output file path
OUTDIR = "outputs"
# SAMPLER
# SAMPLER = "euler"
SAMPLER = "dpmpp_2m"
class SD3Inferencer:
def print(self, txt):
if self.verbose:
print(txt)
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT, verbose=False):
self.verbose = verbose
print("Loading tokenizers...")
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
# (T5 tokenizer is different though)
self.tokenizer = SD3Tokenizer()
print("Loading OpenAI CLIP L...")
self.clip_l = ClipL()
print("Loading OpenCLIP bigG...")
self.clip_g = ClipG()
print("Loading Google T5-v1-XXL...")
self.t5xxl = T5XXL()
print(f"Loading SD3 model {os.path.basename(model)}...")
self.sd3 = SD3(model, shift, verbose)
print("Loading VAE model...")
self.vae = VAE(vae or model)
print("Models loaded.")
def get_empty_latent(self, width, height):
self.print("Prep an empty latent...")
return torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
def get_sigmas(self, sampling, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampling.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def get_noise(self, seed, latent):
generator = torch.manual_seed(seed)
self.print(
f"dtype = {latent.dtype}, layout = {latent.layout}, device = {latent.device}"
)
return torch.randn(
latent.size(),
dtype=torch.float32,
layout=latent.layout,
generator=generator,
device="cpu",
).to(latent.dtype)
def get_cond(self, prompt):
self.print("Encode prompt...")
tokens = self.tokenizer.tokenize_with_weights(prompt)
l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
return torch.cat([lg_out, t5_out], dim=-2), torch.cat(
(l_pooled, g_pooled), dim=-1
)
def max_denoise(self, sigmas):
max_sigma = float(self.sd3.model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def fix_cond(self, cond):
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
return {"c_crossattn": cond, "y": pooled}
def do_sampling(
self,
latent,
seed,
conditioning,
neg_cond,
steps,
cfg_scale,
sampler="dpmpp_2m",
denoise=1.0,
) -> torch.Tensor:
self.print("Sampling...")
latent = latent.half().cuda()
self.sd3.model = self.sd3.model.cuda()
noise = self.get_noise(seed, latent).cuda()
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
sigmas = sigmas[int(steps * (1 - denoise)) :]
conditioning = self.fix_cond(conditioning)
neg_cond = self.fix_cond(neg_cond)
extra_args = {"cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale}
noise_scaled = self.sd3.model.model_sampling.noise_scaling(
sigmas[0], noise, latent, self.max_denoise(sigmas)
)
sample_fn = getattr(sd3_impls, f"sample_{sampler}")
latent = sample_fn(
CFGDenoiser(self.sd3.model), noise_scaled, sigmas, extra_args=extra_args
)
latent = SD3LatentFormat().process_out(latent)
self.sd3.model = self.sd3.model.cpu()
self.print("Sampling done")
return latent
def vae_encode(self, image) -> torch.Tensor:
self.print("Encoding image to latent...")
image = image.convert("RGB")
image_np = np.array(image).astype(np.float32) / 255.0
image_np = np.moveaxis(image_np, 2, 0)
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
image_torch = torch.from_numpy(batch_images)
image_torch = 2.0 * image_torch - 1.0
image_torch = image_torch.cuda()
self.vae.model = self.vae.model.cuda()
latent = self.vae.model.encode(image_torch).cpu()
self.vae.model = self.vae.model.cpu()
self.print("Encoded")
return latent
def vae_decode(self, latent) -> Image.Image:
self.print("Decoding latent to image...")
latent = latent.cuda()
self.vae.model = self.vae.model.cuda()
image = self.vae.model.decode(latent)
image = image.float()
self.vae.model = self.vae.model.cpu()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
out_image = Image.fromarray(decoded_np)
self.print("Decoded")
return out_image
def gen_image(
self,
prompts=[PROMPT],
width=WIDTH,
height=HEIGHT,
steps=STEPS,
cfg_scale=CFG_SCALE,
sampler=SAMPLER,
seed=SEED,
seed_type=SEEDTYPE,
out_dir=OUTDIR,
init_image=INIT_IMAGE,
denoise=DENOISE,
):
latent = self.get_empty_latent(width, height)
if init_image:
image_data = Image.open(init_image)
image_data = image_data.resize((width, height), Image.LANCZOS)
latent = self.vae_encode(image_data)
latent = SD3LatentFormat().process_in(latent)
neg_cond = self.get_cond("")
seed_num = None
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True)
for i, prompt in pbar:
if seed_type == "roll":
seed_num = seed if seed_num is None else seed_num + 1
elif seed_type == "rand":
seed_num = torch.randint(0, 100000, (1,)).item()
else: # fixed
seed_num = seed
conditioning = self.get_cond(prompt)
sampled_latent = self.do_sampling(
latent,
seed_num,
conditioning,
neg_cond,
steps,
cfg_scale,
sampler,
denoise if init_image else 1.0,
)
image = self.vae_decode(sampled_latent)
save_path = os.path.join(out_dir, f"{i:06d}.png")
self.print(f"Will save to {save_path}")
image.save(save_path)
self.print("Done")
CONFIGS = {
"sd3_medium": {
"shift": 1.0,
"cfg": 5.0,
"steps": 50,
"sampler": "dpmpp_2m",
},
"sd3.5_large": {
"shift": 3.0,
"cfg": 4.5,
"steps": 40,
"sampler": "dpmpp_2m",
},
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
}
@torch.no_grad()
def main(
prompt=PROMPT,
model=MODEL,
out_dir=OUTDIR,
postfix=None,
seed=SEED,
seed_type=SEEDTYPE,
sampler=None,
steps=None,
cfg=None,
shift=None,
width=WIDTH,
height=HEIGHT,
vae=VAEFile,
init_image=INIT_IMAGE,
denoise=DENOISE,
verbose=False,
):
steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
sampler = (
sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
)
inferencer = SD3Inferencer()
inferencer.load(model, vae, shift, verbose)
if isinstance(prompt, str):
if os.path.splitext(prompt)[-1] == ".txt":
with open(prompt, "r") as f:
prompts = [l.strip() for l in f.readlines()]
else:
prompts = [prompt]
out_dir = os.path.join(
out_dir,
os.path.splitext(os.path.basename(model))[0],
os.path.splitext(os.path.basename(prompt))[0][:50]
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
)
print(f"Saving images to {out_dir}")
os.makedirs(out_dir, exist_ok=False)
inferencer.gen_image(
prompts,
width,
height,
steps,
cfg,
sampler,
seed,
seed_type,
out_dir,
init_image,
denoise,
)
fire.Fire(main)

View File

@@ -1402,7 +1402,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
fp_additional_modules = getattr(shared.opts, 'forge_additional_modules')
reload = False
if hasattr(self, 'hr_additional_modules') and 'Use same choices' not in self.hr_additional_modules:
if 'Use same choices' not in self.hr_additional_modules:
modules_changed = main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False)
if modules_changed:
reload = True

View File

@@ -21,29 +21,34 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
def ui(self, is_img2img):
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
with gr.Row():
refiner_checkpoint = gr.Dropdown(label='Checkpoint', info='(use model of same architecture)', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles(use_short=True)], value='', tooltip="switch to another model in the middle of generation")
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles(use_short=True)}, self.elem_id("checkpoint_refresh"))
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
def lookup_checkpoint(title):
info = sd_models.get_closet_checkpoint_match(title)
return None if info is None else info.short_title
self.infotext_fields = [
PasteField(enable_refiner, lambda d: 'Refiner' in d),
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
]
gr.Markdown('Refiner is currently under maintenance and unavailable. Sorry for the inconvenience.')
return enable_refiner, refiner_checkpoint, refiner_switch_at
#
# with gr.Row():
# refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles()], value='', tooltip="switch to another model in the middle of generation", interactive=False, visible=False)
# # create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
#
# refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation", interactive=False, visible=False)
#
# def lookup_checkpoint(title):
# info = sd_models.get_closet_checkpoint_match(title)
# return None if info is None else info.title
#
# self.infotext_fields = [
# PasteField(enable_refiner, lambda d: 'Refiner' in d),
# PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
# PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
# ]
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
p.refiner_checkpoint = None
p.refiner_switch_at = None
else:
p.refiner_checkpoint = refiner_checkpoint
p.refiner_switch_at = refiner_switch_at
return [enable_refiner] # , refiner_checkpoint, refiner_switch_at
def setup(self, p, enable_refiner):
pass
# # the actual implementation is in sd_samplers_common.py, apply_refiner
#
# if not enable_refiner or refiner_checkpoint in (None, "", "None"):
# p.refiner_checkpoint = None
# p.refiner_switch_at = None
# else:
# p.refiner_checkpoint = refiner_checkpoint
# p.refiner_switch_at = refiner_switch_at

View File

@@ -450,7 +450,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
def unload_model_weights(sd_model=None, info=None):
memory_management.unload_all_models()
return
pass
def apply_token_merging(sd_model, token_merging_ratio):

View File

@@ -63,10 +63,9 @@ def set_samplers():
def add_sampler(sampler):
global all_samplers, all_samplers_map
if sampler.name not in [x.name for x in all_samplers]:
all_samplers.append(sampler)
all_samplers_map = {x.name: x for x in all_samplers}
set_samplers()
all_samplers.append(sampler)
all_samplers_map = {x.name: x for x in all_samplers}
set_samplers()
return

View File

@@ -168,9 +168,9 @@ class CFGDenoiser(torch.nn.Module):
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
sigma = real_sigma
if sd_samplers_common.apply_refiner(self, x):
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']
# if sd_samplers_common.apply_refiner(self, x):
# cond = self.sampler.sampler_extra_args['cond']
# uncond = self.sampler.sampler_extra_args['uncond']
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if uncond is not None else None

View File

@@ -8,7 +8,7 @@ from modules.shared import opts, state
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
from modules import extra_networks
import k_diffusion.sampling
from modules_forge import main_entry
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -70,12 +70,9 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = x_sample.cpu()
x_sample.clamp_(0.0, 1.0)
x_sample.mul_(255.)
x_sample.round_()
x_sample = x_sample.to(torch.uint8)
x_sample = np.moveaxis(x_sample.numpy(), 0, 2)
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
@@ -164,51 +161,45 @@ replace_torchsde_browinan()
def apply_refiner(cfg_denoiser, x):
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False
if getattr(cfg_denoiser.p, "enable_hr", False):
is_second_pass = cfg_denoiser.p.is_hr_pass
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
return False
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
return False
if opts.hires_fix_refiner_pass != "second pass":
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
with sd_models.SkipWritingToConfig():
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
checkpoint_changed = main_entry.checkpoint_change(refiner_checkpoint_info.short_title, save=False, refresh=False)
if checkpoint_changed:
try:
main_entry.refresh_model_loading_parameters()
sd_models.forge_model_reload()
finally:
main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=True)
if not cfg_denoiser.p.disable_extra_networks:
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
cfg_denoiser.p.setup_conds()
cfg_denoiser.update_inner_model()
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
return True
# completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
# refiner_switch_at = cfg_denoiser.p.refiner_switch_at
# refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
#
# if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
# return False
#
# if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
# return False
#
# if getattr(cfg_denoiser.p, "enable_hr", False):
# is_second_pass = cfg_denoiser.p.is_hr_pass
#
# if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
# return False
#
# if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
# return False
#
# if opts.hires_fix_refiner_pass != "second pass":
# cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
#
# cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
# cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
#
# sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
#
# with sd_models.SkipWritingToConfig():
# sd_models.reload_model_weights(info=refiner_checkpoint_info)
#
# if not cfg_denoiser.p.disable_extra_networks:
# extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
#
# cfg_denoiser.p.setup_conds()
# cfg_denoiser.update_inner_model()
#
# sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
# return True
pass
class TorchHijack:

View File

@@ -188,10 +188,10 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
"sdxl_crop_top": OptionInfo(0, "crop top coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}),
"sdxl_crop_left": OptionInfo(0, "crop left coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}),
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model negative prompt"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model prompt"),
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
@@ -225,13 +225,12 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
"img2img_inpaint_mask_high_contrast": OptionInfo(True, "For inpainting, use a high-contrast brush pattern").info("use a checkerboard brush pattern instead of color brush").needs_reload_ui(),
"img2img_inpaint_mask_scribble_alpha": OptionInfo(75, "Inpaint mask alpha (transparency)", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("only affects non-high-contrast brush").needs_reload_ui(),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
"overlay_inpaint": OptionInfo(True, "Overlay original for inpaint").info("when inpainting, overlay the original image over the areas that weren't inpainted."),
"img2img_autosize": OptionInfo(False, "After loading into Img2img, automatically update Width and Height"),
"img2img_batch_use_original_name": OptionInfo(False, "Save using original filename in img2img batch. Applies to 'Upload' and 'From directory' tabs.").info("Warning: overwriting is possible, based on Settings > Saving images/grids > Saving the image to an existing file.")
}))
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
@@ -320,7 +319,6 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
"sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
"open_dir_button_choice": OptionInfo("Subdirectory", "What directory the [📂] button opens", gr.Radio, {"choices": ["Output Root", "Subdirectory", "Subdirectory (even temp dir)"]}),
"hires_button_gallery_insert": OptionInfo(False, "Insert [✨] hires button results into gallery").info("Default: original image will be replaced"),
}))
options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), {
@@ -340,7 +338,6 @@ options_templates.update(options_section(('ui', "User interface", "ui"), {
"quick_setting_list": OptionInfo([], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"tabs_without_quick_settings_bar": OptionInfo(["Spaces"], "UI tabs without Quicksettings bar (top row)", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}),
"ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
@@ -358,7 +355,6 @@ Infotext is what this software calls the text that contains generation parameter
It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button.
"""),
"enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"),
"stealth_pnginfo_option": OptionInfo("Alpha", "Stealth infotext mode", gr.Radio, {"choices": ["Alpha", "RGB", "None"]}).info("Ignored if infotext is disabled"),
"save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"),
"add_model_name_to_info": OptionInfo(True, "Add model name to infotext"),

View File

@@ -4,10 +4,8 @@ import threading
import time
import traceback
import torch
from contextlib import nullcontext
from modules import errors, shared, devices
from backend.args import args
from typing import Optional
log = logging.getLogger(__name__)
@@ -36,10 +34,6 @@ class State:
def __init__(self):
self.server_start = time.time()
if args.cuda_stream:
self.vae_stream = torch.cuda.Stream()
else:
self.vae_stream = None
@property
def need_restart(self) -> bool:
@@ -159,18 +153,10 @@ class State:
import modules.sd_samplers
try:
if self.vae_stream is not None:
# not waiting on default stream will result in corrupt results
# will not block main stream under any circumstances
self.vae_stream.wait_stream(torch.cuda.default_stream())
vae_context = torch.cuda.stream(self.vae_stream)
if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
vae_context = nullcontext()
with vae_context:
if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step

View File

@@ -1,163 +0,0 @@
import gzip
from modules.script_callbacks import ImageSaveParams
from modules import shared
def add_stealth_pnginfo(params: ImageSaveParams):
stealth_pnginfo_option = shared.opts.data.get('stealth_pnginfo_option', 'Alpha')
if not stealth_pnginfo_option or stealth_pnginfo_option == 'None':
return
if not params.filename.endswith('.png') or params.pnginfo is None:
return
if 'parameters' not in params.pnginfo:
return
add_data(params, str(stealth_pnginfo_option), True)
def prepare_data(params, mode='Alpha', compressed=True):
signature = f"stealth_{'png' if mode == 'Alpha' else 'rgb'}{'info' if not compressed else 'comp'}"
binary_signature = ''.join(format(byte, '08b') for byte in signature.encode('utf-8'))
param = params.encode('utf-8') if not compressed else gzip.compress(bytes(params, 'utf-8'))
binary_param = ''.join(format(byte, '08b') for byte in param)
binary_param_len = format(len(binary_param), '032b')
return binary_signature + binary_param_len + binary_param
def add_data(params, mode='Alpha', compressed=True):
binary_data = prepare_data(params.pnginfo['parameters'], mode, compressed)
if mode == 'Alpha':
params.image.putalpha(255)
width, height = params.image.size
pixels = params.image.load()
index = 0
end_write = False
for x in range(width):
for y in range(height):
if index >= len(binary_data):
end_write = True
break
values = pixels[x, y]
if mode == 'Alpha':
r, g, b, a = values
else:
r, g, b = values
if mode == 'Alpha':
a = (a & ~1) | int(binary_data[index])
index += 1
else:
r = (r & ~1) | int(binary_data[index])
if index + 1 < len(binary_data):
g = (g & ~1) | int(binary_data[index + 1])
if index + 2 < len(binary_data):
b = (b & ~1) | int(binary_data[index + 2])
index += 3
pixels[x, y] = (r, g, b, a) if mode == 'Alpha' else (r, g, b)
if end_write:
break
def read_info_from_image_stealth(image):
geninfo = None
width, height = image.size
pixels = image.load()
has_alpha = True if image.mode == 'RGBA' else False
mode = None
compressed = False
binary_data = ''
buffer_a = ''
buffer_rgb = ''
index_a = 0
index_rgb = 0
sig_confirmed = False
confirming_signature = True
reading_param_len = False
reading_param = False
read_end = False
for x in range(width):
for y in range(height):
if has_alpha:
r, g, b, a = pixels[x, y]
buffer_a += str(a & 1)
index_a += 1
else:
r, g, b = pixels[x, y]
buffer_rgb += str(r & 1)
buffer_rgb += str(g & 1)
buffer_rgb += str(b & 1)
index_rgb += 3
if confirming_signature:
if index_a == len('stealth_pnginfo') * 8:
decoded_sig = bytearray(int(buffer_a[i:i + 8], 2) for i in
range(0, len(buffer_a), 8)).decode('utf-8', errors='ignore')
if decoded_sig in {'stealth_pnginfo', 'stealth_pngcomp'}:
confirming_signature = False
sig_confirmed = True
reading_param_len = True
mode = 'alpha'
if decoded_sig == 'stealth_pngcomp':
compressed = True
buffer_a = ''
index_a = 0
else:
read_end = True
break
elif index_rgb == len('stealth_pnginfo') * 8:
decoded_sig = bytearray(int(buffer_rgb[i:i + 8], 2) for i in
range(0, len(buffer_rgb), 8)).decode('utf-8', errors='ignore')
if decoded_sig in {'stealth_rgbinfo', 'stealth_rgbcomp'}:
confirming_signature = False
sig_confirmed = True
reading_param_len = True
mode = 'rgb'
if decoded_sig == 'stealth_rgbcomp':
compressed = True
buffer_rgb = ''
index_rgb = 0
elif reading_param_len:
if mode == 'alpha':
if index_a == 32:
param_len = int(buffer_a, 2)
reading_param_len = False
reading_param = True
buffer_a = ''
index_a = 0
else:
if index_rgb == 33:
pop = buffer_rgb[-1]
buffer_rgb = buffer_rgb[:-1]
param_len = int(buffer_rgb, 2)
reading_param_len = False
reading_param = True
buffer_rgb = pop
index_rgb = 1
elif reading_param:
if mode == 'alpha':
if index_a == param_len:
binary_data = buffer_a
read_end = True
break
else:
if index_rgb >= param_len:
diff = param_len - index_rgb
if diff < 0:
buffer_rgb = buffer_rgb[:diff]
binary_data = buffer_rgb
read_end = True
break
else:
# impossible
read_end = True
break
if read_end:
break
if sig_confirmed and binary_data != '':
# Convert binary string to UTF-8 encoded text
byte_data = bytearray(int(binary_data[i:i + 8], 2) for i in range(0, len(binary_data), 8))
try:
if compressed:
decoded_data = gzip.decompress(bytes(byte_data)).decode('utf-8')
else:
decoded_data = byte_data.decode('utf-8', errors='ignore')
geninfo = decoded_data
except:
pass
return geninfo

View File

@@ -102,21 +102,14 @@ def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery
shared.total_tqdm.clear()
insert = getattr(shared.opts, 'hires_button_gallery_insert', False)
new_gallery = []
for i, image in enumerate(gallery):
if insert or i != gallery_index:
image[0].already_saved_as = image[0].filename.rsplit('?', 1)[0]
new_gallery.append(image)
if i == gallery_index:
new_gallery.extend(processed.images)
new_index = gallery_index
if insert:
new_index += 1
geninfo["infotexts"].insert(new_index, processed.info)
else:
geninfo["infotexts"][gallery_index] = processed.info
else:
new_gallery.append(image)
geninfo["infotexts"][gallery_index] = processed.info
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")

View File

@@ -99,7 +99,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
new_width = p.hr_resize_x or p.hr_upscale_to_x
new_height = p.hr_resize_y or p.hr_upscale_to_y
new_width -= new_width % 8 # note: hardcoded latent size 8
new_height -= new_height % 8
@@ -346,7 +346,7 @@ def create_ui():
with FormRow(elem_id="txt2img_hires_fix_row_cfg", variant="compact"):
hr_distilled_cfg = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label="Hires Distilled CFG Scale", value=3.5, elem_id="txt2img_hr_distilled_cfg")
hr_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.1, label="Hires CFG Scale", value=7.0, elem_id="txt2img_hr_cfg")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_checkpoint_container:
hr_checkpoint_name = gr.Dropdown(label='Hires Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint", scale=2)
@@ -360,7 +360,7 @@ def create_ui():
else:
modules_list += list(main_entry.module_list.keys())
return modules_list
modules_list = get_additional_modules()
def refresh_model_and_modules():
@@ -470,12 +470,6 @@ def create_ui():
toprow.prompt.submit(**txt2img_args)
toprow.submit.click(**txt2img_args)
def select_gallery_image(index):
index = int(index)
if getattr(shared.opts, 'hires_button_gallery_insert', False):
index += 1
return gr.update(selected_index=index)
txt2img_upscale_inputs = txt2img_inputs[0:1] + [output_panel.gallery, dummy_component_number, output_panel.generation_info] + txt2img_inputs[1:]
output_panel.button_upscale.click(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
@@ -483,7 +477,7 @@ def create_ui():
inputs=txt2img_upscale_inputs,
outputs=txt2img_outputs,
show_progress=False,
).then(fn=select_gallery_image, js="selected_gallery_index", inputs=[dummy_component], outputs=[output_panel.gallery])
)
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
@@ -593,7 +587,7 @@ def create_ui():
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = ForgeCanvas(elem_id="img2maskimg", height=512, contrast_scribbles=opts.img2img_inpaint_mask_high_contrast, scribble_color=opts.img2img_inpaint_mask_brush_color, scribble_color_fixed=True, scribble_alpha=opts.img2img_inpaint_mask_scribble_alpha, scribble_alpha_fixed=True, scribble_softness_fixed=True)
init_img_with_mask = ForgeCanvas(elem_id="img2maskimg", height=512, contrast_scribbles=opts.img2img_inpaint_mask_high_contrast, scribble_color=opts.img2img_inpaint_mask_brush_color, scribble_color_fixed=True, scribble_alpha_fixed=True, scribble_softness_fixed=True)
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
@@ -921,13 +915,13 @@ def create_ui():
scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface:
ui_postprocessing.create_ui()
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Blocks(analytics_enabled=False, head=canvas_head) as pnginfo_interface:
with ResizeHandleRow(equal_height=False):
with gr.Column(variant='panel'):
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil", height="50vh", image_mode="RGBA")
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
with gr.Column(variant='panel'):
html = gr.HTML()
@@ -969,6 +963,8 @@ def create_ui():
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
interface_names_without_quick_setting_bars = ["Spaces"]
shared.tab_names = []
for _interface, label, _ifid in interfaces:
shared.tab_names.append(label)
@@ -996,8 +992,7 @@ def create_ui():
loadsave.setup_ui()
def tab_changed(evt: gr.SelectData):
no_quick_setting = getattr(shared.opts, "tabs_without_quick_settings_bar", [])
return gr.update(visible=evt.value not in no_quick_setting)
return gr.update(visible=evt.value not in interface_names_without_quick_setting_bars)
tabs.select(tab_changed, outputs=[quicksettings_row], show_progress=False, queue=False)
@@ -1074,7 +1069,7 @@ def setup_ui_api(app):
from fastapi.responses import PlainTextResponse
text = sysinfo.get()
filename = f"sysinfo-{datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d-%H-%M')}.json"
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})

View File

@@ -13,7 +13,7 @@ def create_ui():
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
extras_image = gr.Image(label="Source", interactive=True, type="pil", elem_id="extras_image", image_mode="RGBA", height="55vh")
extras_image = ForgeCanvas(elem_id="extras_image", height=512, no_scribbles=True).background
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")

View File

@@ -93,14 +93,13 @@ class UpscalerData:
scaler: Upscaler = None
model: None
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None, sha256: str = None):
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
self.local_data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model
self.sha256 = sha256
def __repr__(self):
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"

View File

@@ -211,66 +211,3 @@ Requested path was: {path}
subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])])
else:
subprocess.Popen(["xdg-open", path])
def load_file_from_url(
url: str,
*,
model_dir: str,
progress: bool = True,
file_name: str | None = None,
hash_prefix: str | None = None,
re_download: bool = False,
) -> str:
"""Download a file from `url` into `model_dir`, using the file present if possible.
Returns the path to the downloaded file.
file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
if the hash does not match, the temporary file is deleted and a ValueError is raised.
re_download: forcibly re-download the file even if it already exists.
"""
from urllib.parse import urlparse
import requests
from tqdm import tqdm
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if re_download or not os.path.exists(cached_file):
os.makedirs(model_dir, exist_ok=True)
temp_file = os.path.join(model_dir, f"{file_name}.tmp")
print(f'\nDownloading: "{url}" to {cached_file}')
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar:
with open(temp_file, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
progress_bar.update(len(chunk))
if hash_prefix and not compare_sha256(temp_file, hash_prefix):
print(f"Hash mismatch for {temp_file}. Deleting the temporary file.")
os.remove(temp_file)
raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!")
os.rename(temp_file, cached_file)
return cached_file
def compare_sha256(file_path: str, hash_prefix: str) -> bool:
"""Check if the SHA256 hash of the file matches the given prefix."""
import hashlib
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())

View File

@@ -2,7 +2,7 @@ import pkg_resources
from modules.launch_utils import run_pip
target_bitsandbytes_version = '0.45.3'
target_bitsandbytes_version = '0.43.3'
def try_install_bnb():

View File

@@ -5,16 +5,6 @@
overflow: hidden;
}
.forge-image-container-plain {
width: 100%;
height: calc(100% - 6px);
position: relative;
overflow: hidden;
background-color: #202020;
background-size: 20px 20px;
background-position: 0 0, 10px 10px;
}
.forge-image-container {
width: 100%;
height: calc(100% - 6px);
@@ -58,16 +48,6 @@
left: 0;
}
.forge-toolbar-static {
position: absolute;
top: 0px;
left: 0px;
z-index: 10 !important;
background: rgba(47, 47, 47, 0.8);
padding: 6px 10px;
opacity: 1.0 !important;
}
.forge-toolbar {
position: absolute;
top: 0px;
@@ -79,7 +59,7 @@
transition: opacity 0.3s ease;
}
.forge-toolbar .forge-btn, .forge-toolbar-static .forge-btn {
.forge-toolbar .forge-btn {
padding: 2px 6px;
border: none;
background-color: #4a4a4a;
@@ -89,11 +69,11 @@
transition: background-color 0.3s ease;
}
.forge-toolbar .forge-btn, .forge-toolbar-static .forge-btn:hover {
.forge-toolbar .forge-btn:hover {
background-color: #5e5e5e;
}
.forge-toolbar .forge-btn, .forge-toolbar-static .forge-btn:active {
.forge-toolbar .forge-btn:active {
background-color: #3e3e3e;
}
@@ -176,4 +156,4 @@
width: 30%;
height: 30%;
transform: translate(-50%, -50%);
}
}

View File

@@ -32,7 +32,6 @@ from io import BytesIO
from gradio.context import Context
from functools import wraps
from modules.shared import opts
canvas_js_root_path = os.path.dirname(__file__)
@@ -137,15 +136,7 @@ class ForgeCanvas:
elem_classes=None
):
self.uuid = 'uuid_' + uuid.uuid4().hex
canvas_html_uuid = canvas_html.replace('forge_mixin', self.uuid)
if opts.forge_canvas_plain:
canvas_html_uuid = canvas_html_uuid.replace('class="forge-image-container"', 'class="forge-image-container-plain"').replace('stroke="white"', 'stroke=#444')
if opts.forge_canvas_toolbar_always:
canvas_html_uuid = canvas_html_uuid.replace('class="forge-toolbar"', 'class="forge-toolbar-static"')
self.block = gr.HTML(canvas_html_uuid, visible=visible, elem_id=elem_id, elem_classes=elem_classes)
self.block = gr.HTML(canvas_html.replace('forge_mixin', self.uuid), visible=visible, elem_id=elem_id, elem_classes=elem_classes)
self.foreground = LogicalImage(visible=DEBUG_MODE, label='foreground', numpy=numpy, elem_id=self.uuid, elem_classes=['logical_image_foreground'])
self.background = LogicalImage(visible=DEBUG_MODE, label='background', numpy=numpy, value=initial_image, elem_id=self.uuid, elem_classes=['logical_image_background'])
Context.root_block.load(None, js=f'async ()=>{{new ForgeCanvas("{self.uuid}", {no_upload}, {no_scribbles}, {contrast_scribbles}, {height}, '

View File

@@ -0,0 +1,30 @@
# See also: https://github.com/lllyasviel/google_blockly_prototypes/blob/main/LICENSE_pyz
import os
import gzip
import importlib.util
pyz_dir = os.path.abspath(os.path.realpath(os.path.join(__file__, '../../repositories/google_blockly_prototypes/forge')))
module_suffix = ".pyz"
def initialization():
print('Loading additional modules ... ', end='')
for filename in os.listdir(pyz_dir):
if not filename.endswith(module_suffix):
continue
module_name = filename[:-len(module_suffix)]
module_package_name = __package__ + '.' + module_name
dynamic_module = importlib.util.module_from_spec(importlib.util.spec_from_loader(module_package_name, loader=None))
dynamic_module.__dict__['__file__'] = os.path.join(pyz_dir, module_name + '.py')
dynamic_module.__dict__['__package__'] = module_package_name
google_blockly_context = gzip.open(os.path.join(pyz_dir, filename), 'rb').read().decode('utf-8')
exec(google_blockly_context, dynamic_module.__dict__)
globals()[module_name] = dynamic_module
print('done.')
return

View File

@@ -1,3 +1,4 @@
def register(options_templates, options_section, OptionInfo):
options_templates.update(options_section((None, "Forge Hidden options"), {
"forge_unet_storage_dtype": OptionInfo('Automatic'),
@@ -7,7 +8,3 @@ def register(options_templates, options_section, OptionInfo):
"forge_preset": OptionInfo('sd'),
"forge_additional_modules": OptionInfo([]),
}))
options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), {
"forge_canvas_plain": OptionInfo(False, "ForgeCanvas: use plain background").needs_reload_ui(),
"forge_canvas_toolbar_always": OptionInfo(False, "ForgeCanvas: toolbar always visible").needs_reload_ui(),
}))

View File

@@ -30,19 +30,6 @@ LORA_CLIP_MAP = {
def load_lora(lora, to_load):
# BFL loras for Flux; from ComfyUI: comfy/lora_convert.py
def convert_lora_bfl_control(sd):
import torch
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
return sd_out
if "img_in.lora_A.weight" in lora and "single_blocks.0.norm.key_norm.scale" in lora:
lora = convert_lora_bfl_control(lora)
patch_dict = {}
loaded_keys = set()
for x in to_load:
@@ -202,12 +189,6 @@ def load_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
set_weight_name = "{}.set_weight".format(x)
set_weight = lora.get(set_weight_name, None)
if set_weight is not None:
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
return patch_dict, remaining_dict
@@ -253,32 +234,32 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
#####
lora_key = "lora_te2_{}".format(l_key.replace(".", "_"))#OneTrainer Flux lora, by Forge
key_map[lora_key] = k
#####
# for k in sdk:
# if k.endswith(".weight"):
# if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
# l_key = k[len("t5xxl.transformer."):-len(".weight")]
# lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
# key_map[lora_key] = k
#
# #####
# lora_key = "lora_te2_{}".format(l_key.replace(".", "_"))#OneTrainer Flux lora, by Forge
# key_map[lora_key] = k
# #####
# elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
# l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
# lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
# key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
# key_map["lora_prior_te_text_projection"] = k #cascade lora?
key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
k = "clip_l.transformer.text_projection.weight"
if k in sdk:
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
#
#
# k = "clip_g.transformer.text_projection.weight"
# if k in sdk:
# key_map["lora_prior_te_text_projection"] = k #cascade lora?
# # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
#
# k = "clip_l.transformer.text_projection.weight"
# if k in sdk:
# key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
return sdk, key_map
@@ -288,13 +269,11 @@ def model_lora_keys_unet(model, key_map={}):
sdk = sd.keys()
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
diffusers_keys = utils.unet_to_diffusers(model.diffusion_model.config)
for k in diffusers_keys:
@@ -302,8 +281,7 @@ def model_lora_keys_unet(model, key_map={}):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix:
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
@@ -311,19 +289,19 @@ def model_lora_keys_unet(model, key_map={}):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
# if 'stable-diffusion-3' in model.config.huggingface_repo.lower(): #Diffusers lora SD3
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
# for k in diffusers_keys:
# if k.endswith(".weight"):
# to = diffusers_keys[k]
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
# key_map[key_lora] = to
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
# key_map[key_lora] = to
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
# key_map[key_lora] = to
# if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
# diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
# for k in diffusers_keys:
# if k.endswith(".weight"):
# to = diffusers_keys[k]
# key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
# key_map[key_lora] = to
#
# key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
# key_map[key_lora] = to
#
# key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
# key_map[key_lora] = to
#
# if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
# diffusers_keys = utils.auraflow_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")

View File

@@ -268,9 +268,6 @@ class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
@classmethod

View File

@@ -1,8 +1,8 @@
setuptools==69.5.1 # temp fix for compatibility with some old packages
GitPython==3.1.32
Pillow==10.4.0
accelerate==0.31.0
blendmodes==2024.1.1
Pillow==9.5.0
accelerate==0.21.0
blendmodes==2022
clean-fid==0.1.35
diskcache==5.6.3
einops==0.4.1
@@ -14,7 +14,7 @@ inflection==0.5.1
jsonmerge==1.8.0
kornia==0.6.7
lark==1.1.2
numpy==1.26.4
numpy==1.26.2
omegaconf==2.2.3
open-clip-torch==2.20.0
piexif==1.1.3
@@ -30,14 +30,14 @@ tomesd==0.1.3
torch
torchdiffeq==0.2.3
torchsde==0.2.6
transformers==4.46.1
transformers==4.44.0
httpx==0.24.1
pillow-avif-plugin==1.4.3
diffusers==0.31.0
diffusers==0.29.2
gradio_rangeslider==0.0.6
gradio_imageslider==0.0.20
loadimg==0.1.2
tqdm==4.66.1
peft==0.13.2
peft==0.12.0
pydantic==2.8.2
huggingface-hub==0.26.2
huggingface-hub==0.24.6

View File

@@ -8,8 +8,7 @@ import gradio as gr
from modules import sd_samplers, errors, sd_models
from modules.processing import Processed, process_images
from modules.shared import state
from modules.images import image_grid, save_image
from modules.shared import opts
def process_model_tag(tag):
info = sd_models.get_closet_checkpoint_match(tag)
@@ -102,10 +101,10 @@ def cmdargs(line):
def load_prompt_file(file):
if file is None:
return None, gr.update()
return None, gr.update(), gr.update(lines=7)
else:
lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
return None, "\n".join(lines)
return None, "\n".join(lines), gr.update(lines=7)
class Script(scripts.Script):
@@ -116,16 +115,19 @@ class Script(scripts.Script):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
make_combined = gr.Checkbox(label="Make a combined image containing all outputs (if more than one)", value=False)
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=2, elem_id=self.elem_id("prompt_txt"))
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.upload(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt], show_progress=False)
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt], show_progress=False)
return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt, make_combined]
# We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str, make_combined):
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
p.do_not_save_grid = True
@@ -186,36 +188,4 @@ class Script(scripts.Script):
all_prompts += proc.all_prompts
infotexts += proc.infotexts
if make_combined and len(images) > 1:
combined_image = image_grid(images, batch_size=1, rows=None).convert("RGB")
full_infotext = "\n".join(infotexts)
is_img2img = getattr(p, "init_images", None) is not None
if opts.grid_save: # use grid specific Settings
save_image(
combined_image,
opts.outdir_grids or (opts.outdir_img2img_grids if is_img2img else opts.outdir_txt2img_grids),
"",
-1,
prompt_txt,
opts.grid_format,
full_infotext,
grid=True
)
else: # use normal output Settings
save_image(
combined_image,
opts.outdir_samples or (opts.outdir_img2img_samples if is_img2img else opts.outdir_txt2img_samples),
"",
-1,
prompt_txt,
opts.samples_format,
full_infotext
)
images.insert(0, combined_image)
all_prompts.insert(0, prompt_txt)
infotexts.insert(0, full_infotext)
return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)

Some files were not shown because too many files have changed in this diff Show More