From bf3518b45aeadaa4918205f84230bd21b52d6a69 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 11 Sep 2024 15:19:42 +0200 Subject: [PATCH] Added structural sparsity blockwise gemm (#1435) * Implemented smfmac xdlops * Added smfmac blockwise xdlops * fixes * add reviewers suggestions --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: 2a261afcdfc2351dece8d872d413310ff1992988] --- .../block/blockwise_gemm_smfmac_xdlops.hpp | 453 ++++++++++++++++++ .../gpu/warp/smfmac_xdlops_gemm.hpp | 52 +- include/ck/utility/amd_smfmac.hpp | 26 +- 3 files changed, 505 insertions(+), 26 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp new file mode 100644 index 0000000000..e9f9b0be7e --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -0,0 +1,453 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +__host__ __device__ static constexpr auto +MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) +{ + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); +} + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + SparseXdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0, + "MPerBlock must be divisible by MPerXDL * MRepeat"); + static_assert(NPerBlock % (NPerXDL * NRepeat) == 0, + "NPerBlock must be divisible by NPerXDL * NRepeat"); + + static_assert( + KPack % (16 * sizeof(ComputeTypeA)) == 0, + "KPack must be divisbile by number of elements processed in single smfmac instruction"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + // Prepares data in a_thread_buf by squeezing values by ommiting zeros to adjust it to 2:4 + // structural sparsity. The indexes of non-zero elements are stored in idx_buf and used later in + // smfmac instruction + template + __device__ void SetIdxSqueezeA(AThreadBuf& a_thread_buf, IdxBuf& idx_buf) + { + static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000}; + static constexpr int32_t processed_elems = 16 / sizeof(ComputeTypeA); + + static_for<0, num_elems, processed_elems>{}([&](auto i) { + constexpr int idx_reg_num = i / (16 * sizeof(ComputeTypeA)); + constexpr int idx_reg_part = (i % 32) / processed_elems; + + vector_type a_thread_vec; + static_for<0, processed_elems, 1>{}([&](auto j) { + a_thread_vec.template AsType()(j) = a_thread_buf + [Number{}]; + }); + + uint8_t idx = 0b11101110; // set to last 2 elems for both 4-elems subgroups by default + for(int j = 0; j < processed_elems; j += 4) + { + int32_t a_pos = idx_reg_part * processed_elems + j; + int32_t nonzero_pos = 0; + ComputeTypeA nonzero_elems[2] = {a_thread_vec[j + 2], a_thread_vec[j + 3]}; + for(int k = 0; k < 3; k += 1) + { + if(a_thread_vec[j + k] != 0.0f) + { + nonzero_elems[nonzero_pos] = a_thread_vec[j + k]; + idx &= ~bit_clear_masks[j / 2 + nonzero_pos]; + idx |= k << 2 * (j / 2 + nonzero_pos); + ++nonzero_pos; + } + } + a_thread_vec[j / 2] = nonzero_elems[0]; + a_thread_vec[j / 2 + 1] = nonzero_elems[1]; + } + IdxBuf[idx_reg_num].AsType()[Number{}] = idx; + + static_for<0, processed_elems / 2, 1>{}([&](auto j) { + a_thread_buf[Number{}] = a_thread_vec[j]; + }); + }); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + static constexpr int32_t elems_per_idx = 16 * sizeof(ComputeTypeA); + auto idx_buf = make_static_buffer( + (a_thread_desc_.GetElementSpaceSize() + elems_per_idx - 1) / elems_per_idx); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + SetIdxSqueezeA(a_thread_buf, idx_buf, a_thread_desc_.GetElementSpaceSize()); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + // a_thread_vec is smaller because it's structurally sparse 2:4 + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type idx_vec; + + static_for<0, KPack / 2, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(2 * i) = b_thread_buf + [Number{}]; + }); + + static_for<0, KPack / elems_per_idx, 1>{}([&](auto i) { + idx_vec.template AsType()(i) = idx_buf[k / elems_per_idx + i]; + }); + + // A is smaller because it's structurally sparse 2:4 + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + using mfma_input_type_idx = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + idx_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp index 33c07f34f7..a436afd395 100644 --- a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp @@ -35,10 +35,16 @@ struct smfmac static constexpr index_t k_per_blk = 8; static constexpr bool is_k_reduction = true; - template - __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const { - intrin_smfmac_f32_16x16x32f16::Run(a, b, idx, reg_c); + intrin_smfmac_f32_16x16x32f16::Run( + a, b, idx, reg_c); } }; @@ -57,10 +63,16 @@ struct smfmac static constexpr index_t k_per_blk = 16; static constexpr bool is_k_reduction = true; - template - __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const { - intrin_smfmac_f32_32x32x16f16::Run(a, b, idx, reg_c); + intrin_smfmac_f32_32x32x16f16::Run( + a, b, idx, reg_c); } }; @@ -79,10 +91,16 @@ struct smfmac static constexpr index_t k_per_blk = 8; static constexpr bool is_k_reduction = true; - template - __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const { - intrin_smfmac_f32_16x16x32bf16::Run(a, b, idx, reg_c); + intrin_smfmac_f32_16x16x32bf16::Run( + a, b, idx, reg_c); } }; @@ -101,10 +119,16 @@ struct smfmac static constexpr index_t k_per_blk = 16; static constexpr bool is_k_reduction = true; - template - __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const { - intrin_smfmac_f32_32x32x16bf16::Run(a, b, idx, reg_c); + intrin_smfmac_f32_32x32x16bf16::Run( + a, b, idx, reg_c); } }; @@ -305,8 +329,8 @@ struct SparseXdlopsGemm "base base_type must be half or bfloat16!"); static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { - smfmac_instr.template run( - p_a_wave[k], p_b_wave[k], idx[k], p_c_thread); + smfmac_instr.template run( + p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread); }); } diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp index abb8d9f5ef..8b6b094ff2 100644 --- a/include/ck/utility/amd_smfmac.hpp +++ b/include/ck/utility/amd_smfmac.hpp @@ -9,16 +9,18 @@ namespace ck { template struct intrin_smfmac_f32_16x16x32f16; +// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse +// indices from reg_idx template <> struct intrin_smfmac_f32_16x16x32f16<16, 16> { - template + template __device__ static void - Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); #else ignore = reg_a; ignore = reg_b; @@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16; template <> struct intrin_smfmac_f32_16x16x32bf16<16, 16> { - template + template __device__ static void - Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); #else ignore = reg_a; ignore = reg_b; @@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16; template <> struct intrin_smfmac_f32_32x32x16f16<32, 32> { - template + template __device__ static void - Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); #else ignore = reg_a; ignore = reg_b; @@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16; template <> struct intrin_smfmac_f32_32x32x16bf16<32, 32> { - template + template __device__ static void - Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); #else ignore = reg_a; ignore = reg_b;