mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-28 18:21:48 +00:00
revise kernel
This commit is contained in:
@@ -5,7 +5,7 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
|
||||
|
||||
|
||||
class KModel(torch.nn.Module):
|
||||
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype):
|
||||
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype, k_predictor=None):
|
||||
super().__init__()
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
@@ -17,7 +17,11 @@ class KModel(torch.nn.Module):
|
||||
print(f'K-Model Created: {dict(storage_dtype=storage_dtype, computation_dtype=computation_dtype, manual_cast=need_manual_cast)}')
|
||||
|
||||
self.diffusion_model = model
|
||||
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
|
||||
|
||||
if k_predictor is None:
|
||||
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
|
||||
else:
|
||||
self.predictor = k_predictor
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
|
||||
@@ -228,7 +228,7 @@ class PredictionFlow(AbstractPrediction):
|
||||
|
||||
|
||||
class PredictionFlux(AbstractPrediction):
|
||||
def __init__(self, sigma_data=1.0, prediction_type='eps', shift=1.0, timesteps=10000):
|
||||
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
|
||||
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||
self.shift = shift
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
|
||||
|
||||
Reference in New Issue
Block a user