mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added new experimental time step weighing that should solve a lot of issues with distribution. Updated example. Removed a warning
This commit is contained in:
@@ -25,6 +25,8 @@ config:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
# images will automatically be resized and bucketed into the resolution specified
|
||||
# on windows, escape back slashes with another backslash so
|
||||
# "C:\\path\\to\\images\\folder"
|
||||
- folder_path: "/path/to/images/folder"
|
||||
caption_ext: "txt"
|
||||
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
||||
@@ -33,17 +35,20 @@ config:
|
||||
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
||||
train:
|
||||
batch_size: 1
|
||||
steps: 4000 # total number of steps to train 500 - 4000 is a good range
|
||||
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
||||
gradient_accumulation_steps: 1
|
||||
train_unet: true
|
||||
train_text_encoder: false # probably won't work with flux
|
||||
content_or_style: balanced # content, style, balanced
|
||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||
noise_scheduler: "flowmatch" # for training only
|
||||
optimizer: "adamw8bit"
|
||||
lr: 4e-4
|
||||
lr: 1e-4
|
||||
# uncomment this to skip the pre training sample
|
||||
# skip_first_sample: true
|
||||
# uncomment to completely disable sampling
|
||||
# disable_sampling: true
|
||||
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
||||
linear_timesteps: true
|
||||
|
||||
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
||||
ema_config:
|
||||
|
||||
@@ -12,29 +12,23 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
num_timesteps = 1000
|
||||
# Bell-Shaped Mean-Normalized Timestep Weighting
|
||||
# bsmntw? need a better name
|
||||
|
||||
# generate the multiplier based on cosmap loss weighing
|
||||
# this is only used on linear timesteps for now
|
||||
x = torch.arange(num_timesteps, dtype=torch.float32)
|
||||
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
|
||||
|
||||
# cosine map weighing is higher in the middle and lower at the ends
|
||||
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
|
||||
# cosmap_weighing = 2 / (math.pi * bot)
|
||||
# Shift minimum to 0
|
||||
y_shifted = y - y.min()
|
||||
|
||||
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
|
||||
sigma_sqrt_weighing = (self.sigmas ** -2.0).float()
|
||||
# clip at 1e4 (1e6 is too high)
|
||||
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
|
||||
# bring to a mean of 1
|
||||
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
|
||||
# Scale to make mean 1
|
||||
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
# self.linear_timesteps_weights = cosmap_weighing
|
||||
self.linear_timesteps_weights = sigma_sqrt_weighing
|
||||
|
||||
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -640,8 +640,8 @@ def add_all_snr_to_noise_scheduler(noise_scheduler, device):
|
||||
all_snr.requires_grad = False
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Failed to add all_snr to noise_scheduler")
|
||||
# just move on
|
||||
pass
|
||||
|
||||
|
||||
def get_all_snr(noise_scheduler, device):
|
||||
|
||||
Reference in New Issue
Block a user