From 2cc0e3d0199c2638272b068db079b6ca7c1970a6 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 30 Jan 2026 03:55:56 -0500 Subject: [PATCH] override base policys vector size with static_assert 4/12/16 bytes --- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 7d5feecb8f..146d42abb2 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -24,30 +24,50 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy static constexpr int NXdlPack = 1; // No N packing static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32 - // Override vector size methods to force 16-byte loads for async buffer operations + // Override vector size methods to ensure compatibility with async buffer operations // Valid sizes for amd_async_buffer_load are 4, 12, or 16 bytes template CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA() { - // Get packed sizes for A/B using AsDataType = remove_cvref_t; using ADataType = remove_cvref_t{}, AsDataType>>; constexpr index_t APackedSize = numeric_traits>::PackedSize; - // Return number of STORAGE elements to load 16 bytes - constexpr index_t vector_size_for_16_bytes = 16 / sizeof(ADataType) * APackedSize; - return vector_size_for_16_bytes; + + // Call base policy's dynamic vector size calculation + constexpr index_t vector_size = + UniversalGemmBasePolicy:: + template GetVectorSizeA(); + + // Calculate actual byte load size (storage bytes = logical elements / PackedSize * sizeof) + constexpr index_t byte_load_size = vector_size * sizeof(ADataType) / APackedSize; + + // Ensure async buffer load requirements: must be 4, 12, or 16 bytes + static_assert(byte_load_size == 4 || byte_load_size == 12 || byte_load_size == 16, + "Vector load size must be 4, 12, or 16 bytes for async buffer operations"); + + return vector_size; } template CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { - // Get packed sizes for A/B using BsDataType = remove_cvref_t; using BDataType = remove_cvref_t{}, BsDataType>>; constexpr index_t BPackedSize = numeric_traits>::PackedSize; - // Return number of STORAGE elements to load 16 bytes - constexpr index_t vector_size_for_16_bytes = 16 / sizeof(BDataType) * BPackedSize; - return vector_size_for_16_bytes; + + // Call base policy's dynamic vector size calculation + constexpr index_t vector_size = + UniversalGemmBasePolicy:: + template GetVectorSizeB(); + + // Calculate actual byte load size (storage bytes = logical elements / PackedSize * sizeof) + constexpr index_t byte_load_size = vector_size * sizeof(BDataType) / BPackedSize; + + // Ensure async buffer load requirements: must be 4, 12, or 16 bytes + static_assert(byte_load_size == 4 || byte_load_size == 12 || byte_load_size == 16, + "Vector load size must be 4, 12, or 16 bytes for async buffer operations"); + + return vector_size; } // DRAM tile distributions use STORAGE dimensions (for the storage tensor view)