diff --git a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp index 23c4ad583e..21ca470222 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp @@ -63,48 +63,15 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8 static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4 - static CK_TILE_DEVICE constexpr auto MakeCBlockDist() + private: + template + struct LdsStoreDescSelector; + + template + struct LdsStoreDescSelector= WarpSize)>> { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2, 1>, // !! note here is different - sequence<0, 0>>{}; - - using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - return c_block_dstr; - } - - static CK_TILE_DEVICE constexpr auto MakeCBlockTile() - { - using CDataType = float; - constexpr auto c_block_dstr = MakeCBlockDist(); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; - } - - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() - { - // A async->LDS - // constexpr index_t Block_M = Problem::BlockShape::Block_M0; - // constexpr index_t Block_K = Problem::BlockShape::Block_K0; - // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t WarpSize = ck_tile::get_warp_size(); - // constexpr index_t NumWarps = Problem::BlockShape::NumWarps; - - constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS - constexpr index_t KVector = 2; // GetAlignment_A(); // async copy 1 dword - constexpr index_t KPad = KPack_; // pad between warps - - static_assert(Block_K % KVector == 0); - constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= WarpSize) + template + static CK_TILE_HOST_DEVICE constexpr auto MakeDesc() { // need multiple waves to load K static_assert(LanesPerK % WarpSize == 0); @@ -143,7 +110,13 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 return lds_block_desc_issues_warps_lanes; } } - else + }; + + template + struct LdsStoreDescSelector> + { + template + static CK_TILE_HOST_DEVICE constexpr auto MakeDesc() { // lanes within a wave load different M but same K static_assert(WarpSize % LanesPerK == 0); @@ -175,6 +148,49 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 return lds_block_desc_issues_warps_lanes; } + }; + + public: + static CK_TILE_DEVICE constexpr auto MakeCBlockDist() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, // !! note here is different + sequence<0, 0>>{}; + + using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + static CK_TILE_DEVICE constexpr auto MakeCBlockTile() + { + using CDataType = float; + constexpr auto c_block_dstr = MakeCBlockDist(); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() + { + // A async->LDS + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS + constexpr index_t KVector = 2; // GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack_; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + + return LdsStoreDescSelector:: + template MakeDesc(); } // template