Various features and fixes. Too much brain fog to do a proper description

This commit is contained in:
Jaret Burkett
2024-07-18 07:34:14 -06:00
parent 58dffd43a8
commit 11e426fdf1
6 changed files with 119 additions and 25 deletions

View File

@@ -123,6 +123,13 @@ class StableDiffusion:
self.dtype = dtype
self.torch_dtype = get_torch_dtype(dtype)
self.device_torch = torch.device(self.device)
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device)
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device)
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
self.model_config = model_config
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
@@ -220,11 +227,13 @@ class StableDiffusion:
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
if self.model_config.experimental_xl:
print("Experimental XL mode enabled")
print("Loading and injecting alt weights")
@@ -333,6 +342,8 @@ class StableDiffusion:
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
if self.model_config.is_pixart_sigma:
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(
@@ -375,6 +386,8 @@ class StableDiffusion:
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
elif self.model_config.is_auraflow:
te_kwargs = {}
@@ -427,7 +440,7 @@ class StableDiffusion:
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
# patch auraflow so it can handle other aspect ratios
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
# patch_auraflow_pos_embed(pipe.transformer.pos_embed)
flush()
# text_encoder = pipe.text_encoder
@@ -442,6 +455,31 @@ class StableDiffusion:
else:
pipln = StableDiffusionPipeline
if self.model_config.text_encoder_bits < 16:
# this is only supported for T5 models for now
te_kwargs = {}
# handle quantization of TE
te_is_quantized = False
if self.model_config.text_encoder_bits == 8:
te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
elif self.model_config.text_encoder_bits == 4:
te_kwargs['load_in_4bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
text_encoder = T5EncoderModel.from_pretrained(
model_path,
subfolder="text_encoder",
torch_dtype=self.te_torch_dtype,
**te_kwargs
)
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
load_args['text_encoder'] = text_encoder
# see if path exists
if not os.path.exists(model_path) or os.path.isdir(model_path):
# try to load with default diffusers
@@ -455,7 +493,7 @@ class StableDiffusion:
# variant="fp16",
trust_remote_code=True,
**load_args
).to(self.device_torch)
)
else:
pipe = pipln.from_single_file(
model_path,
@@ -467,12 +505,12 @@ class StableDiffusion:
safety_checker=None,
trust_remote_code=True,
**load_args
).to(self.device_torch)
)
flush()
pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
tokenizer = pipe.tokenizer
@@ -488,7 +526,7 @@ class StableDiffusion:
self.unet = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
self.vae.eval()
self.vae.requires_grad_(False)
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
@@ -707,7 +745,7 @@ class StableDiffusion:
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
)
flush()
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
@@ -873,6 +911,9 @@ class StableDiffusion:
if not self.is_xl:
raise ValueError("Refiner is only supported for XL models")
conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
if self.is_xl:
# fix guidance rescale for sdxl
# was trained on 0.7 (I believe)
@@ -1014,6 +1055,7 @@ class StableDiffusion:
self.network.train()
self.network.multiplier = start_multiplier
self.unet.to(self.device_torch, dtype=self.torch_dtype)
if network.is_merged_in:
network.merge_out(merge_multiplier)
# self.tokenizer.to(original_device_dict['tokenizer'])
@@ -1655,18 +1697,18 @@ class StableDiffusion:
dtype=None
):
if device is None:
device = self.device
device = self.vae_device_torch
if dtype is None:
dtype = self.torch_dtype
dtype = self.vae_torch_dtype
latent_list = []
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
# move to device and dtype
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
image_list = [image.to(device, dtype=dtype) for image in image_list]
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
@@ -2158,7 +2200,7 @@ class StableDiffusion:
# vae
state['vae'] = {
'training': 'vae' in training_modules,
'device': self.device_torch if 'vae' in active_modules else 'cpu',
'device': self.vae_device_torch if 'vae' in active_modules else 'cpu',
'requires_grad': 'vae' in training_modules,
}
@@ -2182,13 +2224,13 @@ class StableDiffusion:
for i, encoder in enumerate(self.text_encoder):
state['text_encoder'].append({
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
'requires_grad': 'text_encoder' in training_modules,
})
else:
state['text_encoder'] = {
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
'requires_grad': 'text_encoder' in training_modules,
}