Added cogview4. Loss still needs work.

This commit is contained in:
Jaret Burkett
2025-03-04 18:43:52 -07:00
parent c57434ad7b
commit 6f6fb90812
12 changed files with 661 additions and 152 deletions

View File

@@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# flatten second half to max
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
hbsmntw_weighing[num_timesteps //
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
@@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(self.timesteps == t).nonzero().item()
for t in timesteps]
# Get the weights for the timesteps
if v2:
@@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(schedule_timesteps == t).nonzero().item()
for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
@@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
# forward ODE
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
# reverse ODE
# noisy_model_input = (1 - t_01) * noise + t_01 * original_samples
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
def set_train_timesteps(
self,
num_timesteps,
device,
timestep_type='linear',
latents=None,
patch_size=1
):
self.timestep_type = timestep_type
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
@@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']:
# matches inference dynamic shifting
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
self._sigma_to_t(self.sigma_max), self._sigma_to_t(
self.sigma_min), num_timesteps
)
sigmas = timesteps / self.config.num_train_timesteps
if latents is None:
raise ValueError('latents is None')
h = latents.shape[2] // 2 # Divide by ph
w = latents.shape[3] // 2 # Divide by pw
image_seq_len = h * w
# todo need to know the mu for the shift
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
if self.config.use_dynamic_shifting:
if latents is None:
raise ValueError('latents is None')
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
# for flux we double up the patch size before sending her to simulate the latent reduction
h = latents.shape[2]
w = latents.shape[3]
image_seq_len = h * w // (patch_size**2)
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(
in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps)
sigmas = torch.from_numpy(sigmas).to(
dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat(
[sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat(
[sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.75
@@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
t2 = torch.linspace(1000, 0, int(
num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))
# Sort the timesteps in descending order