mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Use read_tr in universal gemm (#2436)
* Use read_tr in universal gemm * Enable all instances back * Revert example37 changes * Resolve comments * resolve comments 2 * Fix assertion msg * fix the gemm basic * change index_t to bool for preshuffle variable * Solve the comment --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -15,6 +16,19 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool is_a_load_tr = std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_load_tr = std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
#else
|
||||
constexpr bool is_a_load_tr = false;
|
||||
constexpr bool is_b_load_tr = false;
|
||||
#endif
|
||||
constexpr auto wg_attr_num_access = (is_a_load_tr || is_b_load_tr)
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -40,14 +54,34 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
using WG = WarpGemmMfmaDispatcher<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
using WG = WarpGemmMfmaDispatcher<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -218,10 +218,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
@@ -300,14 +306,23 @@ struct BlockUniversalGemmAsBsCr
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
@@ -316,6 +331,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_ = load_tile_transpose(b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
@@ -323,10 +342,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window)
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
@@ -382,40 +407,73 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <index_t KIdx,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
constexpr auto a_lds_load_distr = [&]() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeABlockDistributionEncode()),
|
||||
ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto b_lds_load_distr = [&]() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeBBlockDistributionEncode()),
|
||||
BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
|
||||
constexpr auto a_offset =
|
||||
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
constexpr auto b_offset =
|
||||
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
a_lds_load_tile_distr);
|
||||
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
b_lds_load_tile_distr);
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
a_warp_tile_ = load_tile_transpose(a_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_lds_gemm_window);
|
||||
@@ -424,6 +482,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
b_warp_tile_ = load_tile_transpose(b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_lds_gemm_window);
|
||||
@@ -431,10 +493,16 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
@@ -442,7 +510,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KRepeat, 1>{}([&](auto kIter) {
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC
|
||||
// cluster, but except the first, as we can shorten non-MAC cluster a bit
|
||||
@@ -543,29 +611,45 @@ struct BlockUniversalGemmAsBsCr
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename CBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
template <typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
|
||||
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user