mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added cogview4. Loss still needs work.
This commit is contained in:
@@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||
|
||||
# flatten second half to max
|
||||
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||
hbsmntw_weighing[num_timesteps //
|
||||
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||
@@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
step_indices = [(self.timesteps == t).nonzero().item()
|
||||
for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
if v2:
|
||||
@@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = self.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item()
|
||||
for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
@@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
||||
## Add noise according to flow matching.
|
||||
## zt = (1 - texp) * x + texp * z1
|
||||
|
||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
||||
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
# forward ODE
|
||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
||||
|
||||
# n_dim = original_samples.ndim
|
||||
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
# reverse ODE
|
||||
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
|
||||
def set_train_timesteps(
|
||||
self,
|
||||
num_timesteps,
|
||||
device,
|
||||
timestep_type='linear',
|
||||
latents=None,
|
||||
patch_size=1
|
||||
):
|
||||
self.timestep_type = timestep_type
|
||||
if timestep_type == 'linear':
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
@@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
|
||||
elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
|
||||
# matches inference dynamic shifting
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(
|
||||
self.sigma_min), num_timesteps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
|
||||
if latents is None:
|
||||
raise ValueError('latents is None')
|
||||
|
||||
h = latents.shape[2] // 2 # Divide by ph
|
||||
w = latents.shape[3] // 2 # Divide by pw
|
||||
image_seq_len = h * w
|
||||
|
||||
# todo need to know the mu for the shift
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.config.get("base_image_seq_len", 256),
|
||||
self.config.get("max_image_seq_len", 4096),
|
||||
self.config.get("base_shift", 0.5),
|
||||
self.config.get("max_shift", 1.16),
|
||||
)
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
if self.config.use_dynamic_shifting:
|
||||
if latents is None:
|
||||
raise ValueError('latents is None')
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
# for flux we double up the patch size before sending her to simulate the latent reduction
|
||||
h = latents.shape[2]
|
||||
w = latents.shape[3]
|
||||
image_seq_len = h * w // (patch_size**2)
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.config.get("base_image_seq_len", 256),
|
||||
self.config.get("max_image_seq_len", 4096),
|
||||
self.config.get("base_shift", 0.5),
|
||||
self.config.get("max_shift", 1.16),
|
||||
)
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
||||
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(
|
||||
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(
|
||||
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(
|
||||
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(
|
||||
dtype=torch.float32, device=device)
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if self.config.invert_sigmas:
|
||||
sigmas = 1.0 - sigmas
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
sigmas = torch.cat(
|
||||
[sigmas, torch.ones(1, device=sigmas.device)])
|
||||
else:
|
||||
sigmas = torch.cat(
|
||||
[sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = sigmas
|
||||
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
return timesteps
|
||||
|
||||
|
||||
elif timestep_type == 'lognorm_blend':
|
||||
# disgtribute timestepd to the center/early and blend in linear
|
||||
alpha = 0.75
|
||||
@@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
t1 = ((1 - t1/t1.max()) * 1000)
|
||||
|
||||
# add half of linear
|
||||
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
|
||||
t2 = torch.linspace(1000, 0, int(
|
||||
num_timesteps * (1 - alpha)), device=device)
|
||||
timesteps = torch.cat((t1, t2))
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
|
||||
Reference in New Issue
Block a user