diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index fc22f63e14..2f8d3c6053 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -214,22 +214,27 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; + // ---- Lower 4 int4 values (even positions) ---- + // Extract dictionary indices: low 3 bits of each byte (values 0..7). uint32_t dict_sel = a & 0x07070707; - uint32_t sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - - tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); - tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign. + uint32_t sign = a >> 1; + // Build final selector: + // - bit 2 of each byte (0x04) selects negative vs positive table + // - 0x03020100 selects byte lanes [0,1,2,3] in order + final_sel = (sign & 0x04040404) | 0x03020100; + // Lookup positive and negative fp8 codes from the small register tables. + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // Select per-lane between tmp_pos and tmp_neg using the sign-derived selector. tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); + // ---- Upper 4 int4 values (odd positions) ---- + // Shift to bring the high-nibble int4s into place and repeat the process. a >>= 4; - dict_sel = a & 0x07070707; - sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); + dict_sel = a & 0x07070707; + sign = a >> 1; + final_sel = (sign & 0x04040404) | 0x03020100; tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); @@ -306,22 +311,29 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; + // ---- Lower 4 int4 values (even positions) ---- + // Extract dictionary indices: low 3 bits of each byte (values 0..7). uint32_t dict_sel = a & 0x07070707; - uint32_t sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); - tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign. + uint32_t sign = a >> 1; + // Build final selector: + // - bit 2 of each byte (0x04) selects negative vs positive table + // - 0x03020100 selects byte lanes [0,1,2,3] in order + final_sel = (sign & 0x04040404) | 0x03020100; + + // Lookup positive and negative fp8 codes from the small register tables. + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // Select per-lane between tmp_pos and tmp_neg using the sign-derived selector. tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); + // ---- Upper 4 int4 values (odd positions) ---- + // Shift to bring the high-nibble int4s into place and repeat the process. a >>= 4; - dict_sel = a & 0x07070707; - sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); + dict_sel = a & 0x07070707; + sign = a >> 1; + final_sel = (sign & 0x04040404) | 0x03020100; tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index ff3f060770..f83462391c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -30,7 +30,7 @@ struct BaseGemmPipelineAgBgCrCompV3 { if(BlockHasHotloop(num_loop)) { - return TailNumber::Full; + return TailNumber::Odd; } else { @@ -52,23 +52,27 @@ struct BaseGemmPipelineAgBgCrCompV3 // Handle all the valid cases. if(has_hot_loop) { - if(tail_number == TailNumber::Full) + if(tail_number == ck_tile::TailNumber::Odd) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } else { - if(tail_number == TailNumber::Odd) + + if(tail_number == ck_tile::TailNumber::Odd) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } - else if(tail_number == TailNumber::Even) + else if(tail_number == ck_tile::TailNumber::Even) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } #if defined(__HIP_DEVICE_COMPILE__) @@ -76,16 +80,8 @@ struct BaseGemmPipelineAgBgCrCompV3 __builtin_unreachable(); #else // If execution reaches here, it's an invalid combination of arguments. - if(has_hot_loop) - { - throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must " - "be TailNumber::Full."); - } - else - { - throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must " - "be TailNumber::Odd or TailNumber::Even."); - } + throw std::logic_error("Invalid TailNumber value: must be " + "TailNumber::Odd or TailNumber::Even"); #endif } }; @@ -588,7 +584,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } while(i < (num_loop - 1)); } // tail - if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + if constexpr(TailNum == TailNumber::Odd) { // Leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 0e968010f3..012b53bbd4 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -786,8 +786,8 @@ struct QuantGemmKernel using QuantGroupSize = remove_cvref_t; return make_naive_tensor_view( bq_ptr, - make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(1, kargs.stride_BQ), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), + make_tuple(kargs.stride_BQ, 1), number{}, number<1>{}); } @@ -1030,9 +1030,9 @@ struct QuantGemmKernel using QuantGroupSize = remove_cvref_t; return make_tile_window( bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); } } else diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 85f0472ef6..8d76ab934b 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -15,68 +15,9 @@ namespace ck_tile { -template -struct BaseAQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem -{ - CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) - { - if(num_loop % BaseGemmPipelineAgBgCrCompV3::PrefetchStages == 0) - { - return TailNumber::Even; - } - else - { - return TailNumber::Odd; - } - } - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - if(has_hot_loop) - { - if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - else - { - - if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - } -}; - +// ToDo: Change the Pipeline to actual memory pipeline. template -struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem +struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = BaseGemmPipelineAgBgCrMem; using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 20ce7da0ff..fcbac3ff66 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -14,74 +14,8 @@ namespace ck_tile { -// Compute optimized pipeline -// GlobalPrefetchStages: 2 -// LocalPreFillStages: 1 -// LocalPreFetchStages: 1 -// LocalSharedMemoryBuffer: 1 - -template -struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 -{ - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - if(has_hot_loop) - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - else - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - } -}; - template -struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3 +struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index dfc03d62da..870326cb9d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC tile_distribution_encoding_pattern_bq; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index edaa1896fb..8f4d4e0460 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -20,68 +20,8 @@ namespace ck_tile { // LocalPreFetchStages: 1 // LocalSharedMemoryBuffer: 1 -template -struct BaseBQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 -{ - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - if(has_hot_loop) - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - else - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - } -}; - template -struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV3 +struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase; @@ -318,8 +258,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV (PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) / BlockGemmShape::WarpTile::at(number<1>{}), 0) - : is_bq_col_major ? make_array(KPerBlockBQ, 0) - : make_array(0, KPerBlockBQ); + : is_bq_col_major ? make_array(0, KPerBlockBQ) + : make_array(KPerBlockBQ, 0); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 9109f68ec5..dae099af4f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -171,7 +171,7 @@ template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -231,39 +231,39 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - if constexpr(XPerQ < WarpGemm::kN) + if constexpr(YPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t Y = YPerTile; // Full Y dimension of tile - constexpr index_t YR = 1; // No Y replication needed - constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t X1 = NWarps; // Number of warps in N-dim - constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp - constexpr index_t XR = XPerQ; // Elements per quantization group + constexpr index_t X = XPerTile; // Full X dimension of tile + constexpr index_t XR = 1; // No Y replication needed + constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t Y1 = NWarps; // Number of warps in N-dim + constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp + constexpr index_t YR = YPerQ; // Elements per quantization group - static_assert(X0 * X1 * X2 == XPerTile, - "X0, X1, X2 must cover the blocktile along X."); + static_assert(Y0 * Y1 * Y2 == YPerTile, + "Y0, Y1, Y2 must cover the blocktile along Y."); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 2, 0>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, tuple, sequence<1, 2, 2>>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}); } - else if constexpr(XPerQ <= WarpGemm::kN * NWarps) + else if constexpr(YPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto X1 = NWarps / XR; // Warps per unique scale - constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension + constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto Y1 = NWarps / YR; // Warps per unique scale + constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, tuple, sequence<2>>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}); } else // XPerQ > WarpGemm::kN * NWarps diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index efc1690b7c..d7129268c5 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -280,7 +280,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV } else { - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } // Prefill A0 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); @@ -338,7 +338,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV } else { - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } // Prefill A(2i+1) @@ -390,7 +390,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV } else { - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } // Prefill A(2i+2) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index 34bdf4ea38..a75d871421 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -88,11 +88,7 @@ using BQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 5d23e73146..4397668a5d 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -124,7 +124,12 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using BaseGemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::BaseGemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV4>>; const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile_; const ck_tile::index_t K_split =