revise kernel

This commit is contained in:
lllyasviel
2024-08-07 13:28:12 -07:00
committed by GitHub
parent 1ef0844225
commit 14a759b5ca
10 changed files with 317 additions and 420 deletions

View File

@@ -5,7 +5,7 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
class KModel(torch.nn.Module):
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype):
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype, k_predictor=None):
super().__init__()
self.storage_dtype = storage_dtype
@@ -17,7 +17,11 @@ class KModel(torch.nn.Module):
print(f'K-Model Created: {dict(storage_dtype=storage_dtype, computation_dtype=computation_dtype, manual_cast=need_manual_cast)}')
self.diffusion_model = model
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
if k_predictor is None:
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
else:
self.predictor = k_predictor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t

View File

@@ -228,7 +228,7 @@ class PredictionFlow(AbstractPrediction):
class PredictionFlux(AbstractPrediction):
def __init__(self, sigma_data=1.0, prediction_type='eps', shift=1.0, timesteps=10000):
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))