mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Did some work on SD rescaler. Need to run a long test on it eventually.
This commit is contained in:
@@ -114,7 +114,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
tokenizer=self.sd.tokenizer[0],
|
tokenizer=self.sd.tokenizer[0],
|
||||||
tokenizer_2=self.sd.tokenizer[1],
|
tokenizer_2=self.sd.tokenizer[1],
|
||||||
scheduler=self.sd.noise_scheduler,
|
scheduler=self.sd.noise_scheduler,
|
||||||
)
|
).to(self.device_torch)
|
||||||
else:
|
else:
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
vae=self.sd.vae,
|
vae=self.sd.vae,
|
||||||
@@ -125,7 +125,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
)
|
).to(self.device_torch)
|
||||||
# disable progress bar
|
# disable progress bar
|
||||||
pipeline.set_progress_bar_config(disable=True)
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
@@ -387,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
text_embeddings: PromptEmbeds,
|
text_embeddings: PromptEmbeds,
|
||||||
timestep: int,
|
timestep: int,
|
||||||
guidance_scale=7.5,
|
guidance_scale=7.5,
|
||||||
guidance_rescale=0, # 0.7
|
guidance_rescale=0, # 0.7
|
||||||
add_time_ids=None,
|
add_time_ids=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -585,17 +585,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
if self.network_config is not None:
|
if self.network_config is not None:
|
||||||
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
|
|
||||||
self.network = LoRASpecialNetwork(
|
self.network = LoRASpecialNetwork(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
lora_dim=self.network_config.linear,
|
lora_dim=self.network_config.linear,
|
||||||
multiplier=1.0,
|
multiplier=1.0,
|
||||||
alpha=self.network_config.alpha,
|
alpha=self.network_config.linear_alpha,
|
||||||
train_unet=self.train_config.train_unet,
|
train_unet=self.train_config.train_unet,
|
||||||
train_text_encoder=self.train_config.train_text_encoder,
|
train_text_encoder=self.train_config.train_text_encoder,
|
||||||
conv_lora_dim=conv,
|
conv_lora_dim=self.network_config.conv,
|
||||||
conv_alpha=self.network_config.alpha if conv is not None else None,
|
conv_alpha=self.network_config.conv_alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.network.force_to(self.device_torch, dtype=dtype)
|
self.network.force_to(self.device_torch, dtype=dtype)
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class RescaleConfig:
|
|||||||
self.prompt_file = kwargs.get('prompt_file', None)
|
self.prompt_file = kwargs.get('prompt_file', None)
|
||||||
self.prompt_tensors = kwargs.get('prompt_tensors', None)
|
self.prompt_tensors = kwargs.get('prompt_tensors', None)
|
||||||
self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
|
self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
|
||||||
|
self.prompt_dropout = kwargs.get('prompt_dropout', 0.1)
|
||||||
|
|
||||||
if self.prompt_file is None:
|
if self.prompt_file is None:
|
||||||
raise ValueError("prompt_file is required")
|
raise ValueError("prompt_file is required")
|
||||||
@@ -64,7 +65,7 @@ class PromptEmbedsCache:
|
|||||||
class TrainSDRescaleProcess(BaseSDTrainProcess):
|
class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
# pass our custom pipeline to super so it sets it up
|
# pass our custom pipeline to super so it sets it up
|
||||||
super().__init__(process_id, job, config, custom_pipeline=TransferStableDiffusionXLPipeline)
|
super().__init__(process_id, job, config)
|
||||||
self.step_num = 0
|
self.step_num = 0
|
||||||
self.start_step = 0
|
self.start_step = 0
|
||||||
self.device = self.get_conf('device', self.job.device)
|
self.device = self.get_conf('device', self.job.device)
|
||||||
@@ -158,31 +159,36 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
|||||||
def hook_train_loop(self):
|
def hook_train_loop(self):
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
|
do_dropout = False
|
||||||
|
|
||||||
|
# see if we should dropout
|
||||||
|
if self.rescale_config.prompt_dropout > 0.0:
|
||||||
|
thresh = int(self.rescale_config.prompt_dropout * 100)
|
||||||
|
if torch.randint(0, 100, (1,)).item() < thresh:
|
||||||
|
do_dropout = True
|
||||||
|
|
||||||
# get random encoded prompt from cache
|
# get random encoded prompt from cache
|
||||||
prompt_txt = self.prompt_txt_list[
|
positive_prompt_txt = self.prompt_txt_list[
|
||||||
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
||||||
]
|
]
|
||||||
prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype)
|
negative_prompt_txt = self.prompt_txt_list[
|
||||||
prompt.text_embeds.to(device=self.device_torch, dtype=dtype)
|
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
||||||
neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype)
|
]
|
||||||
neutral.text_embeds.to(device=self.device_torch, dtype=dtype)
|
if do_dropout:
|
||||||
if hasattr(prompt, 'pooled_embeds') \
|
positive_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype)
|
||||||
and hasattr(neutral, 'pooled_embeds') \
|
negative_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype)
|
||||||
and prompt.pooled_embeds is not None \
|
else:
|
||||||
and neutral.pooled_embeds is not None:
|
positive_prompt = self.prompt_cache[positive_prompt_txt].to(device=self.device_torch, dtype=dtype)
|
||||||
prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype)
|
negative_prompt = self.prompt_cache[negative_prompt_txt].to(device=self.device_torch, dtype=dtype)
|
||||||
neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype)
|
|
||||||
|
|
||||||
if prompt is None:
|
if positive_prompt is None:
|
||||||
raise ValueError(f"Prompt {prompt_txt} is not in cache")
|
raise ValueError(f"Prompt {positive_prompt_txt} is not in cache")
|
||||||
|
if negative_prompt is None:
|
||||||
|
raise ValueError(f"Prompt {negative_prompt_txt} is not in cache")
|
||||||
|
|
||||||
loss_function = torch.nn.MSELoss()
|
loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# self.sd.noise_scheduler.set_timesteps(
|
|
||||||
# self.train_config.max_denoising_steps, device=self.device_torch
|
|
||||||
# )
|
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
# # ger a random number of steps
|
# # ger a random number of steps
|
||||||
@@ -190,63 +196,89 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
|||||||
1, self.train_config.max_denoising_steps, (1,)
|
1, self.train_config.max_denoising_steps, (1,)
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
|
# set the scheduler to the number of steps
|
||||||
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
|
timesteps_to, device=self.device_torch
|
||||||
|
)
|
||||||
|
|
||||||
# get noise
|
# get noise
|
||||||
latents = self.get_latent_noise(
|
noise = self.get_latent_noise(
|
||||||
pixel_height=self.rescale_config.from_resolution,
|
pixel_height=self.rescale_config.from_resolution,
|
||||||
pixel_width=self.rescale_config.from_resolution,
|
pixel_width=self.rescale_config.from_resolution,
|
||||||
).to(self.device_torch, dtype=dtype)
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
self.sd.pipeline.to(self.device_torch)
|
|
||||||
torch.set_default_device(self.device_torch)
|
torch.set_default_device(self.device_torch)
|
||||||
|
|
||||||
# turn off progress bar
|
# get latents
|
||||||
self.sd.pipeline.set_progress_bar_config(disable=True)
|
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||||
|
latents = latents.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
# get random guidance scale from 1.0 to 10.0
|
# get random guidance scale from 1.0 to 10.0 (CFG)
|
||||||
guidance_scale = torch.rand(1).item() * 9.0 + 1.0
|
guidance_scale = torch.rand(1).item() * 9.0 + 1.0
|
||||||
|
|
||||||
loss_arr = []
|
loss_arr = []
|
||||||
|
|
||||||
|
|
||||||
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
|
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
|
||||||
# pad with spaces
|
# pad with spaces
|
||||||
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
|
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
|
||||||
new_description = f"{self.job.name} ts: {timestep_str}"
|
new_description = f"{self.job.name} ts: {timestep_str}"
|
||||||
self.progress_bar.set_description(new_description)
|
self.progress_bar.set_description(new_description)
|
||||||
|
|
||||||
def pre_condition_callback(target_pred, input_latents):
|
# Begin gradient accumulation
|
||||||
# handle any manipulations before feeding to our network
|
self.optimizer.zero_grad()
|
||||||
reduced_pred = self.reduce_size_fn(target_pred)
|
|
||||||
reduced_latents = self.reduce_size_fn(input_latents)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
return reduced_pred, reduced_latents
|
|
||||||
|
|
||||||
def each_step_callback(noise_target, noise_train_pred):
|
# perform the diffusion
|
||||||
noise_target.requires_grad = False
|
for timestep in tqdm(self.sd.noise_scheduler.timesteps, leave=False):
|
||||||
loss = loss_function(noise_target, noise_train_pred)
|
assert not self.network.is_active
|
||||||
loss_arr.append(loss.item())
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
# run the pipeline
|
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||||
self.sd.pipeline.transfer_diffuse(
|
negative_prompt, # unconditional (negative prompt)
|
||||||
num_inference_steps=timesteps_to,
|
positive_prompt, # conditional (positive prompt)
|
||||||
latents=latents,
|
self.train_config.batch_size,
|
||||||
prompt_embeds=prompt.text_embeds,
|
|
||||||
negative_prompt_embeds=neutral.text_embeds,
|
|
||||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
|
||||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
|
||||||
output_type="latent",
|
|
||||||
num_images_per_prompt=self.train_config.batch_size,
|
|
||||||
guidance_scale=guidance_scale,
|
|
||||||
network=self.network,
|
|
||||||
target_unet=self.sd.unet,
|
|
||||||
pre_condition_callback=pre_condition_callback,
|
|
||||||
each_step_callback=each_step_callback,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
noise_pred_target = self.predict_noise(
|
||||||
|
latents,
|
||||||
|
text_embeddings=text_embeddings,
|
||||||
|
timestep=timestep,
|
||||||
|
guidance_scale=guidance_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# todo should we do every step?
|
||||||
|
do_train_cycle = True
|
||||||
|
|
||||||
|
if do_train_cycle:
|
||||||
|
# get the reduced latents
|
||||||
|
with torch.no_grad():
|
||||||
|
reduced_pred = self.reduce_size_fn(noise_pred_target.detach())
|
||||||
|
reduced_latents = self.reduce_size_fn(latents.detach())
|
||||||
|
with self.network:
|
||||||
|
assert self.network.is_active
|
||||||
|
self.network.multiplier = 1.0
|
||||||
|
noise_pred_train = self.predict_noise(
|
||||||
|
reduced_latents,
|
||||||
|
text_embeddings=text_embeddings,
|
||||||
|
timestep=timestep,
|
||||||
|
guidance_scale=guidance_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
reduced_pred.requires_grad = False
|
||||||
|
loss = loss_function(noise_pred_train, reduced_pred)
|
||||||
|
loss_arr.append(loss.item())
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# get next latents
|
||||||
|
# todo allow to show latent here
|
||||||
|
latents = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample
|
||||||
|
|
||||||
|
# reset prompt embeds
|
||||||
|
positive_prompt.to(device="cpu")
|
||||||
|
negative_prompt.to(device="cpu")
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
# reset network
|
# reset network
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ class NetworkConfig:
|
|||||||
self.linear: int = linear
|
self.linear: int = linear
|
||||||
self.conv: int = kwargs.get('conv', None)
|
self.conv: int = kwargs.get('conv', None)
|
||||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||||
|
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||||
|
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||||
|
|
||||||
|
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
|
|||||||
@@ -241,6 +241,9 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
|
|
||||||
@multiplier.setter
|
@multiplier.setter
|
||||||
def multiplier(self, value):
|
def multiplier(self, value):
|
||||||
|
# only update if changed
|
||||||
|
if self._multiplier == value:
|
||||||
|
return
|
||||||
self._multiplier = value
|
self._multiplier = value
|
||||||
self._update_lora_multiplier()
|
self._update_lora_multiplier()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user