update async load apis

This commit is contained in:
joye
2025-05-20 13:26:49 +08:00
parent 4e7f427b88
commit 3500259bdc
3 changed files with 88 additions and 41 deletions

View File

@@ -1786,52 +1786,28 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
async_buffer_load<bytes, false>{}(smem,
src_wave_buffer_resource,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<false>{});
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
bytes,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
async_buffer_load<bytes, false>{}(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<false>{});
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
// #else
// static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
// if constexpr(oob_conditional_check)
// {
// index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
// llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
// smem,
// sizeof(uint32_t),
// v_offset,
// src_wave_addr_offset,
// src_immediate_addr_offset,
// static_cast<index_t>(coherence));
// }
// else
// {
// llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
// smem,
// sizeof(uint32_t),
// src_thread_addr_offset,
// src_wave_addr_offset,
// src_immediate_addr_offset,
// static_cast<index_t>(coherence));
// }
// #endif
}
template <index_t N,

View File

@@ -4,7 +4,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
namespace ck_tile {
@@ -49,7 +49,7 @@ struct BaseGemmPipelineAgBgCrCompAsync
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompAsync<Problem>;

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackB<Problem>();
return make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile