mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added ability to split up flux across gpus (experimental). Changed the way timestep scheduling works to prep for more specific schedules.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
from torch.distributions import LogNormal
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
|
||||
@@ -89,12 +89,12 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, linear=False):
|
||||
if linear:
|
||||
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'):
|
||||
if timestep_type == 'linear':
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
else:
|
||||
elif timestep_type == 'sigmoid':
|
||||
# distribute them closer to center. Inference distributes them as a bias toward first
|
||||
# Generate values from 0 to 1
|
||||
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
||||
@@ -108,3 +108,25 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
elif timestep_type == 'lognorm_blend':
|
||||
# disgtribute timestepd to the center/early and blend in linear
|
||||
alpha = 0.8
|
||||
|
||||
lognormal = LogNormal(loc=0, scale=0.333)
|
||||
|
||||
# Sample from the distribution
|
||||
t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device)
|
||||
|
||||
# Scale and reverse the values to go from 1000 to 0
|
||||
t1 = ((1 - t1/t1.max()) * 1000)
|
||||
|
||||
# add half of linear
|
||||
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
|
||||
timesteps = torch.cat((t1, t2))
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||
|
||||
timesteps = timesteps.to(torch.int)
|
||||
else:
|
||||
raise ValueError(f"Invalid timestep type: {timestep_type}")
|
||||
|
||||
Reference in New Issue
Block a user