From c1c018bae76756f3077292e536fd0487df9fc03a Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Wed, 25 May 2022 10:55:22 +0800 Subject: [PATCH] Tensile-style block to C tile map (#239) * fix build * Revert "fix build" This reverts commit d73102384bfbb609e487d6d0cd04a3c8c9c4ec9e. * post PR #235 merge fix * amend * adds tensile-stype c-tile map * make it dynamic version * add k-split flavor tile map * apply tensile-style tile map to all xdl gridwise gemms * remove dead code Co-authored-by: Chao Liu [ROCm/composable_kernel commit: e579c9e5c6654c78a3829db8f0875462617d0452] --- .../gpu/device/device_grouped_gemm_xdl.hpp | 6 +- .../gpu/grid/block_to_ctile_map.hpp | 231 ++++++++++++++++++ .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 2 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 8 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4.hpp | 6 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 6 +- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 8 +- .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 8 +- .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 8 +- .../test_block_to_ctile_map.cpp | 230 ++++++++++++++++- 11 files changed, 481 insertions(+), 34 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index bcf2ea703a..08a70823be 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -346,7 +346,6 @@ struct DeviceGroupedGemmXdl return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n); } - private: typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; ck::index_t BlockStart_; }; @@ -418,9 +417,8 @@ struct DeviceGroupedGemmXdl DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); const index_t grid_size_grp = - typename GroupedGemmBlock2CTileMap::UnderlyingBlock2CTileMap( - c_grid_desc_m_n_, M01, N01) - .CalculateGridSize(c_grid_desc_m_n_); + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0) + .block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); const index_t BlockStart = grid_size_; const index_t BlockEnd = grid_size_ + grid_size_grp; diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 0fe08c9027..792060ca86 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -8,6 +8,237 @@ namespace ck { +// Rows of column-vectors +template +struct BlockToCTileMap_M00_N0_M01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) + : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + + const index_t grid_size = M00 * M01_ * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + if(M0 % M01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + + const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), + make_unmerge_transform(make_tuple(M00, M01)), + make_pass_through_transform(make_tuple(N0))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + + const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_n0_m01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1)); + UnderlyingMap underlying_map_; +}; + +// Rows of column-vectors +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) + : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// 2D slices of column-vectors in 3D space +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_KSplit_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8, + index_t KSplit = 1) + : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0 * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + index_t KSplit_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + // Blocks of row-vectors template ( + return BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 78d30bfd55..55390dbc86 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -259,7 +259,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) { - return BlockToCTileMap_M00_N00_M01_N01( + return BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index d60f8c4d07..974455fa3b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -288,11 +288,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) { - return BlockToCTileMap_M00_N00_M01_N01( - c_grid_desc_m_n, M01, N01); + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 96ae9bbb45..a54906cfbc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -265,10 +265,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - return BlockToCTileMap_KSplit_M00_N00_M01_N01( - c_m_n_grid_desc, M01, N01, KBatch); + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); } using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 6d138542f0..dbff1577e1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -239,10 +239,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - return BlockToCTileMap_KSplit_M00_N00_M01_N01( - c_m_n_grid_desc, M01, N01, KBatch); + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); } __host__ __device__ static constexpr auto diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 22dfc613bf..ffa82a7570 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -300,11 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) { - return BlockToCTileMap_M00_N00_M01_N01( - c_grid_desc_m_n, M01, N01); + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = remove_cvref_t( - c_grid_desc_m_n, M01, N01); + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 108800f677..745dfde0ba 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) { - return BlockToCTileMap_M00_N00_M01_N01( - c_grid_desc_m_n, M01, N01); + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = remove_cvref_t{}; static auto I1 = Number<1>{}; +static auto I2 = Number<2>{}; TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1) { @@ -20,7 +21,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1 const index_t M01 = 4; const index_t N01 = 4; - auto c_grid_desc_m_n = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, I1)); + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", M, @@ -37,7 +38,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1 EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); // clang-format off - std::vector> expected = { + std::vector> expected_m0idx_n0idx_valid = { {0, 0, 1}, {0, 1, 1}, {0, 2, 1}, @@ -64,7 +65,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1 std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) << std::endl; bool equal = - expected[i] == + expected_m0idx_n0idx_valid[i] == std::vector{m0n0_idx[I0], m0n0_idx[I1], tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; @@ -78,12 +79,11 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0 const index_t N = 384; const index_t MPerBlock = 128; const index_t NPerBlock = 128; - // const index_t MBlock = M / MPerBlock; - // const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; const index_t N01 = 4; - auto c_grid_desc_m_n = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, I1)); + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", M, @@ -98,3 +98,221 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0 EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == false); } + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 512; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 0}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 0}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 0}, + {0, 3, 1}, + {1, 3, 1}, + {2, 3, 1}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0) +{ + const index_t M = 512; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + // clang-format off + std::vector> expected_m0_gridsize_validity = { + {5, 15, false}, + {4, 12, true}, + {3, 18, false}, + {2, 12, true}, + {1, 12, true} + }; + // clang-format on + + for(auto e : expected_m0_gridsize_validity) + { + const index_t M01 = std::get<0>(e); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_EQ(tile_map.CalculateGridSize(c_grid_desc_m_n), std::get<1>(e)); + EXPECT_EQ(tile_map.CheckValidity(c_grid_desc_m_n), std::get<2>(e)); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01Adapt tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 1}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 1}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 1}, + {4, 0, 1}, + {5, 0, 1}, + {4, 1, 1}, + {5, 1, 1}, + {4, 2, 1}, + {5, 2, 1}, + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_KSplit_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + const index_t KSplit = 3; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_KSplit_M00_N0_M01Adapt + tile_map(c_grid_desc_m_n, M01, KSplit); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18 * KSplit); + + std::vector> expected_ksplitidx_m0idx_n0idx_valid = { + {0, 0, 0, 1}, {0, 1, 0, 1}, {0, 2, 0, 1}, {0, 3, 0, 1}, {0, 0, 1, 1}, {0, 1, 1, 1}, + {0, 2, 1, 1}, {0, 3, 1, 1}, {0, 0, 2, 1}, {0, 1, 2, 1}, {0, 2, 2, 1}, {0, 3, 2, 1}, + {0, 4, 0, 1}, {0, 5, 0, 1}, {0, 4, 1, 1}, {0, 5, 1, 1}, {0, 4, 2, 1}, {0, 5, 2, 1}, + {1, 0, 0, 1}, {1, 1, 0, 1}, {1, 2, 0, 1}, {1, 3, 0, 1}, {1, 0, 1, 1}, {1, 1, 1, 1}, + {1, 2, 1, 1}, {1, 3, 1, 1}, {1, 0, 2, 1}, {1, 1, 2, 1}, {1, 2, 2, 1}, {1, 3, 2, 1}, + {1, 4, 0, 1}, {1, 5, 0, 1}, {1, 4, 1, 1}, {1, 5, 1, 1}, {1, 4, 2, 1}, {1, 5, 2, 1}, + {2, 0, 0, 1}, {2, 1, 0, 1}, {2, 2, 0, 1}, {2, 3, 0, 1}, {2, 0, 1, 1}, {2, 1, 1, 1}, + {2, 2, 1, 1}, {2, 3, 1, 1}, {2, 0, 2, 1}, {2, 1, 2, 1}, {2, 2, 2, 1}, {2, 3, 2, 1}, + {2, 4, 0, 1}, {2, 5, 0, 1}, {2, 4, 1, 1}, {2, 5, 1, 1}, {2, 4, 2, 1}, {2, 5, 2, 1}, + }; + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto ksplitm0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", ksplit, m0, n0 = " << ksplitm0n0_idx[I0] << ", " + << ksplitm0n0_idx[I1] << ", " << ksplitm0n0_idx[I2]; + std::cout << ", valid = " + << tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_ksplitidx_m0idx_n0idx_valid[i] == + std::vector{ksplitm0n0_idx[I0], + ksplitm0n0_idx[I1], + ksplitm0n0_idx[I2], + tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +}