mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Various features and fixes. Too much brain fog to do a proper description
This commit is contained in:
@@ -356,7 +356,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# we have to encode images into latents for now
|
# we have to encode images into latents for now
|
||||||
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
|
||||||
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
|
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
|
||||||
target = unaugmented_latents.detach()
|
target = unaugmented_latents.detach()
|
||||||
|
|
||||||
@@ -907,6 +907,17 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.timer.start('preprocess_batch')
|
self.timer.start('preprocess_batch')
|
||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess_batch(batch)
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
# sanity check
|
||||||
|
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||||
|
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||||
|
if isinstance(self.sd.text_encoder, list):
|
||||||
|
for encoder in self.sd.text_encoder:
|
||||||
|
if encoder.dtype != self.sd.te_torch_dtype:
|
||||||
|
encoder.to(self.sd.te_torch_dtype)
|
||||||
|
else:
|
||||||
|
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||||
|
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||||
|
|
||||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||||
# pick random negative prompts
|
# pick random negative prompts
|
||||||
|
|||||||
@@ -83,7 +83,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
self.network_config = None
|
self.network_config = None
|
||||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
model_config = self.get_conf('model', {})
|
||||||
|
|
||||||
|
# update modelconfig dtype to match train
|
||||||
|
model_config['dtype'] = self.train_config.dtype
|
||||||
|
self.model_config = ModelConfig(**model_config)
|
||||||
|
|
||||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||||
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
||||||
first_sample_config = self.get_conf('first_sample', None)
|
first_sample_config = self.get_conf('first_sample', None)
|
||||||
@@ -723,6 +728,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
noise_offset=self.train_config.noise_offset,
|
noise_offset=self.train_config.noise_offset,
|
||||||
).to(self.device_torch, dtype=dtype)
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
if self.train_config.random_noise_shift > 0.0:
|
||||||
|
# get random noise -1 to 1
|
||||||
|
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||||
|
dtype=noise.dtype) * 2 - 1
|
||||||
|
|
||||||
|
# multiply by shift amount
|
||||||
|
noise_shift *= self.train_config.random_noise_shift
|
||||||
|
|
||||||
|
# add to noise
|
||||||
|
noise += noise_shift
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||||
|
|||||||
@@ -160,6 +160,8 @@ class AdapterConfig:
|
|||||||
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
|
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
|
||||||
if self.train_only_image_encoder:
|
if self.train_only_image_encoder:
|
||||||
self.train_image_encoder = True
|
self.train_image_encoder = True
|
||||||
|
self.train_only_image_encoder_positional_embedding: bool = kwargs.get(
|
||||||
|
'train_only_image_encoder_positional_embedding', False)
|
||||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
||||||
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
||||||
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
||||||
@@ -260,6 +262,7 @@ class TrainConfig:
|
|||||||
# multiplier applied to loos on regularization images
|
# multiplier applied to loos on regularization images
|
||||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||||
|
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||||
|
|
||||||
# dropout that happens before encoding. It functions independently per text encoder
|
# dropout that happens before encoding. It functions independently per text encoder
|
||||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||||
@@ -385,6 +388,11 @@ class ModelConfig:
|
|||||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
|
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
|
||||||
self.unet_path = kwargs.get("unet_path", None)
|
self.unet_path = kwargs.get("unet_path", None)
|
||||||
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
||||||
|
self.vae_device = kwargs.get("vae_device", None)
|
||||||
|
self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
|
||||||
|
self.te_device = kwargs.get("te_device", None)
|
||||||
|
self.te_dtype = kwargs.get("te_dtype", self.dtype)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EMAConfig:
|
class EMAConfig:
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
elif adapter_config.type == 'ip+':
|
elif adapter_config.type == 'ip+':
|
||||||
heads = 12 if not sd.is_xl else 20
|
heads = 12 if not sd.is_xl else 20
|
||||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||||
embedding_dim = self.image_encoder.config.target_hidden_size if not self.config.image_encoder_arch.startswith(
|
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
|
||||||
'convnext') else \
|
'convnext') else \
|
||||||
self.image_encoder.config.hidden_sizes[-1]
|
self.image_encoder.config.hidden_sizes[-1]
|
||||||
|
|
||||||
@@ -964,7 +964,10 @@ class IPAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||||
if self.config.train_only_image_encoder:
|
if self.config.train_only_image_encoder:
|
||||||
yield from self.image_encoder.parameters(recurse)
|
if self.config.train_only_image_encoder_positional_embedding:
|
||||||
|
yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse)
|
||||||
|
else:
|
||||||
|
yield from self.image_encoder.parameters(recurse)
|
||||||
return
|
return
|
||||||
if self.config.train_scaler:
|
if self.config.train_scaler:
|
||||||
# no params
|
# no params
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
|
||||||
from toolkit.custom_adapter import CustomAdapter
|
from toolkit.custom_adapter import CustomAdapter
|
||||||
|
|
||||||
|
|
||||||
@@ -202,6 +202,10 @@ class TEAdapterAttnProcessor(nn.Module):
|
|||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
# remove attn mask if doing clip
|
||||||
|
if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip":
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
@@ -246,7 +250,7 @@ class TEAdapter(torch.nn.Module):
|
|||||||
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||||
self.token_size = self.te_ref().config.d_model
|
self.token_size = self.te_ref().config.d_model
|
||||||
else:
|
else:
|
||||||
self.token_size = self.te_ref().config.target_hidden_size
|
self.token_size = self.te_ref().config.hidden_size
|
||||||
|
|
||||||
# add text projection if is sdxl
|
# add text projection if is sdxl
|
||||||
self.text_projection = None
|
self.text_projection = None
|
||||||
@@ -388,8 +392,17 @@ class TEAdapter(torch.nn.Module):
|
|||||||
# ).input_ids.to(te.device)
|
# ).input_ids.to(te.device)
|
||||||
# outputs = te(input_ids=input_ids)
|
# outputs = te(input_ids=input_ids)
|
||||||
# outputs = outputs.last_hidden_state
|
# outputs = outputs.last_hidden_state
|
||||||
|
if self.adapter_ref().config.text_encoder_arch == "clip":
|
||||||
|
embeds = train_tools.encode_prompts(
|
||||||
|
tokenizer,
|
||||||
|
te,
|
||||||
|
text,
|
||||||
|
truncate=True,
|
||||||
|
max_length=self.adapter_ref().config.num_tokens,
|
||||||
|
)
|
||||||
|
attention_mask = torch.ones(embeds.shape[:2], device=embeds.device)
|
||||||
|
|
||||||
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
elif self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||||
# just use aura pile
|
# just use aura pile
|
||||||
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -407,7 +420,8 @@ class TEAdapter(torch.nn.Module):
|
|||||||
truncate=True,
|
truncate=True,
|
||||||
max_length=self.adapter_ref().config.num_tokens,
|
max_length=self.adapter_ref().config.num_tokens,
|
||||||
)
|
)
|
||||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
if attention_mask is not None:
|
||||||
|
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||||
if self.text_projection is not None:
|
if self.text_projection is not None:
|
||||||
# pool the output of embeds ignoring 0 in the attention mask
|
# pool the output of embeds ignoring 0 in the attention mask
|
||||||
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
|
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
|
||||||
@@ -420,19 +434,19 @@ class TEAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
pooled_embeds = self.text_projection(pooled_output)
|
pooled_embeds = self.text_projection(pooled_output)
|
||||||
|
|
||||||
t5_embeds = PromptEmbeds(
|
prompt_embeds = PromptEmbeds(
|
||||||
(embeds, pooled_embeds),
|
(embeds, pooled_embeds),
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
).detach()
|
).detach()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
t5_embeds = PromptEmbeds(
|
prompt_embeds = PromptEmbeds(
|
||||||
embeds,
|
embeds,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
).detach()
|
).detach()
|
||||||
|
|
||||||
return t5_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -123,6 +123,13 @@ class StableDiffusion:
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.torch_dtype = get_torch_dtype(dtype)
|
self.torch_dtype = get_torch_dtype(dtype)
|
||||||
self.device_torch = torch.device(self.device)
|
self.device_torch = torch.device(self.device)
|
||||||
|
|
||||||
|
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device)
|
||||||
|
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||||
|
|
||||||
|
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device)
|
||||||
|
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||||
|
|
||||||
@@ -220,11 +227,13 @@ class StableDiffusion:
|
|||||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||||
for text_encoder in text_encoders:
|
for text_encoder in text_encoders:
|
||||||
text_encoder.to(self.device_torch, dtype=dtype)
|
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
text_encoder.eval()
|
text_encoder.eval()
|
||||||
text_encoder = text_encoders
|
text_encoder = text_encoders
|
||||||
|
|
||||||
|
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||||
|
|
||||||
if self.model_config.experimental_xl:
|
if self.model_config.experimental_xl:
|
||||||
print("Experimental XL mode enabled")
|
print("Experimental XL mode enabled")
|
||||||
print("Loading and injecting alt weights")
|
print("Loading and injecting alt weights")
|
||||||
@@ -333,6 +342,8 @@ class StableDiffusion:
|
|||||||
# replace the to function with a no-op since it throws an error instead of a warning
|
# replace the to function with a no-op since it throws an error instead of a warning
|
||||||
text_encoder.to = lambda *args, **kwargs: None
|
text_encoder.to = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||||
|
|
||||||
if self.model_config.is_pixart_sigma:
|
if self.model_config.is_pixart_sigma:
|
||||||
# load the transformer only from the save
|
# load the transformer only from the save
|
||||||
transformer = Transformer2DModel.from_pretrained(
|
transformer = Transformer2DModel.from_pretrained(
|
||||||
@@ -375,6 +386,8 @@ class StableDiffusion:
|
|||||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||||
tokenizer = pipe.tokenizer
|
tokenizer = pipe.tokenizer
|
||||||
|
|
||||||
|
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
elif self.model_config.is_auraflow:
|
elif self.model_config.is_auraflow:
|
||||||
te_kwargs = {}
|
te_kwargs = {}
|
||||||
@@ -427,7 +440,7 @@ class StableDiffusion:
|
|||||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
# patch auraflow so it can handle other aspect ratios
|
# patch auraflow so it can handle other aspect ratios
|
||||||
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
# patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
# text_encoder = pipe.text_encoder
|
# text_encoder = pipe.text_encoder
|
||||||
@@ -442,6 +455,31 @@ class StableDiffusion:
|
|||||||
else:
|
else:
|
||||||
pipln = StableDiffusionPipeline
|
pipln = StableDiffusionPipeline
|
||||||
|
|
||||||
|
if self.model_config.text_encoder_bits < 16:
|
||||||
|
# this is only supported for T5 models for now
|
||||||
|
te_kwargs = {}
|
||||||
|
# handle quantization of TE
|
||||||
|
te_is_quantized = False
|
||||||
|
if self.model_config.text_encoder_bits == 8:
|
||||||
|
te_kwargs['load_in_8bit'] = True
|
||||||
|
te_kwargs['device_map'] = "auto"
|
||||||
|
te_is_quantized = True
|
||||||
|
elif self.model_config.text_encoder_bits == 4:
|
||||||
|
te_kwargs['load_in_4bit'] = True
|
||||||
|
te_kwargs['device_map'] = "auto"
|
||||||
|
te_is_quantized = True
|
||||||
|
|
||||||
|
text_encoder = T5EncoderModel.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
subfolder="text_encoder",
|
||||||
|
torch_dtype=self.te_torch_dtype,
|
||||||
|
**te_kwargs
|
||||||
|
)
|
||||||
|
# replace the to function with a no-op since it throws an error instead of a warning
|
||||||
|
text_encoder.to = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
load_args['text_encoder'] = text_encoder
|
||||||
|
|
||||||
# see if path exists
|
# see if path exists
|
||||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||||
# try to load with default diffusers
|
# try to load with default diffusers
|
||||||
@@ -455,7 +493,7 @@ class StableDiffusion:
|
|||||||
# variant="fp16",
|
# variant="fp16",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
**load_args
|
**load_args
|
||||||
).to(self.device_torch)
|
)
|
||||||
else:
|
else:
|
||||||
pipe = pipln.from_single_file(
|
pipe = pipln.from_single_file(
|
||||||
model_path,
|
model_path,
|
||||||
@@ -467,12 +505,12 @@ class StableDiffusion:
|
|||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
**load_args
|
**load_args
|
||||||
).to(self.device_torch)
|
)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
pipe.register_to_config(requires_safety_checker=False)
|
pipe.register_to_config(requires_safety_checker=False)
|
||||||
text_encoder = pipe.text_encoder
|
text_encoder = pipe.text_encoder
|
||||||
text_encoder.to(self.device_torch, dtype=dtype)
|
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
text_encoder.eval()
|
text_encoder.eval()
|
||||||
tokenizer = pipe.tokenizer
|
tokenizer = pipe.tokenizer
|
||||||
@@ -488,7 +526,7 @@ class StableDiffusion:
|
|||||||
self.unet = pipe.transformer
|
self.unet = pipe.transformer
|
||||||
else:
|
else:
|
||||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
|
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.requires_grad_(False)
|
self.vae.requires_grad_(False)
|
||||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||||
@@ -707,7 +745,7 @@ class StableDiffusion:
|
|||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
**extra_args
|
**extra_args
|
||||||
).to(self.device_torch)
|
)
|
||||||
flush()
|
flush()
|
||||||
# disable progress bar
|
# disable progress bar
|
||||||
pipeline.set_progress_bar_config(disable=True)
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
@@ -873,6 +911,9 @@ class StableDiffusion:
|
|||||||
if not self.is_xl:
|
if not self.is_xl:
|
||||||
raise ValueError("Refiner is only supported for XL models")
|
raise ValueError("Refiner is only supported for XL models")
|
||||||
|
|
||||||
|
conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
||||||
|
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
||||||
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
# fix guidance rescale for sdxl
|
# fix guidance rescale for sdxl
|
||||||
# was trained on 0.7 (I believe)
|
# was trained on 0.7 (I believe)
|
||||||
@@ -1014,6 +1055,7 @@ class StableDiffusion:
|
|||||||
self.network.train()
|
self.network.train()
|
||||||
self.network.multiplier = start_multiplier
|
self.network.multiplier = start_multiplier
|
||||||
|
|
||||||
|
self.unet.to(self.device_torch, dtype=self.torch_dtype)
|
||||||
if network.is_merged_in:
|
if network.is_merged_in:
|
||||||
network.merge_out(merge_multiplier)
|
network.merge_out(merge_multiplier)
|
||||||
# self.tokenizer.to(original_device_dict['tokenizer'])
|
# self.tokenizer.to(original_device_dict['tokenizer'])
|
||||||
@@ -1655,18 +1697,18 @@ class StableDiffusion:
|
|||||||
dtype=None
|
dtype=None
|
||||||
):
|
):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.device
|
device = self.vae_device_torch
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = self.torch_dtype
|
dtype = self.vae_torch_dtype
|
||||||
|
|
||||||
latent_list = []
|
latent_list = []
|
||||||
# Move to vae to device if on cpu
|
# Move to vae to device if on cpu
|
||||||
if self.vae.device == 'cpu':
|
if self.vae.device == 'cpu':
|
||||||
self.vae.to(self.device)
|
self.vae.to(device)
|
||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.requires_grad_(False)
|
self.vae.requires_grad_(False)
|
||||||
# move to device and dtype
|
# move to device and dtype
|
||||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||||
|
|
||||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||||
|
|
||||||
@@ -2158,7 +2200,7 @@ class StableDiffusion:
|
|||||||
# vae
|
# vae
|
||||||
state['vae'] = {
|
state['vae'] = {
|
||||||
'training': 'vae' in training_modules,
|
'training': 'vae' in training_modules,
|
||||||
'device': self.device_torch if 'vae' in active_modules else 'cpu',
|
'device': self.vae_device_torch if 'vae' in active_modules else 'cpu',
|
||||||
'requires_grad': 'vae' in training_modules,
|
'requires_grad': 'vae' in training_modules,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2182,13 +2224,13 @@ class StableDiffusion:
|
|||||||
for i, encoder in enumerate(self.text_encoder):
|
for i, encoder in enumerate(self.text_encoder):
|
||||||
state['text_encoder'].append({
|
state['text_encoder'].append({
|
||||||
'training': 'text_encoder' in training_modules,
|
'training': 'text_encoder' in training_modules,
|
||||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||||
'requires_grad': 'text_encoder' in training_modules,
|
'requires_grad': 'text_encoder' in training_modules,
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
state['text_encoder'] = {
|
state['text_encoder'] = {
|
||||||
'training': 'text_encoder' in training_modules,
|
'training': 'text_encoder' in training_modules,
|
||||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||||
'requires_grad': 'text_encoder' in training_modules,
|
'requires_grad': 'text_encoder' in training_modules,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user