mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-08 08:59:58 +00:00
SD3+ (#2688)
Co-authored-by: graemeniedermayer graemeniedermayer@users.noreply.github.com
This commit is contained in:
@@ -250,6 +250,38 @@ class PredictionFlow(AbstractPrediction):
|
||||
return 1.0 - percent
|
||||
|
||||
|
||||
class PredictionDiscreteFlow(AbstractPrediction):
|
||||
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.0, timesteps = 1000):
|
||||
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||
self.shift = shift
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer("sigmas", ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
|
||||
|
||||
class PredictionFlux(AbstractPrediction):
|
||||
def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None):
|
||||
super().__init__(sigma_data=1.0, prediction_type='const')
|
||||
|
||||
Reference in New Issue
Block a user