Added ability to split up flux across gpus (experimental). Changed the way timestep scheduling works to prep for more specific schedules.

This commit is contained in:
Jaret Burkett
2024-12-31 07:06:55 -07:00
parent 8ef07a9c36
commit 4723f23c0d
5 changed files with 182 additions and 7 deletions

View File

@@ -1,6 +1,6 @@
import math
from typing import Union
from torch.distributions import LogNormal
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
@@ -89,12 +89,12 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'):
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
elif timestep_type == 'sigmoid':
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
@@ -108,3 +108,25 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.8
lognormal = LogNormal(loc=0, scale=0.333)
# Sample from the distribution
t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device)
# Scale and reverse the values to go from 1000 to 0
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
timesteps = timesteps.to(torch.int)
else:
raise ValueError(f"Invalid timestep type: {timestep_type}")