From c0cb4d036ddc701270af0f4517303ec7fb1a867a Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 6 Aug 2025 02:45:31 +0000 Subject: [PATCH] fix split k --- include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index f3e7d8e336..2c01731a9f 100755 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -336,6 +336,7 @@ struct FlatmmKernel template __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { + constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{}); constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; @@ -351,11 +352,11 @@ struct FlatmmKernel if constexpr(std::is_same_v) { - b_k_split_offset = k_id * KRead * kargs.stride_B; + b_k_split_offset = k_id * KRead * kargs.stride_B * N1; } else if constexpr(std::is_same_v) { - b_k_split_offset = k_id * KRead; + b_k_split_offset = k_id * KRead * N1; } if(k_id < static_cast(kargs.k_batch - 1)) @@ -557,8 +558,8 @@ struct FlatmmKernel } }(); - index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k / - BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatK = FlatmmPipeline::flatKPerWarp * (kargs.K / + BlockGemmShape::WarpTile::at(I2)); index_t kFlatN = kargs.N * kargs.K / kFlatK; const auto& b_flat_tensor_view = [&]() { return make_naive_tensor_view( @@ -598,7 +599,7 @@ struct FlatmmKernel const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { - return make_naive_tensor_view( + return make_naive_tensor_view( e_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_E, 1), @@ -607,7 +608,7 @@ struct FlatmmKernel } else { - return make_naive_tensor_view( + return make_naive_tensor_view( e_ptr, make_tuple(kargs.N, kargs.M), make_tuple(kargs.stride_E, 1),