mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 09:19:20 +00:00
Various features and fixes. Too much brain fog to do a proper description
This commit is contained in:
@@ -123,6 +123,13 @@ class StableDiffusion:
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
self.device_torch = torch.device(self.device)
|
||||
|
||||
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device)
|
||||
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||
|
||||
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device)
|
||||
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||
|
||||
self.model_config = model_config
|
||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
|
||||
@@ -220,11 +227,13 @@ class StableDiffusion:
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
|
||||
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
|
||||
if self.model_config.experimental_xl:
|
||||
print("Experimental XL mode enabled")
|
||||
print("Loading and injecting alt weights")
|
||||
@@ -333,6 +342,8 @@ class StableDiffusion:
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
|
||||
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||
|
||||
if self.model_config.is_pixart_sigma:
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
@@ -375,6 +386,8 @@ class StableDiffusion:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
|
||||
|
||||
elif self.model_config.is_auraflow:
|
||||
te_kwargs = {}
|
||||
@@ -427,7 +440,7 @@ class StableDiffusion:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# patch auraflow so it can handle other aspect ratios
|
||||
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
||||
# patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
||||
|
||||
flush()
|
||||
# text_encoder = pipe.text_encoder
|
||||
@@ -442,6 +455,31 @@ class StableDiffusion:
|
||||
else:
|
||||
pipln = StableDiffusionPipeline
|
||||
|
||||
if self.model_config.text_encoder_bits < 16:
|
||||
# this is only supported for T5 models for now
|
||||
te_kwargs = {}
|
||||
# handle quantization of TE
|
||||
te_is_quantized = False
|
||||
if self.model_config.text_encoder_bits == 8:
|
||||
te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
elif self.model_config.text_encoder_bits == 4:
|
||||
te_kwargs['load_in_4bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="text_encoder",
|
||||
torch_dtype=self.te_torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
|
||||
load_args['text_encoder'] = text_encoder
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# try to load with default diffusers
|
||||
@@ -455,7 +493,7 @@ class StableDiffusion:
|
||||
# variant="fp16",
|
||||
trust_remote_code=True,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
@@ -467,12 +505,12 @@ class StableDiffusion:
|
||||
safety_checker=None,
|
||||
trust_remote_code=True,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
)
|
||||
flush()
|
||||
|
||||
pipe.register_to_config(requires_safety_checker=False)
|
||||
text_encoder = pipe.text_encoder
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
tokenizer = pipe.tokenizer
|
||||
@@ -488,7 +526,7 @@ class StableDiffusion:
|
||||
self.unet = pipe.transformer
|
||||
else:
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
@@ -707,7 +745,7 @@ class StableDiffusion:
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
)
|
||||
flush()
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
@@ -873,6 +911,9 @@ class StableDiffusion:
|
||||
if not self.is_xl:
|
||||
raise ValueError("Refiner is only supported for XL models")
|
||||
|
||||
conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
||||
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
# was trained on 0.7 (I believe)
|
||||
@@ -1014,6 +1055,7 @@ class StableDiffusion:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
|
||||
self.unet.to(self.device_torch, dtype=self.torch_dtype)
|
||||
if network.is_merged_in:
|
||||
network.merge_out(merge_multiplier)
|
||||
# self.tokenizer.to(original_device_dict['tokenizer'])
|
||||
@@ -1655,18 +1697,18 @@ class StableDiffusion:
|
||||
dtype=None
|
||||
):
|
||||
if device is None:
|
||||
device = self.device
|
||||
device = self.vae_device_torch
|
||||
if dtype is None:
|
||||
dtype = self.torch_dtype
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
latent_list = []
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
# move to device and dtype
|
||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
@@ -2158,7 +2200,7 @@ class StableDiffusion:
|
||||
# vae
|
||||
state['vae'] = {
|
||||
'training': 'vae' in training_modules,
|
||||
'device': self.device_torch if 'vae' in active_modules else 'cpu',
|
||||
'device': self.vae_device_torch if 'vae' in active_modules else 'cpu',
|
||||
'requires_grad': 'vae' in training_modules,
|
||||
}
|
||||
|
||||
@@ -2182,13 +2224,13 @@ class StableDiffusion:
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
state['text_encoder'].append({
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
})
|
||||
else:
|
||||
state['text_encoder'] = {
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user