mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
WIP: debugging...
This commit is contained in:
@@ -295,22 +295,46 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
|
||||
////////////// global window & register /////////////////
|
||||
// A DRAM tile window(s) for load
|
||||
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
// Get bottom tensor view and window origin: need to divide by APackedSize
|
||||
auto&& bottom_tensor_view = a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
|
||||
auto&& tensor_ptr = reinterpret_cast<const ADataType*>(&(bottom_tensor_view.get_buffer_view()(0)));
|
||||
auto&& tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
tensor_ptr,
|
||||
make_tuple(4096, 4096 / APackedSize),
|
||||
make_tuple(4096 / APackedSize, 1),
|
||||
number<32>{},
|
||||
number<1>{});
|
||||
const auto& origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
// a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
tensor_view,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
{origin[0], origin[1] / APackedSize},
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
// B DRAM window(s) for load
|
||||
auto b_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
// Get bottom tensor view and window origin: need to divide by BPackedSize
|
||||
auto&& bottom_tensor_view = b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
|
||||
auto&& tensor_ptr = reinterpret_cast<const BDataType*>(&(bottom_tensor_view.get_buffer_view()(0)));
|
||||
auto&& tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
tensor_ptr,
|
||||
make_tuple(4096, 4096 / BPackedSize),
|
||||
make_tuple(4096 / BPackedSize, 1),
|
||||
number<32>{},
|
||||
number<1>{});
|
||||
const auto& origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
// b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
tensor_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
// b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
{origin[0], origin[1] / BPackedSize},
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
@@ -397,9 +421,9 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize);
|
||||
|
||||
// read A(0), B(0) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
@@ -420,10 +444,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
// tile distribution for the register tiles
|
||||
// Use custom distributions that account for packed types
|
||||
constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
make_static_tile_distribution(Policy::template MakeALdsBlockDistributionEncode<Problem>());
|
||||
constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
make_static_tile_distribution(Policy::template MakeBLdsBlockDistributionEncode<Problem>());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
@@ -281,6 +281,163 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// Custom warp distribution encodings that account for packed types
|
||||
// For 16x16x128 MFMA with pk_fp4_t, the K dimension must use storage elements, not logical elements
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_AWarpDstrEncoding()
|
||||
{
|
||||
// For 16x16x128 MFMA with pk_fp4_t (PackedSize=2)
|
||||
// Physical layout in registers: [16 M-lanes, 4 K-lanes, 16 bytes per lane]
|
||||
// Each byte stores 2 fp4 values, so 16 bytes = 32 fp4 values
|
||||
// But we need to use STORAGE size (16) not LOGICAL size (32) in the distribution
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 32 / APackedSize; // Storage elements, not logical!
|
||||
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BWarpDstrEncoding()
|
||||
{
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
constexpr index_t kBNLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 32 / BPackedSize; // Storage elements!
|
||||
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kBNLane>, sequence<kABKLane, kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
|
||||
// Custom LDS block distributions that account for packed types
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDistributionEncode()
|
||||
{
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
// IMPORTANT: Use packed K for iteration count
|
||||
// LDS shape is [MPerBlock, KPerBlock / APackedSize]
|
||||
// WarpGemm expects [MPerXdl, KPerXdl / APackedSize] per warp per iteration
|
||||
constexpr index_t KIterPerWarp = (KPerBlock / APackedSize) / (KPerXdl / APackedSize);
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
|
||||
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding<Problem>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDistributionEncode()
|
||||
{
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
// IMPORTANT: Use packed K for iteration count
|
||||
// LDS shape is [NPerBlock, KPerBlock / BPackedSize]
|
||||
// WarpGemm expects [NPerXdl, KPerXdl / BPackedSize] per warp per iteration
|
||||
constexpr index_t KIterPerWarp = (KPerBlock / BPackedSize) / (KPerXdl / BPackedSize);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
|
||||
constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
// Use custom warp encoding that accounts for packed types
|
||||
return detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding<Problem>());
|
||||
}
|
||||
}
|
||||
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
// Using the proven "Flat" patterns from v1 policy
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user