[CK_TILE] Use read_tr in universal gemm (#2436)

* Use read_tr in universal gemm

* Enable all instances back

* Revert example37 changes

* Resolve comments

* resolve comments 2

* Fix assertion msg

* fix the gemm basic

* change index_t to bool for preshuffle variable

* Solve the comment

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>

[ROCm/composable_kernel commit: f1d8ad2818]
This commit is contained in:
Yi DING
2025-07-17 14:56:22 +08:00
committed by GitHub
parent ac0a7f44e1
commit 3bb687d502
26 changed files with 1150 additions and 565 deletions

View File

@@ -15,9 +15,9 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
@@ -29,14 +29,14 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"

View File

@@ -13,9 +13,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
@@ -44,10 +44,10 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"

View File

@@ -1,10 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
@@ -15,6 +16,19 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if defined(__gfx950__)
constexpr bool is_a_load_tr = std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_load_tr = std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
#else
constexpr bool is_a_load_tr = false;
constexpr bool is_b_load_tr = false;
#endif
constexpr auto wg_attr_num_access = (is_a_load_tr || is_b_load_tr)
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -40,14 +54,34 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
using WG = WarpGemmMfmaDispatcher<ck_tile::half_t,
ck_tile::half_t,
float,
32,
32,
16,
true,
false,
false,
wg_attr_num_access>;
return make_tuple(WG{}, 4, 1);
#endif
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
using WG = WarpGemmMfmaDispatcher<ck_tile::bf16_t,
ck_tile::bf16_t,
float,
32,
32,
16,
true,
false,
false,
wg_attr_num_access>;
return make_tuple(WG{}, 4, 1);
}
else
{

View File

@@ -218,10 +218,16 @@ struct BlockUniversalGemmAsBsCr
BLdsTile b_warp_tile_;
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
@@ -300,14 +306,23 @@ struct BlockUniversalGemmAsBsCr
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
@@ -316,6 +331,10 @@ struct BlockUniversalGemmAsBsCr
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_block_window);
}
else
{
load_tile(b_warp_tile_, b_block_window);
@@ -323,10 +342,16 @@ struct BlockUniversalGemmAsBsCr
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
[[maybe_unused]] ASmemBlockWindow& a_block_window,
[[maybe_unused]] BSmemBlockWindow& b_block_window)
[[maybe_unused]] BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
@@ -382,40 +407,73 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
make_static_tile_distribution(MakeABlockDistributionEncode());
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
make_static_tile_distribution(MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
ALdsTile b_warp_tile_;
BLdsTile b_warp_tile_;
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
template <index_t KIdx,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
constexpr auto a_lds_load_distr = [&]() {
if constexpr(ALoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeABlockDistributionEncode()),
ADataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeABlockDistributionEncode());
}();
constexpr auto b_lds_load_distr = [&]() {
if constexpr(BLoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeBBlockDistributionEncode()),
BDataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeBBlockDistributionEncode());
}();
constexpr auto a_lds_shape = []() {
if constexpr(ALoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
else
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(BLoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
else
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
constexpr auto a_offset =
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
constexpr auto b_offset =
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
auto a_lds_gemm_window = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{}),
{0, KIdx * KPerInnerLoop},
a_lds_load_tile_distr);
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{}),
{0, KIdx * KPerInnerLoop},
b_lds_load_tile_distr);
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
}
else
{
load_tile(a_warp_tile_, a_lds_gemm_window);
@@ -424,6 +482,10 @@ struct BlockUniversalGemmAsBsCr
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
}
else
{
load_tile(b_warp_tile_, b_lds_gemm_window);
@@ -431,10 +493,16 @@ struct BlockUniversalGemmAsBsCr
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
@@ -442,7 +510,7 @@ struct BlockUniversalGemmAsBsCr
// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
__builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
@@ -543,29 +611,45 @@ struct BlockUniversalGemmAsBsCr
return c_block_tensor;
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename CBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C = A * B
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
auto c_block_tensor = MakeCBlockTile();
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
return c_block_tensor;
}

View File

@@ -20,6 +20,13 @@ struct GemmPipelineAgBgCrImplBase
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
#if defined(__gfx950__)
static constexpr bool is_a_load_tr = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_b_load_tr = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
#else
static constexpr bool is_a_load_tr = false;
static constexpr bool is_b_load_tr = false;
#endif
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
@@ -50,11 +57,15 @@ struct GemmPipelineAgBgCrImplBase
store_tile(lds_tile_window, block_tile_tmp);
}
template <typename DstBlockTile, typename SrcTileWindow>
template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window) const
const SrcTileWindow& lds_tile_window,
bool_constant<LoadTranspose> = {}) const
{
load_tile(dst_block_tile, lds_tile_window);
if constexpr(LoadTranspose)
dst_block_tile = load_tile_transpose(lds_tile_window);
else
load_tile(dst_block_tile, lds_tile_window);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
@@ -96,14 +107,25 @@ struct GemmPipelineAgBgCrImplBase
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_lds_shape = []() {
if constexpr(is_a_load_tr)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
auto a_lds_load_tile_distr = []() {
if constexpr(is_a_load_tr)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename ALdsLoadTileDistr::DstrEncode,
typename Problem::ADataType>::TransposedDstrEncode{});
else
return ALdsLoadTileDistr{};
}();
auto a_lds_gemm_window =
make_tile_window(a_lds_block_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsLoadTileDistr{});
make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
@@ -130,14 +152,25 @@ struct GemmPipelineAgBgCrImplBase
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_lds_shape = []() {
if constexpr(is_b_load_tr)
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
auto b_lds_load_tile_distr = []() {
if constexpr(is_b_load_tr)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename BLdsLoadTileDistr::DstrEncode,
typename Problem::BDataType>::TransposedDstrEncode{});
else
return BLdsLoadTileDistr{};
}();
auto b_lds_gemm_window =
make_tile_window(b_lds_block_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsLoadTileDistr{});
make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),

