mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
revert custom ldstile, should be able to use the regular ones
This commit is contained in:
@@ -437,11 +437,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
// tile distribution for the register tiles
|
||||
// Use custom distributions that account for packed types
|
||||
constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(Policy::template MakeALdsBlockDistributionEncode<Problem>());
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(Policy::template MakeBLdsBlockDistributionEncode<Problem>());
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
@@ -450,6 +449,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
ALdsTile a_block_tile0, a_block_tile1;
|
||||
BLdsTile b_block_tile0, b_block_tile1;
|
||||
|
||||
static_assert(sizeof(ALdsTile) == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize) * NWarp / BlockSize, "ALdsTile size is wrong!");
|
||||
static_assert(sizeof(BLdsTile) == NPerBlock * (KPerBlock * sizeof(BDataType) / BPackedSize) * MWarp / BlockSize, "BLdsTile size is wrong!");
|
||||
static_assert(Policy::template GetSmemSizeA<Problem>() == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), "SmemSizeA size is wrong!");
|
||||
static_assert(Policy::template GetSmemSizeB<Problem>() == (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, "SmemSizeB size is wrong!");
|
||||
|
||||
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
|
||||
// Calculate scale iterations: each scale covers 32 elements in K
|
||||
// Each K iteration processes KPerXdl elements
|
||||
|
||||
@@ -244,148 +244,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// Custom warp distribution encodings that account for packed types
|
||||
// For 16x16x128 MFMA with pk_fp4_t, the K dimension must use storage elements, not logical elements
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_AWarpDstrEncoding()
|
||||
{
|
||||
// For 16x16x128 MFMA with pk_fp4_t (PackedSize=2)
|
||||
// Physical layout in registers: [16 M-lanes, 4 K-lanes, 16 bytes per lane]
|
||||
// Each byte stores 2 fp4 values, so 16 bytes = 32 fp4 values
|
||||
// WarpGemm expects LOGICAL dimensions, so use 32 (logical fp4), not 16 (storage)
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 32; // LOGICAL elements (not divided by PackedSize)!
|
||||
// have also tried 16 here
|
||||
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BWarpDstrEncoding()
|
||||
{
|
||||
constexpr index_t kBNLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 32; // LOGICAL elements (not divided by PackedSize)! have also tried 16
|
||||
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
|
||||
// Custom LDS block distributions that account for packed types
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDistributionEncode()
|
||||
{
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
// Use LOGICAL dimensions for iteration count (matches WarpGemm expectations)
|
||||
// LDS shape is [MPerBlock, KPerBlock / APackedSize] in storage (bytes)
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; // Logical K iterations
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
|
||||
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
// here the iters don't get affected by PackedSize
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDistributionEncode()
|
||||
{
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
// Use LOGICAL dimensions for iteration count (matches WarpGemm expectations)
|
||||
// LDS shape is [NPerBlock, KPerBlock / BPackedSize] in storage
|
||||
// But distributions work in logical space
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; // Logical K iterations
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
|
||||
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
|
||||
}
|
||||
}
|
||||
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
// Using the proven "Flat" patterns from v1 policy
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user