diff --git a/exllamav3/exllamav3_ext/hgemm.cu b/exllamav3/exllamav3_ext/hgemm.cu index 88b5f7f..ec91ca3 100644 --- a/exllamav3/exllamav3_ext/hgemm.cu +++ b/exllamav3/exllamav3_ext/hgemm.cu @@ -9,10 +9,12 @@ /* Row-major matmul using cuBLAS, a @ b -> c -- if c is float16, operation is float16 @ float16 -> float16 -- if c is float32, operation is float16 @ float16 -> float23 +- if c is float16, operation is float16 @ float16 -> float16 (float16 accumulate) +- if c is float32, operation is float16 @ float16 -> float32 (float32 accumulate) */ +using bfloat16 = __nv_bfloat16; + void hgemm_gr ( at::Tensor a, @@ -25,8 +27,9 @@ void hgemm_gr cudaStream_t stream = graph ? graph->capture_stream : at::cuda::getCurrentCUDAStream().stream(); bool output_fp32 = c.dtype() == at::kFloat; - if (!output_fp32) - TORCH_CHECK_DTYPE(c, kHalf); + bool output_fp16 = c.dtype() == at::kHalf; + + TORCH_CHECK(output_fp32 || output_fp16, "c must be float32 or float16"); // Check shapes of a,b,c are compatible TORCH_CHECK_DTYPE(a, kHalf); @@ -51,7 +54,7 @@ void hgemm_gr void* ws = DevCtx::instance().get_ws(device); cublasSetWorkspace(cublas_handle, ws, WORKSPACE_SIZE); - if (!output_fp32) + if (output_fp16) { half alpha_ = __float2half(1.0f); half beta_ = __float2half(0.0f); @@ -70,7 +73,7 @@ void hgemm_gr cublas_check(r); cuda_check(cudaPeekAtLastError()); } - else + if (output_fp32) { float alpha_ = 1.0f; float beta_ = 0.0f;