WIP: debugging...

This commit is contained in:
Sami Remes
2026-01-26 11:33:45 -05:00
parent d2a7c2f041
commit 70c7fcda43
2 changed files with 190 additions and 8 deletions

View File

@@ -295,22 +295,46 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
////////////// global window & register /////////////////
// A DRAM tile window(s) for load
auto a_tile_windows = generate_tuple(
[&](auto idx) {
// Get bottom tensor view and window origin: need to divide by APackedSize
auto&& bottom_tensor_view = a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
auto&& tensor_ptr = reinterpret_cast<const ADataType*>(&(bottom_tensor_view.get_buffer_view()(0)));
auto&& tensor_view = make_naive_tensor_view<address_space_enum::global>(
tensor_ptr,
make_tuple(4096, 4096 / APackedSize),
make_tuple(4096 / APackedSize, 1),
number<32>{},
number<1>{});
const auto& origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
// a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
tensor_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
{origin[0], origin[1] / APackedSize},
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// B DRAM window(s) for load
auto b_tile_windows = generate_tuple(
[&](auto idx) {
// Get bottom tensor view and window origin: need to divide by BPackedSize
auto&& bottom_tensor_view = b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
auto&& tensor_ptr = reinterpret_cast<const BDataType*>(&(bottom_tensor_view.get_buffer_view()(0)));
auto&& tensor_view = make_naive_tensor_view<address_space_enum::global>(
tensor_ptr,
make_tuple(4096, 4096 / BPackedSize),
make_tuple(4096 / BPackedSize, 1),
number<32>{},
number<1>{});
const auto& origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
// b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
tensor_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
// b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
{origin[0], origin[1] / BPackedSize},
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
@@ -397,9 +421,9 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize);
// read A(0), B(0) from DRAM to LDS window(0)
// and advance the DRAM windows
@@ -420,10 +444,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// tile distribution for the register tiles
// Use custom distributions that account for packed types
constexpr auto ALdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
make_static_tile_distribution(Policy::template MakeALdsBlockDistributionEncode<Problem>());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
make_static_tile_distribution(Policy::template MakeBLdsBlockDistributionEncode<Problem>());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));

View File

@@ -281,6 +281,163 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
// Custom warp distribution encodings that account for packed types
// For 16x16x128 MFMA with pk_fp4_t, the K dimension must use storage elements, not logical elements
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_AWarpDstrEncoding()
{
// For 16x16x128 MFMA with pk_fp4_t (PackedSize=2)
// Physical layout in registers: [16 M-lanes, 4 K-lanes, 16 bytes per lane]
// Each byte stores 2 fp4 values, so 16 bytes = 32 fp4 values
// But we need to use STORAGE size (16) not LOGICAL size (32) in the distribution
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 32 / APackedSize; // Storage elements, not logical!
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BWarpDstrEncoding()
{
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr index_t kBNLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 32 / BPackedSize; // Storage elements!
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
// Custom LDS block distributions that account for packed types
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDistributionEncode()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// IMPORTANT: Use packed K for iteration count
// LDS shape is [MPerBlock, KPerBlock / APackedSize]
// WarpGemm expects [MPerXdl, KPerXdl / APackedSize] per warp per iteration
constexpr index_t KIterPerWarp = (KPerBlock / APackedSize) / (KPerXdl / APackedSize);
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
if constexpr(UseDefaultScheduler)
{
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
}
else
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDistributionEncode()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// IMPORTANT: Use packed K for iteration count
// LDS shape is [NPerBlock, KPerBlock / BPackedSize]
// WarpGemm expects [NPerXdl, KPerXdl / BPackedSize] per warp per iteration
constexpr index_t KIterPerWarp = (KPerBlock / BPackedSize) / (KPerXdl / BPackedSize);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
if constexpr(UseDefaultScheduler)
{
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
}
else
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use custom warp encoding that accounts for packed types
return detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
}
}
// MX Scale tile distributions for loading from global memory
// Using the proven "Flat" patterns from v1 policy
template <typename Problem>