From 3500259bdc59f1dde58bcde653d1aa0f2a833f42 Mon Sep 17 00:00:00 2001 From: joye Date: Tue, 20 May 2025 13:26:49 +0800 Subject: [PATCH] update async load apis --- .../core/arch/amd_buffer_addressing.hpp | 54 ++++---------- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 4 +- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 71 +++++++++++++++++++ 3 files changed, 88 insertions(+), 41 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 3586d4b5ca..ac051e3608 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1786,52 +1786,28 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, constexpr index_t bytes = sizeof(T) * N; static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); + if constexpr(oob_conditional_check) { index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - async_buffer_load{}(smem, - src_wave_buffer_resource, - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + smem, + bytes, + v_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); } else { - async_buffer_load{}(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + smem, + sizeof(uint32_t), + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); } - // #else - // static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - - // if constexpr(oob_conditional_check) - // { - // index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - // llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - // smem, - // sizeof(uint32_t), - // v_offset, - // src_wave_addr_offset, - // src_immediate_addr_offset, - // static_cast(coherence)); - // } - // else - // { - // llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - // smem, - // sizeof(uint32_t), - // src_thread_addr_offset, - // src_wave_addr_offset, - // src_immediate_addr_offset, - // static_cast(coherence)); - // } - // #endif } template +template struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync { using Base = BaseGemmPipelineAgBgCrCompAsync; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp new file mode 100644 index 0000000000..c0c379c523 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares +// the same vector size implementation, SmemSize, Global memory tile distiribution as the +// UniversalGemm Pipeline Policy. +// Default policy class should not be templated, put template on +// member functions instead. +struct GemmPipelineAgBgCrCompAsyncDefaultPolicy + : public UniversalGemmBasePolicy +{ + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackB(); + + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } +}; +} // namespace ck_tile