This commit is contained in:
Sami Remes
2026-01-16 08:22:11 -05:00
parent f6f9931541
commit 16ca5cb532
7 changed files with 135 additions and 6 deletions

View File

@@ -672,6 +672,9 @@ struct UniversalGemmKernel
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
static_assert(GemmPipeline::GetVectorSizeA() == GemmPipeline::GetVectorSizeB(), "Vector size of A and B must be the same!");
static_assert(GemmPipeline::GetVectorSizeA() == 16, "Vector size of A must be 16!");
static_assert(GemmPipeline::GetVectorSizeB() == 16, "Vector size of B must be 16!");
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(

View File

@@ -314,6 +314,18 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
/// Check tile window traits for vector size
using ATileDstr = remove_cvref_t<decltype(Policy::template MakeADramTileDistribution<Problem>())>;
// static_assert(ATileDstr::LargestVec >= 16, "wrong! not implemented vector size");
// static_assert(ATileDstr::X1 >= 16, "wrong! not implemented vector size");
using BTileDstr = remove_cvref_t<decltype(Policy::template MakeBDramTileDistribution<Problem>())>;
// static_assert(BTileDstr::LargestVec >= 16, "wrong! not implemented vector size");
// static_assert(BTileDstr::X1 >= 16, "wrong! not implemented vector size");
using ATileType = remove_cvref_t<decltype(a_tile_windows[number<0>{}])>;
using BTileType = remove_cvref_t<decltype(b_tile_windows[number<0>{}])>;
static_assert(sizeof(typename ATileType::Traits::vector_t) == 16, "wrong! not implemented vector size");
static_assert(sizeof(typename BTileType::Traits::vector_t) == 16, "wrong! not implemented vector size");
////////////// MX Scale windows /////////////////
// Get WarpGemm configuration

View File

@@ -24,6 +24,115 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
static constexpr int NXdlPack = 1; // No N packing
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
// Override vector size methods to force 16-byte loads for async buffer operations
// Valid sizes for amd_async_buffer_load are 4, 12, or 16 bytes
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// Force 16-byte vector loads for optimal async buffer performance
// For fp4 (1 byte): 16 elements = 16 bytes
// For fp8 (1 byte): 16 elements = 16 bytes
// For fp16 (2 bytes): 8 elements = 16 bytes
// constexpr index_t vector_size_for_16_bytes = 16 / sizeof(ADataType);
// return vector_size_for_16_bytes;
return 16;
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// Force 16-byte vector loads for optimal async buffer performance
// For fp4 (1 byte): 16 elements = 16 bytes
// For fp8 (1 byte): 16 elements = 16 bytes
// For fp16 (2 bytes): 8 elements = 16 bytes
// constexpr index_t vector_size_for_16_bytes = 16 / sizeof(BDataType);
// return vector_size_for_16_bytes;
return 16;
}
// Override DRAM tile distributions to use the constrained vector sizes
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using ALayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
MPerBlock,
KPerBlock,
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
static_assert(false, "Not implemented");
// using TileEncodingPattern =
// tile_distribution_encoding_pattern_2d<BlockSize,
// KPerBlock,
// MPerBlock,
// VecLoadSize,
// getATileAccessPattern(),
// NumWaveGroups>;
// return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
static_assert(false, "Not implemented");
// using TileEncodingPattern =
// tile_distribution_encoding_pattern_2d<BlockSize,
// KPerBlock,
// NPerBlock,
// VecLoadSize,
// getBTileAccessPattern(),
// NumWaveGroups>;
// return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
@@ -44,7 +153,8 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
}
else
{
constexpr index_t KPack = GetSmemPackA<Problem>();
// constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr index_t KPack = 16;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
@@ -81,7 +191,8 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
}
else
{
constexpr index_t KPack = GetSmemPackB<Problem>();
// constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr index_t KPack = 16;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),