override base policys vector size with static_assert 4/12/16 bytes

This commit is contained in:
Sami Remes
2026-01-30 03:55:56 -05:00
parent 409a7d8edb
commit 2cc0e3d019

View File

@@ -24,30 +24,50 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
static constexpr int NXdlPack = 1; // No N packing
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
// Override vector size methods to force 16-byte loads for async buffer operations
// Override vector size methods to ensure compatibility with async buffer operations
// Valid sizes for amd_async_buffer_load are 4, 12, or 16 bytes
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
{
// Get packed sizes for A/B
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// Return number of STORAGE elements to load 16 bytes
constexpr index_t vector_size_for_16_bytes = 16 / sizeof(ADataType) * APackedSize;
return vector_size_for_16_bytes;
// Call base policy's dynamic vector size calculation
constexpr index_t vector_size =
UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>::
template GetVectorSizeA<Problem, IsWave32Host>();
// Calculate actual byte load size (storage bytes = logical elements / PackedSize * sizeof)
constexpr index_t byte_load_size = vector_size * sizeof(ADataType) / APackedSize;
// Ensure async buffer load requirements: must be 4, 12, or 16 bytes
static_assert(byte_load_size == 4 || byte_load_size == 12 || byte_load_size == 16,
"Vector load size must be 4, 12, or 16 bytes for async buffer operations");
return vector_size;
}
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
// Get packed sizes for A/B
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// Return number of STORAGE elements to load 16 bytes
constexpr index_t vector_size_for_16_bytes = 16 / sizeof(BDataType) * BPackedSize;
return vector_size_for_16_bytes;
// Call base policy's dynamic vector size calculation
constexpr index_t vector_size =
UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>::
template GetVectorSizeB<Problem, IsWave32Host>();
// Calculate actual byte load size (storage bytes = logical elements / PackedSize * sizeof)
constexpr index_t byte_load_size = vector_size * sizeof(BDataType) / BPackedSize;
// Ensure async buffer load requirements: must be 4, 12, or 16 bytes
static_assert(byte_load_size == 4 || byte_load_size == 12 || byte_load_size == 16,
"Vector load size must be 4, 12, or 16 bytes for async buffer operations");
return vector_size;
}
// DRAM tile distributions use STORAGE dimensions (for the storage tensor view)