From addcd203eb48ce5a0b74f52022f3c2df30cb8747 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 11 Apr 2025 12:18:26 +0200 Subject: [PATCH] [CK_TILE] Add 2:4 structured sparsity support for fp16 gemm (#1957) * add structured sparsity fp16 support for gemm * added reviewer suggestions * update changelog * update changelog * add reviewers suggestions * Minor fix * clang fix * fix doxygen [ROCm/composable_kernel commit: 6c61f4d237a9841c5b5d8b4380eaf9c2af14947e] --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 3 +- example/ck_tile/03_gemm/run_gemm_example.inc | 24 ++-- example/ck_tile/03_gemm/universal_gemm.cpp | 3 +- include/ck_tile/host/fill.hpp | 43 +++++++ .../gemm/pipeline/gemm_pipeline_problem.hpp | 3 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 4 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 8 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 13 +- .../gemm/warp/warp_gemm_attribute_smfmac.hpp | 80 ++++++++++++ .../warp/warp_gemm_attribute_smfmac_impl.hpp | 114 ++++++++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 15 ++- .../ops/gemm/warp/warp_gemm_smfmac_impl.hpp | 110 +++++++++++++++++ 13 files changed, 401 insertions(+), 20 deletions(-) create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ef2998eb..e3d7971c71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added support for FP16 2:4 structured sparsity to universal GEMM. ### Optimized diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 3254a407fd..973006196b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -93,7 +93,8 @@ struct GemmConfig static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c3b4ec609c..b4ea5d22c0 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -55,7 +55,8 @@ void permute_tensor_b(Tensor& tensor) ALayout, BLayout, CLayout, - GemmConfig::TransposeC>; + GemmConfig::TransposeC, + GemmConfig::UseStructuredSparsity>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name - << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; return ave_time; } @@ -259,6 +262,11 @@ int run_gemm_example_with_layouts(int argc, b_k_n.SetZero(); } + if(GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eef8d3b60e..2ba16ca89d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -46,7 +46,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ALayout, BLayout, CLayout, - GemmConfig::TransposeC>; + GemmConfig::TransposeC, + GemmConfig::UseStructuredSparsity>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 006026470b..d90c0cf6cf 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -364,6 +364,49 @@ struct FillConstant } }; +//---------------------------------------------------------------------------------------------- +/// @brief Transforms given input to fit 2:4 structured sparsity pattern so +/// every subgroup of 4 elements contain at most 2 non-zero elements +template +struct AdjustToStructuredSparsity +{ + size_t start{0}; + // masks represent all valid 2:4 structured sparsity permutations + // clang-format off + static constexpr int32_t masks[] = {0, 0, 1, 1, + 0, 1, 0, 1, + 0, 1, 1, 0, + 1, 0, 0, 1, + 1, 0, 1, 0, + 1, 1, 0, 0, + 0, 0, 0, 1, + 0, 0, 1, 0, + 0, 1, 0, 0, + 1, 0, 0, 0}; + // clang-format on + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::transform(first, last, first, [=, index = start](T val) mutable { + auto tmp = val * masks[index % (sizeof(masks) / sizeof(int32_t))]; + index += 1; + + return type_convert(tmp); + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + template struct FillTrigValue { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index f833ccc849..cba3677332 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -194,7 +194,8 @@ struct UniversalGemmPipelineProblem static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; - static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c504a51ad0..b555cf75e0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -580,7 +580,9 @@ struct UniversalGemmPipelineAgBgCrPolicy WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - Problem::TransposeC>; + Problem::TransposeC, + false, + Problem::UseStructuredSparsity>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy + bool TransposeC_ = false, + bool UseStructuredSparsity_ = false> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; @@ -49,7 +50,8 @@ struct TileGemmUniversalTraits using BLayout = BLayout_; using CLayout = CLayout_; - static constexpr bool TransposeC = TransposeC_; + static constexpr bool TransposeC = TransposeC_; + static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 1fd12973f6..33f3dde256 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,9 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" + namespace ck_tile { // fp16 @@ -64,6 +67,14 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl, 4>>; +// fp16 2:4 structured sparsity + +using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; + +using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; + // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp new file mode 100644 index 0000000000..adf548aaca --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp @@ -0,0 +1,80 @@ +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" + +namespace ck_tile { + +/** + * @brief Class describing structured sparsity mfma instructions. + * + * @paragraph Overview "Overview" + * Currently only 2:4 structured sparsity is supported, which is based on requirement that in every + * groups of four continuous elements there are at most two non-zero, which results in processing + * only half of elements in smfmac instruction. Because of structured sparsity A vector in smfmac + * instruction will be smaller than B vector by the factor of CompressionRatio. The indexes of + * non-zero elements are stored in `index` which is an additional parameter to assembly instruction. + * Every pair of two bit indexes are containing information about which two elements in current + * group of 4 values are non-zero and should be used inside smfmac instruction. Structured sparsity + * format is supported only for A matrix for now. + */ +template +struct WarpGemmAttributeSmfmac +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using IdxDataType = typename Impl::IdxDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + using BVecType = typename Impl::BVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCompressionRatio = Impl::CompressionRatio; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeSmfmacImpl is not supported"); + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { + Impl{}(c_vec, a_vec, b_vec, idx, bool_constant{}); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp new file mode 100644 index 0000000000..97fd2a8742 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "warp_gemm_attribute_mfma_impl.hpp" + +namespace ck_tile { + +// fp16 2:4 structured sparsity + +template +struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using IdxDataType = int32_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t CompressionRatio = 2; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; +#endif + } +}; + +template +struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using IdxDataType = int32_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t CompressionRatio = 2; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; +#endif + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 9c319b5e5f..6320b33598 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,8 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false> struct WarpGemmMfmaDispatcher; // clang-format off @@ -35,6 +36,10 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +// fp16 2:4 structural sparsity +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; + // bf16 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; @@ -70,7 +75,8 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false> using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; + SwizzleA, + UseStructuredSparsity>::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp new file mode 100644 index 0000000000..9e028ddab0 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +namespace ck_tile { + +template +struct WarpGemmSmfmacImpl +{ + using WarpGemmAttribute = remove_cvref_t; + + static constexpr index_t kM = WarpGemmAttribute::kM; + static constexpr index_t kN = WarpGemmAttribute::kN; + static constexpr index_t kK = WarpGemmAttribute::kK; + /// @brief The number of elements in K dimension processed by single thread in wavefront. + /// + /// @note Note that WarpGemm may run MFMA instruction multiple times (on different K). + /// In such situation this value reflects this fact. + static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread; + + using ADataType = typename WarpGemmAttribute::ADataType; + using BDataType = typename WarpGemmAttribute::BDataType; + using CDataType = typename WarpGemmAttribute::CDataType; + + using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; + using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; + + using AWarpDstr = remove_cvref_t; + using BWarpDstr = remove_cvref_t; + using CWarpDstr = remove_cvref_t; + + using AWarpTensor = static_distributed_tensor; + using BWarpTensor = static_distributed_tensor; + using CWarpTensor = static_distributed_tensor; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() + { + return WarpGemmAttribute_::get_num_of_access(); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + /// elements into lower part of a_vec to half its effective size. + /// + /// @param a_vec Vector to be compressed. + /// + /// @return Four 2-bit indexes of non-zero elements locations + /// + template + CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const + { + int32_t idx = 0b11101110; + + static_for<0, 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 3, 1>{}([&](auto j) { + if(a_vec[i * 4 + j] != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); + idx |= j << 2 * (i * 2 + non_zero_pos); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; + } + + template + CK_TILE_DEVICE void + operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant = {}) const + { + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio; + + using AVec = ext_vector_t; + using AVecCompressed = + ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + auto c_vec = c.get_thread_buffer().template get_as()[I0]; + + const int32_t idx = compress_a(a_vec); + + // @TODO can we simply set a_vec_pruned to a_vec[0:3]? + const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]}; + + // c_vec += a_vec * b_vec[idx] + WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant{}); + + c.get_thread_buffer().template set_as(I0, c_vec); + } +}; + +} // namespace ck_tile