Fix cublas ops on dynamic vram. (#12776)

This commit is contained in:
comfyanonymous
2026-03-04 22:21:55 -08:00
committed by GitHub
parent 43c64b6308
commit f2ee7f2d36

View File

@@ -660,23 +660,29 @@ class fp8_ops(manual_cast):
CUBLAS_IS_AVAILABLE = False CUBLAS_IS_AVAILABLE = False
try: try:
from cublas_ops import CublasLinear from cublas_ops import CublasLinear, cublas_half_matmul
CUBLAS_IS_AVAILABLE = True CUBLAS_IS_AVAILABLE = True
except ImportError: except ImportError:
pass pass
if CUBLAS_IS_AVAILABLE: if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init): class cublas_ops(manual_cast):
class Linear(CublasLinear, disable_weight_init.Linear): class Linear(CublasLinear, manual_cast.Linear):
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
return super().forward(input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs) run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
# ============================================================================== # ==============================================================================
# Mixed Precision Operations # Mixed Precision Operations