GEMM Multiply Multiply Fix (#2102)

* fix the type convert and increase the BF16 conversion + the profile comment

* fix the CI

[ROCm/composable_kernel commit: 0cca8fa28f]
This commit is contained in:
Thomas Ning
2025-04-22 01:13:22 -07:00
committed by GitHub
parent 863ec4eb88
commit 005e61ce63
2 changed files with 2 additions and 2 deletions

View File

@@ -117,7 +117,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
#if CK_USE_RNE_BF16_CONVERSION
return bf16_convert_rtn<bhalf_t>(x);
#else
return uint16_t(uint32_t{x} >> 16);
return uint16_t(static_cast<uint32_t>(x) >> 16);
#endif
}

View File

@@ -42,7 +42,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->bf16; 9: f8->f16, comp f8;)\n");
"comp f8; 8: int8->bf16; 9: int8->f16, 10. f8->f16;)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");