revert custom ldstile, should be able to use the regular ones

This commit is contained in:
Sami Remes
2026-01-28 10:37:13 -05:00
parent 30d4c25d5a
commit 0033748c62
2 changed files with 7 additions and 145 deletions

View File

@@ -437,11 +437,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// tile distribution for the register tiles
// Use custom distributions that account for packed types
constexpr auto ALdsTileDistr =
make_static_tile_distribution(Policy::template MakeALdsBlockDistributionEncode<Problem>());
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(Policy::template MakeBLdsBlockDistributionEncode<Problem>());
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
@@ -450,6 +449,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
ALdsTile a_block_tile0, a_block_tile1;
BLdsTile b_block_tile0, b_block_tile1;
static_assert(sizeof(ALdsTile) == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize) * NWarp / BlockSize, "ALdsTile size is wrong!");
static_assert(sizeof(BLdsTile) == NPerBlock * (KPerBlock * sizeof(BDataType) / BPackedSize) * MWarp / BlockSize, "BLdsTile size is wrong!");
static_assert(Policy::template GetSmemSizeA<Problem>() == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), "SmemSizeA size is wrong!");
static_assert(Policy::template GetSmemSizeB<Problem>() == (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, "SmemSizeB size is wrong!");
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
// Calculate scale iterations: each scale covers 32 elements in K
// Each K iteration processes KPerXdl elements

View File

@@ -244,148 +244,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
// Custom warp distribution encodings that account for packed types
// For 16x16x128 MFMA with pk_fp4_t, the K dimension must use storage elements, not logical elements
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_AWarpDstrEncoding()
{
// For 16x16x128 MFMA with pk_fp4_t (PackedSize=2)
// Physical layout in registers: [16 M-lanes, 4 K-lanes, 16 bytes per lane]
// Each byte stores 2 fp4 values, so 16 bytes = 32 fp4 values
// WarpGemm expects LOGICAL dimensions, so use 32 (logical fp4), not 16 (storage)
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 32; // LOGICAL elements (not divided by PackedSize)!
// have also tried 16 here
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BWarpDstrEncoding()
{
constexpr index_t kBNLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 32; // LOGICAL elements (not divided by PackedSize)! have also tried 16
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
// Custom LDS block distributions that account for packed types
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDistributionEncode()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// Use LOGICAL dimensions for iteration count (matches WarpGemm expectations)
// LDS shape is [MPerBlock, KPerBlock / APackedSize] in storage (bytes)
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; // Logical K iterations
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
if constexpr(UseDefaultScheduler)
{
// here the iters don't get affected by PackedSize
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
}
else
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDistributionEncode()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// Use LOGICAL dimensions for iteration count (matches WarpGemm expectations)
// LDS shape is [NPerBlock, KPerBlock / BPackedSize] in storage
// But distributions work in logical space
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; // Logical K iterations
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
if constexpr(UseDefaultScheduler)
{
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
}
else
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
}
}
// MX Scale tile distributions for loading from global memory
// Using the proven "Flat" patterns from v1 policy
template <typename Problem>