forge 2.0.0

see also discussions
This commit is contained in:
lllyasviel
2024-08-10 19:24:19 -07:00
committed by GitHub
parent 4014013d05
commit cfa5242a75
28 changed files with 785 additions and 1249 deletions

View File

@@ -5,16 +5,13 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
class KModel(torch.nn.Module):
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype, k_predictor=None):
def __init__(self, model, diffusers_scheduler, k_predictor=None):
super().__init__()
self.storage_dtype = storage_dtype
self.computation_dtype = computation_dtype
self.storage_dtype = model.storage_dtype
self.computation_dtype = model.computation_dtype
need_manual_cast = self.storage_dtype != self.computation_dtype
operations.shift_manual_cast(model, enabled=need_manual_cast)
print(f'K-Model Created: {dict(storage_dtype=storage_dtype, computation_dtype=computation_dtype, manual_cast=need_manual_cast)}')
print(f'K-Model Created: {dict(storage_dtype=self.storage_dtype, computation_dtype=self.computation_dtype)}')
self.diffusion_model = model