mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
override base policys vector size with static_assert 4/12/16 bytes
This commit is contained in:
@@ -24,30 +24,50 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
static constexpr int NXdlPack = 1; // No N packing
|
||||
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
|
||||
|
||||
// Override vector size methods to force 16-byte loads for async buffer operations
|
||||
// Override vector size methods to ensure compatibility with async buffer operations
|
||||
// Valid sizes for amd_async_buffer_load are 4, 12, or 16 bytes
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
// Get packed sizes for A/B
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
// Return number of STORAGE elements to load 16 bytes
|
||||
constexpr index_t vector_size_for_16_bytes = 16 / sizeof(ADataType) * APackedSize;
|
||||
return vector_size_for_16_bytes;
|
||||
|
||||
// Call base policy's dynamic vector size calculation
|
||||
constexpr index_t vector_size =
|
||||
UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>::
|
||||
template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
|
||||
// Calculate actual byte load size (storage bytes = logical elements / PackedSize * sizeof)
|
||||
constexpr index_t byte_load_size = vector_size * sizeof(ADataType) / APackedSize;
|
||||
|
||||
// Ensure async buffer load requirements: must be 4, 12, or 16 bytes
|
||||
static_assert(byte_load_size == 4 || byte_load_size == 12 || byte_load_size == 16,
|
||||
"Vector load size must be 4, 12, or 16 bytes for async buffer operations");
|
||||
|
||||
return vector_size;
|
||||
}
|
||||
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
// Get packed sizes for A/B
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
// Return number of STORAGE elements to load 16 bytes
|
||||
constexpr index_t vector_size_for_16_bytes = 16 / sizeof(BDataType) * BPackedSize;
|
||||
return vector_size_for_16_bytes;
|
||||
|
||||
// Call base policy's dynamic vector size calculation
|
||||
constexpr index_t vector_size =
|
||||
UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>::
|
||||
template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
|
||||
// Calculate actual byte load size (storage bytes = logical elements / PackedSize * sizeof)
|
||||
constexpr index_t byte_load_size = vector_size * sizeof(BDataType) / BPackedSize;
|
||||
|
||||
// Ensure async buffer load requirements: must be 4, 12, or 16 bytes
|
||||
static_assert(byte_load_size == 4 || byte_load_size == 12 || byte_load_size == 16,
|
||||
"Vector load size must be 4, 12, or 16 bytes for async buffer operations");
|
||||
|
||||
return vector_size;
|
||||
}
|
||||
|
||||
// DRAM tile distributions use STORAGE dimensions (for the storage tensor view)
|
||||
|
||||
Reference in New Issue
Block a user