mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
HGEMM: Cleanup
This commit is contained in:
@@ -9,10 +9,12 @@
|
||||
/*
|
||||
|
||||
Row-major matmul using cuBLAS, a @ b -> c
|
||||
- if c is float16, operation is float16 @ float16 -> float16
|
||||
- if c is float32, operation is float16 @ float16 -> float23
|
||||
- if c is float16, operation is float16 @ float16 -> float16 (float16 accumulate)
|
||||
- if c is float32, operation is float16 @ float16 -> float32 (float32 accumulate)
|
||||
*/
|
||||
|
||||
using bfloat16 = __nv_bfloat16;
|
||||
|
||||
void hgemm_gr
|
||||
(
|
||||
at::Tensor a,
|
||||
@@ -25,8 +27,9 @@ void hgemm_gr
|
||||
cudaStream_t stream = graph ? graph->capture_stream : at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
bool output_fp32 = c.dtype() == at::kFloat;
|
||||
if (!output_fp32)
|
||||
TORCH_CHECK_DTYPE(c, kHalf);
|
||||
bool output_fp16 = c.dtype() == at::kHalf;
|
||||
|
||||
TORCH_CHECK(output_fp32 || output_fp16, "c must be float32 or float16");
|
||||
|
||||
// Check shapes of a,b,c are compatible
|
||||
TORCH_CHECK_DTYPE(a, kHalf);
|
||||
@@ -51,7 +54,7 @@ void hgemm_gr
|
||||
void* ws = DevCtx::instance().get_ws(device);
|
||||
cublasSetWorkspace(cublas_handle, ws, WORKSPACE_SIZE);
|
||||
|
||||
if (!output_fp32)
|
||||
if (output_fp16)
|
||||
{
|
||||
half alpha_ = __float2half(1.0f);
|
||||
half beta_ = __float2half(0.0f);
|
||||
@@ -70,7 +73,7 @@ void hgemm_gr
|
||||
cublas_check(r);
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
else
|
||||
if (output_fp32)
|
||||
{
|
||||
float alpha_ = 1.0f;
|
||||
float beta_ = 0.0f;
|
||||
|
||||
Reference in New Issue
Block a user