compiles again using get_y_sliced_thread_data in warpgemm loop

This commit is contained in:
Sami Remes
2026-01-23 11:01:43 -05:00
parent f09e10936d
commit d2a7c2f041
2 changed files with 45 additions and 31 deletions

View File

@@ -299,7 +299,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
@@ -309,7 +309,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
@@ -364,6 +364,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
scale_b_dram_window.get_load_offset(tuple<number<0>, number<NWarp * NPerXdl>>{}));
// this pipeline has a pair of LDS buffers per logical tile
// TODO: check for packed size - are these blocks too big?
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
@@ -372,14 +373,14 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
if constexpr(is_a_load_tr_v)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
return make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v)
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
return make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{});
}();
// LDS tile windows for storing, one per LDS buffer
@@ -439,6 +440,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * KPerXdl) / (ScaleBlockSize * KXdlPack);
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter is wrong!");
// Load a sample scale tile to get the type
auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{});
@@ -520,7 +522,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
});
// Warp GEMM loop with MX scaling
auto warp_gemm_loop = [&](auto& a_block_tile, auto& b_block_tile, auto& scale_a, auto& scale_b) {
auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) {
// Extract A/B values from block tiles to warp iteration structure
constexpr auto a_warp_y_lengths =
to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
@@ -537,25 +539,22 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
constexpr auto OpSelA = kScaleInPack;
// read A warp tensor from A block tensor
typename WarpGemm::AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
constexpr auto OpSelB = kScaleInPack;
// Extract A/B values for this iteration - create warp tensors
typename WarpGemm::AWarpTensor a_warp_tensor{};
const auto a_thread_data = a_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, a_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
a_warp_tensor.get_thread_buffer()(i) = a_thread_data[i];
});
typename WarpGemm::BWarpTensor b_warp_tensor{};
const auto b_thread_data = b_block_tile.get_y_sliced_thread_data(
// read B warp tensor from B block tensor
typename WarpGemm::BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
static_for<0, b_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
b_warp_tensor.get_thread_buffer()(i) = b_thread_data[i];
});
WarpGemm{}.template operator()<OpSelA, OpSelB>(
c_warp_tensors(m_iter)(n_iter),

View File

@@ -29,9 +29,9 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
{
// Get packed sizes for A/B
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// Force 16-byte vector loads for optimal async buffer performance
// For fp4 (1 byte): 16 elements = 16 bytes
// For fp8 (1 byte): 16 elements = 16 bytes
@@ -53,9 +53,9 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
// Get packed sizes for A/B
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// Force 16-byte vector loads for optimal async buffer performance
// For fp4 (1 byte): 16 elements = 16 bytes
// For fp8 (1 byte): 16 elements = 16 bytes
@@ -86,13 +86,17 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
using ALayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
// Get packed sizes for A/B
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
MPerBlock,
KPerBlock,
KPerBlock / APackedSize,
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
@@ -123,6 +127,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
// Get packed sizes for A/B
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -141,7 +150,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
KPerBlock / BPackedSize,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
@@ -153,8 +162,13 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
// Get packed sizes for A/B
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize;
if constexpr(is_a_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
@@ -191,8 +205,13 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
// Get packed sizes for A/B
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize;
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
@@ -300,10 +319,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma
static_assert(K_Lane == 4, "K_Lane must be 4 for 16x16 mfma");
static_assert(NPerXdl == 16, "NPerXdl must be 16 for 16x16 mfma");
static_assert(MWarp == 1, "MWarp must be 1 for 16x16 mfma");
// Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile
// Layout is [K, N] where K is packed int32