Co-authored-by: graemeniedermayer graemeniedermayer@users.noreply.github.com
This commit is contained in:
DenOfEquity
2025-02-27 17:54:44 +00:00
committed by GitHub
parent 8dd92501e6
commit f23bc80d2f
6 changed files with 1184 additions and 22 deletions

View File

@@ -250,6 +250,38 @@ class PredictionFlow(AbstractPrediction):
return 1.0 - percent
class PredictionDiscreteFlow(AbstractPrediction):
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.0, timesteps = 1000):
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
self.shift = shift
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.register_buffer("sigmas", ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep: torch.Tensor):
timestep = timestep / 1000.0
if self.shift == 1.0:
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
class PredictionFlux(AbstractPrediction):
def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None):
super().__init__(sigma_data=1.0, prediction_type='const')