mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
update async load apis
This commit is contained in:
@@ -1786,52 +1786,28 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
async_buffer_load<bytes, false>{}(smem,
|
||||
src_wave_buffer_resource,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<false>{});
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
async_buffer_load<bytes, false>{}(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<false>{});
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
// #else
|
||||
// static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
// if constexpr(oob_conditional_check)
|
||||
// {
|
||||
// index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
// llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
// smem,
|
||||
// sizeof(uint32_t),
|
||||
// v_offset,
|
||||
// src_wave_addr_offset,
|
||||
// src_immediate_addr_offset,
|
||||
// static_cast<index_t>(coherence));
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
// smem,
|
||||
// sizeof(uint32_t),
|
||||
// src_thread_addr_offset,
|
||||
// src_wave_addr_offset,
|
||||
// src_immediate_addr_offset,
|
||||
// static_cast<index_t>(coherence));
|
||||
// }
|
||||
// #endif
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -49,7 +49,7 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
|
||||
* even when Compute Version 3's block size is twice that of Compute Version 4.
|
||||
*/
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompAsync<Problem>;
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
|
||||
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
|
||||
// UniversalGemm Pipeline Policy.
|
||||
// Default policy class should not be templated, put template on
|
||||
// member functions instead.
|
||||
struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
{
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user