[CK_TILE] Fix incompatible vector type arguments for the intrinsic calls (#3672)

* Change call to the intrinsics

* fix clang format

* Undo changes under include/ck/utility

* Use named variable as vector size

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
Po Yen Chen
2026-01-31 04:02:49 +08:00
committed by GitHub
parent 70d71b1514
commit 8c1788757a

View File

@@ -612,7 +612,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
c_vec,
0,
0,
0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
@@ -637,8 +643,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
fp32x16_t{0.f},
0,
0,
0));
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {
@@ -700,7 +711,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
c_vec,
0,
0,
0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
@@ -725,8 +742,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
fp32x4_t{0.f},
0,
0,
0));
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {
@@ -790,7 +812,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
c_vec,
0,
0,
0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
@@ -815,8 +843,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
fp32x4_t{0.f},
0,
0,
0));
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {
@@ -880,7 +913,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
c_vec,
0,
0,
0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
@@ -905,8 +944,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
fp32x4_t{0.f},
0,
0,
0));
#elif defined(__gfx908__)
CVecType c_vec{0.f};
static_for<0, 2, 1>{}([&](auto k) {