View File

@@ -153,6 +153,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop);
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
using Base::PrefetchStages;
using Base::UsePersistentKernel;
@@ -467,7 +470,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -478,7 +481,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -494,7 +497,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
@@ -506,7 +510,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -517,7 +521,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -536,7 +540,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
@@ -578,7 +583,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// __builtin_amdgcn_sched_barrier(0);

View File

@@ -141,6 +141,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -305,17 +308,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
auto a_copy_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v())
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
auto a_copy_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window1 = make_tile_window(
b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v())
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
// Block GEMM
auto block_gemm = BlockGemm();
@@ -325,7 +334,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -336,7 +345,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -354,51 +363,53 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
block_sync_lds();
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
constexpr auto ALdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0, a_block_tile1;
BLdsTile b_block_tile0, b_block_tile1;
ALdsTile a_block_tile0;
ALdsTile a_block_tile1;
BLdsTile b_block_tile0;
BLdsTile b_block_tile1;
constexpr auto a_lds_input_tile_distr = [&]() {
if constexpr(is_a_load_tr_v())
return make_static_tile_distribution(
typename InputTileDistributionTraits<
decltype(BlockGemm::MakeABlockDistributionEncode()),
typename Problem::ADataType>::TransposedDstrEncode{});
else
return ALdsTileDistr;
}();
constexpr auto b_lds_input_tile_distr = [&]() {
if constexpr(is_b_load_tr_v())
return make_static_tile_distribution(
typename InputTileDistributionTraits<
decltype(BlockGemm::MakeBBlockDistributionEncode()),
typename Problem::BDataType>::TransposedDstrEncode{});
else
return BLdsTileDistr;
}();
auto a_lds_ld_window0 =
make_tile_window(a_lds_block0,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto a_lds_ld_window1 =
make_tile_window(a_lds_block1,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsTileDistr);
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto b_lds_ld_window0 =
make_tile_window(b_lds_block0,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
auto b_lds_ld_window1 =
make_tile_window(b_lds_block1,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsTileDistr);
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
static_assert(
!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>)&&!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>)&&!(
is_tile_window_linear_v<
decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
"LDS windows must not be linear");
static_assert(!is_tile_window_linear_v<decltype(a_lds_ld_window0)> &&
!is_tile_window_linear_v<decltype(a_lds_ld_window1)> &&
!is_tile_window_linear_v<decltype(b_lds_ld_window0)> &&
!is_tile_window_linear_v<decltype(b_lds_ld_window1)>,
"LDS windows must not be linear");
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -409,7 +420,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -433,10 +444,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// ping
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -448,7 +459,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
Base::LocalPrefill(
a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -473,10 +484,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// pong
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -488,7 +499,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
Base::LocalPrefill(
a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -521,9 +532,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// 3
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
if constexpr(is_a_col_major)
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -534,7 +545,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -550,8 +561,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
// 1
@@ -565,8 +576,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// 2
{
block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;

View File

@@ -21,15 +21,27 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr bool single_load_tr_length =
(DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) ==
(WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size());
constexpr auto wg_attr_num_access =
((is_a_load_tr<Problem> || is_b_load_tr<Problem>)&&!single_load_tr_length)
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType, // AccDataType
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
Problem::TransposeC,
false,
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,

View File

@@ -196,6 +196,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -272,10 +275,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto& b_lds_block = ab_lds_blocks.at(I1{});
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
@@ -332,7 +335,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -343,7 +346,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -373,12 +376,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -394,7 +398,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -427,12 +431,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -445,7 +450,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -461,14 +466,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
});
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
};
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else if constexpr(TailNum == TailNumber::Two)
@@ -558,10 +565,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto& b_lds_block = ab_lds_blocks.at(I1{});
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
@@ -617,7 +624,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -628,7 +635,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -658,10 +665,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
// no second block_sync_lds because it's interwave
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -677,7 +688,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -709,10 +720,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto HotLoopTail = [&](auto tail_num) {
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
// no second block_sync_lds because it's interwave
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -725,7 +740,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -741,13 +756,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
});
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
};
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile,
a_lds_gemm_window,
b_lds_gemm_window,
is_a_load_tr_v,
is_b_load_tr_v);
}
else if constexpr(TailNum == TailNumber::Two)
{

View File

@@ -47,6 +47,8 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool Preshuffle = Problem::Preshuffle;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t kLdsAlignmentInBytes = 16;

View File

@@ -49,6 +49,9 @@ struct GemmPipelineProblemBase
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
// In the base situation, the Preshuffle setting should be false.
static constexpr bool Preshuffle = false;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -12,6 +12,20 @@ namespace ck_tile {
template <typename Derived>
struct UniversalGemmBasePolicy
{
#if defined(__gfx950__)
template <typename Problem>
static constexpr bool is_a_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::ALayout>, tensor_layout::gemm::ColumnMajor>;
template <typename Problem>
static constexpr bool is_b_load_tr =
std::is_same_v<remove_cvref_t<typename Problem::BLayout>, tensor_layout::gemm::RowMajor>;
#else
template <typename Problem>
static constexpr bool is_a_load_tr = false;
template <typename Problem>
static constexpr bool is_b_load_tr = false;
#endif
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
@@ -22,51 +36,65 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
if constexpr(is_a_load_tr<Problem>)
{
// TODO: better lds descriptor for performance
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
make_tuple(number<MPerBlock>{}, number<1>{}),
number<MPerBlock>{},
number<1>{});
return a_lds_block_desc_0;
}
else
{
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
number<MPerBlock / MLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
number<KPerBlock / KPack * MLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
number<MPerBlock / MLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
number<KPerBlock / KPack * MLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
return a_lds_block_desc;
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
}
/**
@@ -78,14 +106,24 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
#if 1
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better lds descriptor for performance
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
make_tuple(number<NPerBlock>{}, number<1>{}),
number<NPerBlock>{},
number<1>{});
return b_lds_block_desc_0;
}
else
// else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
@@ -584,8 +622,18 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
@@ -594,7 +642,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity>;
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,

View File

@@ -84,7 +84,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
static constexpr bool Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()

View File

@@ -21,22 +21,29 @@ using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
@@ -56,25 +63,33 @@ using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
@@ -123,22 +138,29 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
@@ -159,25 +181,33 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
2,
AttrNumAccess>>;
#endif
#if defined(__gfx950__)
@@ -247,17 +277,25 @@ using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfma<
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>>>;
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -8,10 +8,22 @@
namespace ck_tile {
template <typename WarpGemmAttributeMfmaImpl_>
// Number of groups of consecutive elements to fill in a ABKLane
enum class WGAttrNumAccessEnum
{
Single = 1,
Double = 2,
Quad = 4,
Invalid = -1
};
template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfma
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
@@ -31,21 +43,35 @@ struct WarpGemmAtrributeMfma
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
template <index_t kMNLane>
static constexpr auto get_warp_dstr_encoding()
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
@@ -73,12 +99,16 @@ struct WarpGemmAtrributeMfma
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
template <typename WarpGemmAttributeMfmaImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfmaIterateK
{
static_assert(kKIter > 0, "wrong!");
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
@@ -104,17 +134,37 @@ struct WarpGemmAtrributeMfmaIterateK
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
sequence<AttrNumAccessV,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
@@ -127,6 +177,8 @@ struct WarpGemmAtrributeMfmaIterateK
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
@@ -143,17 +195,38 @@ struct WarpGemmAtrributeMfmaIterateK
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<AttrNumAccessV,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
@@ -166,6 +239,8 @@ struct WarpGemmAtrributeMfmaIterateK
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
static_assert(AttrNumAccessV == 1,
"Multiple access is not supported when using multi-block");
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
@@ -289,10 +364,13 @@ struct WarpGemmAtrributeMfmaIterateK
}
};
template <typename WarpGemmAttributeMfmaImpl_>
template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfmaTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
@@ -312,21 +390,35 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
template <index_t kMNLane>
static constexpr auto get_warp_dstr_encoding()
{
if constexpr(AttrNumAccessV == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
{
static_assert(kKPerThread % AttrNumAccessV == 0,
"kKPerThread must be divisible by NumAccess");
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
@@ -450,10 +542,13 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
template <typename WarpGemmAttributeMfmaImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
// swap A and B
using ADataType = typename Impl::BDataType;
@@ -478,80 +573,14 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
get_bwarp_dstr_encoding();
}
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
get_awarp_dstr_encoding();
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()

View File

@@ -16,8 +16,9 @@ template <typename AType,
index_t NPerWave,
index_t KPerWave,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false>
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
struct WarpGemmMfmaDispatcher;
// clang-format off
@@ -25,12 +26,20 @@ struct WarpGemmMfmaDispatcher;
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
@@ -46,12 +55,20 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
@@ -80,10 +97,18 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
// int8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
@@ -102,8 +127,9 @@ template <typename AType,
index_t NPerWave,
index_t KPerWave,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false>
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
BType,
AccType,
@@ -112,6 +138,7 @@ using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
KPerWave,
TransposeC,
SwizzleA,
UseStructuredSparsity>::Type;
UseStructuredSparsity,
AttrNumAccess>::Type;
} // namespace ck_tile