mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added some split prompting started code, adamw8bit, replacements improving, learnable snr gos. A lot of good stuff.
This commit is contained in:
@@ -12,7 +12,7 @@ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInf
|
|||||||
from .tools.fuyu_utils import FuyuImageProcessor
|
from .tools.fuyu_utils import FuyuImageProcessor
|
||||||
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
|
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
|
||||||
from .tools.llava_utils import LLaVAImageProcessor
|
from .tools.llava_utils import LLaVAImageProcessor
|
||||||
from .tools.caption import default_long_prompt, default_short_prompt
|
from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
|
||||||
from jobs.process import BaseExtensionProcess
|
from jobs.process import BaseExtensionProcess
|
||||||
from .tools.sync_tools import get_img_paths
|
from .tools.sync_tools import get_img_paths
|
||||||
|
|
||||||
@@ -39,6 +39,8 @@ class SuperTagger(BaseExtensionProcess):
|
|||||||
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
|
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
|
||||||
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
|
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
|
||||||
self.force_reprocess_img = config.get('force_reprocess_img', False)
|
self.force_reprocess_img = config.get('force_reprocess_img', False)
|
||||||
|
self.caption_replacements = config.get('caption_replacements', default_replacements)
|
||||||
|
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
|
||||||
self.master_dataset_dict = OrderedDict()
|
self.master_dataset_dict = OrderedDict()
|
||||||
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
|
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
|
||||||
if parent_dir is not None and len(self.dataset_paths) == 0:
|
if parent_dir is not None and len(self.dataset_paths) == 0:
|
||||||
@@ -118,7 +120,8 @@ class SuperTagger(BaseExtensionProcess):
|
|||||||
|
|
||||||
img_info.caption = self.image_processor.generate_caption(
|
img_info.caption = self.image_processor.generate_caption(
|
||||||
image=caption_image,
|
image=caption_image,
|
||||||
prompt=self.caption_prompt
|
prompt=self.caption_prompt,
|
||||||
|
replacements=self.caption_replacements
|
||||||
)
|
)
|
||||||
img_info.mark_step_complete(step)
|
img_info.mark_step_complete(step)
|
||||||
elif step == 'caption_short':
|
elif step == 'caption_short':
|
||||||
@@ -134,7 +137,8 @@ class SuperTagger(BaseExtensionProcess):
|
|||||||
self.image_processor.load_model()
|
self.image_processor.load_model()
|
||||||
img_info.caption_short = self.image_processor.generate_caption(
|
img_info.caption_short = self.image_processor.generate_caption(
|
||||||
image=caption_image,
|
image=caption_image,
|
||||||
prompt=self.caption_short_prompt
|
prompt=self.caption_short_prompt,
|
||||||
|
replacements=self.caption_short_replacements
|
||||||
)
|
)
|
||||||
img_info.mark_step_complete(step)
|
img_info.mark_step_complete(step)
|
||||||
elif step == 'contrast_stretch':
|
elif step == 'contrast_stretch':
|
||||||
|
|||||||
@@ -33,7 +33,13 @@ def clean_caption(cap, replacements=None):
|
|||||||
cap = " ".join(cap.split())
|
cap = " ".join(cap.split())
|
||||||
|
|
||||||
for replacement in replacements:
|
for replacement in replacements:
|
||||||
cap = cap.replace(replacement[0], replacement[1])
|
if replacement[0].startswith('*'):
|
||||||
|
# we are removing all text if it starts with this and the rest matches
|
||||||
|
search_text = replacement[0][1:]
|
||||||
|
if cap.startswith(search_text):
|
||||||
|
cap = ""
|
||||||
|
else:
|
||||||
|
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
|
||||||
|
|
||||||
cap_list = cap.split(",")
|
cap_list = cap.split(",")
|
||||||
# trim whitespace
|
# trim whitespace
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class LLaVAImageProcessor:
|
|||||||
output_ids = self.model.generate(
|
output_ids = self.model.generate(
|
||||||
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
|
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
|
||||||
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
|
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
|
||||||
top_p=0.9
|
top_p=0.8
|
||||||
)
|
)
|
||||||
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
||||||
conv.messages[-1][-1] = outputs
|
conv.messages[-1][-1] = outputs
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|||||||
from toolkit.ip_adapter import IPAdapter
|
from toolkit.ip_adapter import IPAdapter
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||||
|
apply_learnable_snr_gos, LearnableSNRGamma
|
||||||
import gc
|
import gc
|
||||||
import torch
|
import torch
|
||||||
from jobs.process import BaseSDTrainProcess
|
from jobs.process import BaseSDTrainProcess
|
||||||
@@ -59,6 +60,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.sd.vae.to('cpu')
|
self.sd.vae.to('cpu')
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
|
self.sd.noise_scheduler.set_timesteps(1000)
|
||||||
|
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||||
|
|
||||||
# you can expand these in a child class to make customization easier
|
# you can expand these in a child class to make customization easier
|
||||||
def calculate_loss(
|
def calculate_loss(
|
||||||
self,
|
self,
|
||||||
@@ -145,7 +149,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
if self.train_config.learnable_snr_gos:
|
||||||
|
# add snr_gamma
|
||||||
|
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||||
if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
||||||
# add snr_gamma
|
# add snr_gamma
|
||||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||||
@@ -315,14 +321,20 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# activate network if it exits
|
# activate network if it exits
|
||||||
|
|
||||||
# make the batch splits
|
prompts_1 = conditioned_prompts
|
||||||
|
prompts_2 = None
|
||||||
|
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
|
||||||
|
prompts_1 = batch.get_caption_short_list()
|
||||||
|
prompts_2 = conditioned_prompts
|
||||||
|
|
||||||
|
# make the batch splits
|
||||||
if self.train_config.single_item_batching:
|
if self.train_config.single_item_batching:
|
||||||
batch_size = noisy_latents.shape[0]
|
batch_size = noisy_latents.shape[0]
|
||||||
# chunk/split everything
|
# chunk/split everything
|
||||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||||
noise_list = torch.chunk(noise, batch_size, dim=0)
|
noise_list = torch.chunk(noise, batch_size, dim=0)
|
||||||
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
|
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
|
||||||
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts]
|
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
|
||||||
if imgs is not None:
|
if imgs is not None:
|
||||||
imgs_list = torch.chunk(imgs, batch_size, dim=0)
|
imgs_list = torch.chunk(imgs, batch_size, dim=0)
|
||||||
else:
|
else:
|
||||||
@@ -332,32 +344,44 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
else:
|
else:
|
||||||
adapter_images_list = [None for _ in range(batch_size)]
|
adapter_images_list = [None for _ in range(batch_size)]
|
||||||
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
||||||
|
if prompts_2 is None:
|
||||||
|
prompt_2_list = [None for _ in range(batch_size)]
|
||||||
|
else:
|
||||||
|
prompt_2_list = [[prompt] for prompt in prompts_2]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# but it all in an array
|
# but it all in an array
|
||||||
noisy_latents_list = [noisy_latents]
|
noisy_latents_list = [noisy_latents]
|
||||||
noise_list = [noise]
|
noise_list = [noise]
|
||||||
timesteps_list = [timesteps]
|
timesteps_list = [timesteps]
|
||||||
conditioned_prompts_list = [conditioned_prompts]
|
conditioned_prompts_list = [prompts_1]
|
||||||
imgs_list = [imgs]
|
imgs_list = [imgs]
|
||||||
adapter_images_list = [adapter_images]
|
adapter_images_list = [adapter_images]
|
||||||
mask_multiplier_list = [mask_multiplier]
|
mask_multiplier_list = [mask_multiplier]
|
||||||
|
if prompts_2 is None:
|
||||||
|
prompt_2_list = [None]
|
||||||
|
else:
|
||||||
|
prompt_2_list = [prompts_2]
|
||||||
|
|
||||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
|
|
||||||
|
|
||||||
|
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
|
||||||
noisy_latents_list,
|
noisy_latents_list,
|
||||||
noise_list,
|
noise_list,
|
||||||
timesteps_list,
|
timesteps_list,
|
||||||
conditioned_prompts_list,
|
conditioned_prompts_list,
|
||||||
imgs_list,
|
imgs_list,
|
||||||
adapter_images_list,
|
adapter_images_list,
|
||||||
mask_multiplier_list
|
mask_multiplier_list,
|
||||||
|
prompt_2_list
|
||||||
):
|
):
|
||||||
|
|
||||||
with network:
|
with network:
|
||||||
with self.timer('encode_prompt'):
|
with self.timer('encode_prompt'):
|
||||||
if grad_on_text_encoder:
|
if grad_on_text_encoder:
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
||||||
|
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
else:
|
else:
|
||||||
@@ -368,7 +392,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
te.eval()
|
te.eval()
|
||||||
else:
|
else:
|
||||||
self.sd.text_encoder.eval()
|
self.sd.text_encoder.eval()
|
||||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
||||||
|
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
@@ -36,7 +37,7 @@ from toolkit.stable_diffusion_model import StableDiffusion
|
|||||||
|
|
||||||
from jobs.process import BaseTrainProcess
|
from jobs.process import BaseTrainProcess
|
||||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -158,6 +159,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.named_lora = False
|
self.named_lora = False
|
||||||
if self.embed_config is not None or is_training_adapter:
|
if self.embed_config is not None or is_training_adapter:
|
||||||
self.named_lora = True
|
self.named_lora = True
|
||||||
|
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
||||||
|
|
||||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||||
# override in subclass
|
# override in subclass
|
||||||
@@ -370,6 +372,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
get_torch_dtype(self.save_config.dtype)
|
get_torch_dtype(self.save_config.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# save learnable params as json if we have thim
|
||||||
|
if self.snr_gos:
|
||||||
|
json_data = {
|
||||||
|
'offset': self.snr_gos.offset.item(),
|
||||||
|
'scale': self.snr_gos.scale.item(),
|
||||||
|
'gamma': self.snr_gos.gamma.item(),
|
||||||
|
}
|
||||||
|
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
|
||||||
|
with open(path_to_save, 'w') as f:
|
||||||
|
json.dump(json_data, f, indent=4)
|
||||||
|
|
||||||
self.print(f"Saved to {file_path}")
|
self.print(f"Saved to {file_path}")
|
||||||
self.clean_up_saves()
|
self.clean_up_saves()
|
||||||
self.post_save_hook(file_path)
|
self.post_save_hook(file_path)
|
||||||
@@ -789,6 +802,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
vae = vae.to(torch.device('cpu'), dtype=dtype)
|
vae = vae.to(torch.device('cpu'), dtype=dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
|
if self.train_config.learnable_snr_gos:
|
||||||
|
self.snr_gos = LearnableSNRGamma(
|
||||||
|
self.sd.noise_scheduler, device=self.device_torch
|
||||||
|
)
|
||||||
|
# check to see if previous settings exist
|
||||||
|
path_to_load = os.path.join(self.save_root, 'learnable_snr.json')
|
||||||
|
if os.path.exists(path_to_load):
|
||||||
|
with open(path_to_load, 'r') as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
||||||
|
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||||
|
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
### HOOk ###
|
### HOOk ###
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ torch
|
|||||||
torchvision
|
torchvision
|
||||||
safetensors
|
safetensors
|
||||||
diffusers==0.21.3
|
diffusers==0.21.3
|
||||||
git+https://github.com/huggingface/transformers.git@master
|
git+https://github.com/huggingface/transformers.git
|
||||||
lycoris-lora==1.8.3
|
lycoris-lora==1.8.3
|
||||||
flatten_json
|
flatten_json
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|||||||
@@ -169,6 +169,9 @@ class TrainConfig:
|
|||||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||||
self.snr_gamma = kwargs.get('snr_gamma', None)
|
self.snr_gamma = kwargs.get('snr_gamma', None)
|
||||||
|
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
|
||||||
|
# this should balance the learning rate across all timesteps over time
|
||||||
|
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
|
||||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||||
@@ -190,6 +193,8 @@ class TrainConfig:
|
|||||||
# Double up every image and run it through with both short and long captions. The idea
|
# Double up every image and run it through with both short and long captions. The idea
|
||||||
# is that the network will learn how to generate good images with both short and long captions
|
# is that the network will learn how to generate good images with both short and long captions
|
||||||
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
|
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
|
||||||
|
# if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only
|
||||||
|
self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False)
|
||||||
|
|
||||||
# basically gradient accumulation but we run just 1 item through the network
|
# basically gradient accumulation but we run just 1 item through the network
|
||||||
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
|
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ def get_optimizer(
|
|||||||
|
|
||||||
if lower_type == "adam8bit":
|
if lower_type == "adam8bit":
|
||||||
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
|
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
|
||||||
|
elif lower_type == "adamw8bit":
|
||||||
|
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
|
||||||
elif lower_type == "lion8bit":
|
elif lower_type == "lion8bit":
|
||||||
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -683,6 +683,68 @@ def get_all_snr(noise_scheduler, device):
|
|||||||
all_snr.requires_grad = False
|
all_snr.requires_grad = False
|
||||||
return all_snr.to(device)
|
return all_snr.to(device)
|
||||||
|
|
||||||
|
class LearnableSNRGamma:
|
||||||
|
"""
|
||||||
|
This is a trainer for learnable snr gamma
|
||||||
|
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
|
||||||
|
"""
|
||||||
|
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||||
|
self.device = device
|
||||||
|
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||||
|
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||||
|
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
|
||||||
|
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||||
|
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
|
||||||
|
self.buffer = []
|
||||||
|
self.max_buffer_size = 100
|
||||||
|
|
||||||
|
def forward(self, loss, timesteps):
|
||||||
|
# do a our train loop for lsnr here and return our values detached
|
||||||
|
loss = loss.detach()
|
||||||
|
with torch.no_grad():
|
||||||
|
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
|
||||||
|
for loss_chunk in loss_chunks:
|
||||||
|
self.buffer.append(loss_chunk.mean().detach())
|
||||||
|
if len(self.buffer) > self.max_buffer_size:
|
||||||
|
self.buffer.pop(0)
|
||||||
|
all_snr = get_all_snr(self.noise_scheduler, loss.device)
|
||||||
|
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
||||||
|
base_snrs = snr.clone().detach()
|
||||||
|
snr.requires_grad = True
|
||||||
|
snr = snr * self.scale + self.offset
|
||||||
|
|
||||||
|
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
||||||
|
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||||
|
snr_adjusted_loss = loss * snr_weight
|
||||||
|
with torch.no_grad():
|
||||||
|
target = torch.mean(torch.stack(self.buffer)).detach()
|
||||||
|
|
||||||
|
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
|
||||||
|
squared_differences = (snr_adjusted_loss - target) ** 2
|
||||||
|
local_loss = torch.mean(squared_differences)
|
||||||
|
local_loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_learnable_snr_gos(
|
||||||
|
loss,
|
||||||
|
timesteps,
|
||||||
|
learnable_snr_trainer:LearnableSNRGamma
|
||||||
|
):
|
||||||
|
|
||||||
|
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||||
|
|
||||||
|
snr = snr * scale + offset
|
||||||
|
|
||||||
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||||
|
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||||
|
snr_adjusted_loss = loss * snr_weight
|
||||||
|
|
||||||
|
return snr_adjusted_loss
|
||||||
|
|
||||||
|
|
||||||
def apply_snr_weight(
|
def apply_snr_weight(
|
||||||
loss,
|
loss,
|
||||||
@@ -700,5 +762,6 @@ def apply_snr_weight(
|
|||||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||||
else:
|
else:
|
||||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
|
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
|
||||||
loss = loss * snr_weight
|
snr_adjusted_loss = loss * snr_weight
|
||||||
return loss
|
|
||||||
|
return snr_adjusted_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user