HGEMM: Cleanup

This commit is contained in:
turboderp
2026-03-03 06:08:07 +01:00
parent 30178941f0
commit e5b522872b

View File

@@ -9,10 +9,12 @@
/*
Row-major matmul using cuBLAS, a @ b -> c
- if c is float16, operation is float16 @ float16 -> float16
- if c is float32, operation is float16 @ float16 -> float23
- if c is float16, operation is float16 @ float16 -> float16 (float16 accumulate)
- if c is float32, operation is float16 @ float16 -> float32 (float32 accumulate)
*/
using bfloat16 = __nv_bfloat16;
void hgemm_gr
(
at::Tensor a,
@@ -25,8 +27,9 @@ void hgemm_gr
cudaStream_t stream = graph ? graph->capture_stream : at::cuda::getCurrentCUDAStream().stream();
bool output_fp32 = c.dtype() == at::kFloat;
if (!output_fp32)
TORCH_CHECK_DTYPE(c, kHalf);
bool output_fp16 = c.dtype() == at::kHalf;
TORCH_CHECK(output_fp32 || output_fp16, "c must be float32 or float16");
// Check shapes of a,b,c are compatible
TORCH_CHECK_DTYPE(a, kHalf);
@@ -51,7 +54,7 @@ void hgemm_gr
void* ws = DevCtx::instance().get_ws(device);
cublasSetWorkspace(cublas_handle, ws, WORKSPACE_SIZE);
if (!output_fp32)
if (output_fp16)
{
half alpha_ = __float2half(1.0f);
half beta_ = __float2half(0.0f);
@@ -70,7 +73,7 @@ void hgemm_gr
cublas_check(r);
cuda_check(cudaPeekAtLastError());
}
else
if (output_fp32)
{
float alpha_ = 1.0f;
float beta_ = 0.0f;