Wave Tile Transfer supporting global load with transpose (#3027)

* Initial implementation:

 - add new thread group transfer supporting transpose instruction
 - refactor AB transfer to switch between thread and wave tiles methods

* Add some comments and remove explicit wave and lane calculations

* Remove compiler option for performance

* fp16 example: use tuned instance

* Missing cleanup

* Integrate wave transfer in existing gemm and batched gemm instances

* Add fast instances

* extend implementation for 8 bit datatypes

packed types not supported

* Address review comments

* Optimize pipeline v1 and re-introduce compiler option

* Disable wave tile approach for b scale gemm

* Fix for clang20

* Avoid code duplication of amd_global_load_transpose_to_vgpr function
This commit is contained in:
Enrico Degregori
2025-10-16 20:33:56 +02:00
committed by GitHub
parent c4b2da9cbd
commit 440358c168
15 changed files with 1513 additions and 720 deletions

View File

@@ -0,0 +1,402 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_address_space.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
namespace ck {
template <typename ABLayout,
typename ABMajorLayout,
typename LDSTypeAB,
index_t BlockSize,
index_t MNPerBlock,
index_t KPerBlock,
index_t MNPerWmma,
index_t ABK1Value,
bool UseBlockPaddingAB,
bool PermuteAB,
typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
typename ABBlockTransferThreadClusterArrangeOrder,
typename ABBlockTransferSrcAccessOrder,
index_t ABBlockTransferSrcVectorDim,
index_t ABBlockTransferSrcScalarPerVector,
index_t ABBlockTransferDstScalarPerVector_ABK1,
bool ABThreadTransferSrcResetCoordinateAfterRun>
struct ABTransferThreadTiles
{
static constexpr auto ABK0Number = Number<KPerBlock / ABK1Value>{};
static constexpr auto ABK1Number = Number<ABK1Value>{};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t ABPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<LDSTypeAB>, pk_i4_t>)
return 2;
else
return 1;
}();
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto MakeGridDescriptor(const GridDescriptorBase& ab_grid_desc,
index_t MN,
index_t MNPad,
index_t K,
index_t KPad,
index_t StrideAB,
index_t ABK0)
{
if constexpr(PadMN && PadK)
{
// pad both MN and K
const auto ab_grid_desc_n_k =
transform_tensor_descriptor(ab_grid_desc,
make_tuple(make_right_pad_transform(MN, MNPad - MN),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
ab_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
make_pass_through_transform(MNPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ab_grid_desc_bk0_n_bk1;
}
else if constexpr(PadMN && !PadK)
{
// pad MN, but not K
const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
ab_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
make_right_pad_transform(MN, MNPad - MN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ab_grid_desc_bk0_n_bk1;
}
else if constexpr(!PadMN && PadK)
{
// pad K, but not MN
const auto ab_grid_desc_n_k = transform_tensor_descriptor(
ab_grid_desc,
make_tuple(make_pass_through_transform(MN), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
ab_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
make_pass_through_transform(MN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ab_grid_desc_bk0_n_bk1;
}
else
{
if constexpr(!PermuteAB)
{
// not pad MN or K
const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
ab_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
make_pass_through_transform(MN)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ab_grid_desc_bk0_n_bk1;
}
else
{
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, MN, KPerBlock / K1, K1] -> BTile[K / K1, MN, K1]
constexpr index_t ABK01 = KPerBlock / ABK1Value;
const index_t ABK0_ = StrideAB / ABK1Value;
const index_t ABK00 = ABK0_ / ABK01;
const auto ab_grid_desc_abk00_mn_abk01_abk1_permute =
make_naive_tensor_descriptor_packed(make_tuple(ABK00, MN, ABK01, ABK1Value));
const auto ab_grid_desc_abk0_mn_abk1_permute = transform_tensor_descriptor(
ab_grid_desc_abk00_mn_abk01_abk1_permute,
make_tuple(make_merge_transform(make_tuple(ABK00, ABK01)),
make_pass_through_transform(make_tuple(MN)),
make_pass_through_transform(ABK1Value)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return ab_grid_desc_abk0_mn_abk1_permute;
}
}
}
__device__ static constexpr auto GetBlockDescriptor()
{
// A matrix in LDS memory, dst of blockwise copy
if constexpr(UseBlockPaddingAB)
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor(
make_tuple(ABK0Number, Number<MNPerBlock>{}, ABK1Number),
make_tuple(Number<MNPerBlock + 1>{} * ABK1Number, ABK1Number, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else if constexpr(is_same<ABMajorLayout, ABLayout>::value)
{
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeAB) / ABPackedSize;
constexpr auto MNLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(ABK0Number * Number<MNLdsLayer>{},
Number<MNPerBlock / MNLdsLayer>{},
ABK1Number),
make_tuple(ABK1Number, Number<KPerBlock * MNLdsLayer>{}, I1));
constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
ab_lds_block_desc,
make_tuple(
make_xor_with_modulo_transform(make_tuple(Number<MNPerBlock / MNLdsLayer>{},
Number<ABK0Number * MNLdsLayer>{})),
make_pass_through_transform(ABK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto ab_lds_block_desc_abk0_mnldslayer_mn_abk1 = transform_tensor_descriptor(
ab_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(ABK0Number, Number<MNLdsLayer>{})),
make_pass_through_transform(Number<MNPerBlock / MNLdsLayer>{}),
make_pass_through_transform(ABK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
ab_lds_block_desc_abk0_mnldslayer_mn_abk1,
make_tuple(make_pass_through_transform(ABK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MNPerBlock / MNLdsLayer>{}, Number<MNLdsLayer>{})),
make_pass_through_transform(ABK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return ab_lds_block_desc_abk0_mn_abk1;
}
else
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto MN0 = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I1);
constexpr auto MN1 = MNPerBlock / MN0;
constexpr auto KThreadWrite = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I0);
constexpr auto K0PerThreadWrite = ABK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MNPerWmma;
constexpr auto K0PerThreadRead = ABK0Number / KThreadRead;
constexpr auto kfold = (ABK1Number * MN0 * sizeof(LDSTypeAB) > 128)
? 1
: 128 / (ABK1Number * MN0 * sizeof(LDSTypeAB));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (ABK1Number * MNPerWmma * sizeof(LDSTypeAB) > 128)
? 1
: ((128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB))) > MN0
? MN0
: 128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB)));
constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * MN1>{},
Number<kfold * MN0 / mpair>{},
Number<mpair>{},
ABK1Number));
constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
ab_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * MN1>{}, Number<kfold * MN0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(ABK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto ab_lds_block_desc_unmerged = transform_tensor_descriptor(
ab_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<MN1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<MN0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(ABK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
ab_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<MN0 / mpair>{}, Number<mpair>{}, Number<MN1>{})),
make_pass_through_transform(ABK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return ab_lds_block_desc_abk0_mn_abk1;
}
}
template <typename GridDescriptor,
typename BlockDescriptor,
typename ABsDataType,
typename ABElementwiseOperation,
index_t GlobalBufferNum>
__device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
BlockDescriptor& block_descriptor,
ABElementwiseOperation& ab_element_op,
const index_t block_mn_id)
{
constexpr index_t NumABTensor = ABsDataType::Size();
const index_t mn_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_mn_id * MNPerBlock);
// workaround because v7r2 is not as general as v4r1
if constexpr(NumABTensor > 1)
{
const auto idx_as_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
Number<NumABTensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
ABsDataType,
Tuple<LDSTypeAB>,
GridDescriptor,
decltype(tie(block_descriptor)),
ABElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<ABK0Number, MNPerBlock, ABK1Number>,
ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
ABBlockTransferThreadClusterArrangeOrder,
ABBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABBlockTransferSrcVectorDim,
2,
ABBlockTransferSrcScalarPerVector,
ABBlockTransferDstScalarPerVector_ABK1,
uniform_sequence_gen_t<NumABTensor, ABThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>,
GlobalBufferNum>{grid_descriptor,
idx_as_block_begin,
tie(block_descriptor),
make_tuple(make_multi_index(0, 0, 0)),
ab_element_op};
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
ABElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<ABK0Number, MNPerBlock, ABK1Number>,
ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
ABBlockTransferThreadClusterArrangeOrder,
remove_cvref_t<tuple_element_t<0, ABsDataType>>,
remove_cvref_t<tuple_element_t<0, ABsDataType>>,
decltype(grid_descriptor[I0]),
decltype(block_descriptor),
ABBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABBlockTransferSrcVectorDim,
2,
ABBlockTransferSrcScalarPerVector,
ABBlockTransferDstScalarPerVector_ABK1,
1,
1,
ABThreadTransferSrcResetCoordinateAfterRun,
true,
GlobalBufferNum>(grid_descriptor[I0],
make_multi_index(0, mn_block_data_idx_on_grid, 0),
ab_element_op,
block_descriptor,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
}
template <index_t MNRepeat, index_t MNWaves>
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
{
// This is a block descriptor used to read LDS memory into register
// It's defined in a way consistent with the existing implementation to
// avoid changes in the pipelines
using BlockDesc = decltype(GetBlockDescriptor());
// ABK0_MN_ABK1 -> ABK0_MNRepeat_MNWaves_KRow_MNPerWmma_ABK1
constexpr auto ABK0 = BlockDesc{}.GetLength(I0);
constexpr auto ABK1 = BlockDesc{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto KRow = I2;
#else
constexpr auto KRow = I1;
#endif
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow)),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
__device__ static constexpr auto GetBlockStep()
{
// Grid descriptor step (MoveSrcSliceWindow)
return make_multi_index(KPerBlock / ABK1Number, 0, 0);
}
template <typename GridDescriptor>
__device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
{
// K dimension size. This should always be called with the A matrix grid descriptor
// because it doesn't work for B matrix when packed int4 is used
return grid_desc.GetLength(I0) * grid_desc.GetLength(I2);
}
};
} // namespace ck

View File

@@ -0,0 +1,343 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_address_space.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/utility/math.hpp"
namespace ck {
template <typename ABLayout,
typename ABMajorLayout,
typename LDSTypeAB,
index_t BlockSize,
index_t MNPerBlock,
index_t KPerBlock,
index_t MNPerWmma,
index_t KPack,
index_t ABK1Value,
index_t WaveSize>
struct ABTransferWaveTiles
{
static_assert(!(is_same_v<remove_cvref_t<LDSTypeAB>, pk_i4_t>),
"wave tile transfer method does not support pk_i4_t");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t MNKRow = 2;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Tiles distribution for global memory loading
// Notes: support for not power of 2 needs to be reviewed later on
// The tiles are distributed along the non-contiguous matrix dimension
// Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64
// MRepeat = 1, KRepeat = 4
// -------------
// |W0| | | |
// -------------
// |W1| | | |
// -------------
// |W2| | | |
// -------------
// |W3| | | |
// -------------
// Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64
// MRepeat = 4, KRepeat = 1
// -------------
// |W0|W1|W2|W3|
// -------------
// | | | | |
// -------------
// | | | | |
// -------------
// | | | | |
// -------------
static constexpr index_t NumberOfWaves = BlockSize / WaveSize;
static constexpr index_t MNMajorWaves_ =
MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0
? std::min(MNPerBlock / MNPerWmma, NumberOfWaves)
: (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1);
static constexpr index_t KMajorWaves_ =
KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0
? std::min(KPerBlock / KPack, NumberOfWaves)
: (KPerBlock / KPack % 2 == 0 ? 2 : 1);
static constexpr bool ABDoTranspose = !is_same_v<ABLayout, ABMajorLayout>;
static constexpr index_t MNWaves_ =
ABDoTranspose ? NumberOfWaves / KMajorWaves_ : MNMajorWaves_;
static constexpr index_t KWaves_ = ABDoTranspose ? KMajorWaves_ : NumberOfWaves / MNMajorWaves_;
static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
template <bool PadMN, bool PadK, typename GridDescriptorBase>
__host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
index_t sizeMN,
index_t,
index_t sizeK,
index_t,
index_t,
index_t)
{
// Notes: padding is currently not supported
static_assert(!PadMN && !PadK, "padding is currently not supported");
// Divide the base descriptor MN_K into tiles
const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
base_desc,
make_tuple(
make_unmerge_transform(make_tuple(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{}), Number<MNPerWmma>{})),
make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number<KPack>{}),
Number<KPack>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
// The distinction is needed to get the same global indices for both layouts
// Divide each tile in 2 16x8 subtile
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
// MNKRow = 0-1
// LaneLocal = 0-15
// VectorSize must be 8
if constexpr(!ABDoTranspose)
{
const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(
math::integer_divide_ceil(sizeK, Number<KPack>{})),
make_pass_through_transform(Number<MNPerWmma>{}),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
// Freeze VectorSize to first element of the loading chunk (for convenience)
// Swap MNPerWmma and MNKRow for consistency with transpose descriptor
return transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
make_pass_through_transform(Number<MNPerWmma>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<2>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<>{}));
}
else
{
const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles,
make_tuple(make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(
math::integer_divide_ceil(sizeK, Number<KPack>{})),
make_unmerge_transform(
make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
make_pass_through_transform(Number<KPack>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
// Freeze VectorSize to first element of the loading chunk (for convenience)
return transform_tensor_descriptor(
ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
make_tuple(
make_pass_through_transform(
math::integer_divide_ceil(sizeMN, Number<MNPerWmma>{})),
make_pass_through_transform(math::integer_divide_ceil(sizeK, Number<KPack>{})),
make_pass_through_transform(Number<MNKRow>{}),
make_freeze_transform(I0),
make_pass_through_transform(Number<KPack>{})),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
}
}
__device__ static constexpr auto GetBlockDescriptor()
{
// LDS memory layouts:
// lanes within tiles stored contiguously in chunks of 8 elements
// tiles are then stored first in K dimension
// MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
const auto a_grid_desc_mraw_kraw = [&]() {
return make_naive_tensor_descriptor(
make_tuple(Number<MNRepeat_ * MNWaves_>{},
Number<KRepeat_ * KWaves_>{},
Number<MNKRow>{},
Number<MNPerWmma>{},
Number<ABK1Value>{}),
make_tuple(Number<KPack * MNPerWmma * KWaves_ * KRepeat_>{},
Number<KPack * MNPerWmma>{},
Number<ABK1Value * MNPerWmma>{},
Number<ABK1Value>{},
I1));
}();
// Freeze VectorSize to first element of the chunk (for convenience)
return transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(Number<MNRepeat_ * MNWaves_>{}),
make_pass_through_transform(Number<KRepeat_ * KWaves_>{}),
make_pass_through_transform(Number<MNKRow>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_freeze_transform(I0)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<>{}));
}
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MNWaves_, KWaves_, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetBlockLaneIdx()
{
const index_t lane_id = __lane_id();
constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma;
constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MNKRow, LanesPerSubTile))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
}
template <typename ABDataType>
__device__ static auto GetGridLaneIdx()
{
const index_t lane_id = __lane_id();
constexpr index_t SubTilesRow = MNKRow;
constexpr index_t SubTilesCol = 4 / sizeof(ABDataType);
constexpr index_t LanesPerSubTile =
ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol;
constexpr auto dims_tuple = ABDoTranspose
? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile)
: make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile);
constexpr auto laneid_to_grid_lane_idx_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(dims_tuple)),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto indices =
laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
if constexpr(!ABDoTranspose)
{
return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]);
}
else
{
return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]);
}
}
template <typename GridDescriptor,
typename BlockDescriptor,
typename ABsDataType,
typename ABElementwiseOperation,
index_t GlobalBufferNum>
__device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
BlockDescriptor& block_descriptor,
ABElementwiseOperation& ab_element_op,
const index_t block_mn_id)
{
// Note: GlobalBufferNum is currently not used but it will be needed
// once we add other pipelines. It is currently needed only for
// consistency with the thread tiles approach
static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
constexpr index_t NumABTensor = ABsDataType::Size();
static_assert(NumABTensor == 1, "multiAB currently not supported");
using ABDataType = remove_cvref_t<tuple_element_t<0, ABsDataType>>;
const auto wave_idx = GetWaveIdx();
index_t wave_idK = wave_idx[I1];
index_t wave_idMN = wave_idx[I0];
const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
index_t lane_group_grid = grid_lane_id[I0];
index_t lane_local_id_grid = grid_lane_id[I1];
const auto block_lane_id = GetBlockLaneIdx();
index_t lane_group_block = block_lane_id[I0];
index_t lane_local_id_block = block_lane_id[I1];
return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
BlockDescriptor,
ABDataType,
ABDataType,
ABElementwiseOperation,
Sequence<MNRepeat_, KRepeat_, I1, I1>,
Sequence<MNWaves_, KWaves_, I1, I1>,
Sequence<I0, I1, I2, I3>,
ABK1Value,
ABDoTranspose>(
grid_descriptor[I0],
block_descriptor,
make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
wave_idK,
lane_group_grid,
lane_local_id_grid),
make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block),
ab_element_op);
}
template <index_t MNRepeat, index_t MNWaves>
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
{
// This is a block descriptor used to read LDS memory into register
// It's defined in a way consistent with the existing implementation to
// avoid changes in the pipelines
return make_naive_tensor_descriptor(make_tuple(Number<KPerBlock / KPack>{},
Number<MNRepeat>{},
Number<MNWaves>{},
Number<MNKRow>{},
Number<MNPerWmma>{},
Number<ABK1Value>{}),
make_tuple(Number<KPack * MNPerWmma>{},
Number<KPerBlock * MNPerWmma * MNWaves>{},
Number<KPerBlock * MNPerWmma>{},
Number<MNPerWmma * ABK1Value>{},
Number<ABK1Value>{},
I1));
}
__device__ static constexpr auto GetBlockStep()
{
// Grid descriptor step (MoveSrcSliceWindow)
return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0);
}
template <typename GridDescriptor>
__device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
{
return grid_desc.GetLength(I1) * KPack;
}
};
} // namespace ck

View File

@@ -175,7 +175,8 @@ template <typename ALayout,
typename ComputeTypeA,
typename ComputeTypeB,
bool PermuteA,
bool PermuteB>
bool PermuteB,
bool ForceThreadTileTransfer = false>
struct GridwiseGemm_wmma_cshuffle_v3
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
@@ -227,7 +228,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>
PermuteB,
ForceThreadTileTransfer>
{
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
@@ -279,7 +281,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
ForceThreadTileTransfer>;
using Base::I0;
using Base::I1;
@@ -318,9 +321,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumATensor;
using Base::NumBTensor;
using Base::NumDTensor;

View File

@@ -122,7 +122,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>
PermuteB,
true>
{
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
@@ -174,7 +175,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
PermuteB,
true>;
using Base::I0;
using Base::I1;
@@ -213,9 +215,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumATensor;
using Base::NumBTensor;
using Base::NumDTensor;

View File

@@ -14,10 +14,13 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -107,7 +110,8 @@ template <typename ALayout,
typename ComputeTypeA,
typename ComputeTypeB,
bool PermuteA,
bool PermuteB>
bool PermuteB,
bool ForceThreadTileTransfer = false> // only needed for convolution (limitation)
struct GridwiseGemm_wmma_cshuffle_v3_base
{
@@ -162,6 +166,101 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return 1;
}();
// Limitations of the current implementation:
// - no multiAB
// - GemmSpecialization Default
// - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation)
// AK1Value == 8 is not really a limitation but a requirement for the method so
// it will stay
#ifdef __gfx12__
static constexpr bool IsAWaveTransferApplicable =
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8;
static constexpr bool IsBWaveTransferApplicable =
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
#else
static constexpr bool IsAWaveTransferApplicable = false;
static constexpr bool IsBWaveTransferApplicable = false;
#endif
static constexpr index_t WaveSize =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
static constexpr bool UseBlockPaddingA =
ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
using ATransfer = typename std::conditional<
IsAWaveTransferApplicable,
ABTransferWaveTiles<ALayout,
tensor_layout::gemm::RowMajor,
LDSTypeA,
BlockSize,
MPerBlock,
KPerBlock,
MPerWmma,
KPack,
AK1Value,
WaveSize>,
ABTransferThreadTiles<ALayout,
tensor_layout::gemm::RowMajor,
LDSTypeA,
BlockSize,
MPerBlock,
KPerBlock,
MPerWmma,
AK1Value,
UseBlockPaddingA,
PermuteA,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
AThreadTransferSrcResetCoordinateAfterRun>>::type;
static constexpr bool UseBlockPaddingB =
BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
using BTransfer = typename std::conditional<
IsBWaveTransferApplicable,
ABTransferWaveTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
KPack,
BK1Value,
WaveSize>,
ABTransferThreadTiles<BLayout,
tensor_layout::gemm::ColumnMajor,
LDSTypeB,
BlockSize,
NPerBlock,
KPerBlock,
NPerWmma,
BK1Value,
UseBlockPaddingB,
PermuteB,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BThreadTransferSrcResetCoordinateAfterRun>>::type;
static_assert(!(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t> &&
GemmSpec != tensor_operation::device::GemmSpecialization::Default),
"pk_i4_t does not support padding");
static_assert(!PermuteA, "PermuteA is not supported");
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
@@ -222,27 +321,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return math::integer_divide_ceil(N, NPerBlock);
}
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
{
// K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1
constexpr auto K0 = BlockDesc{}.GetLength(I0);
constexpr auto K1 = BlockDesc{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto KRow = I2;
#else
constexpr auto KRow = I1;
#endif
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(make_unmerge_transform(make_tuple(Number<K0 / KRow>{}, KRow)),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
static constexpr auto MakeAsGridPointer()
{
return generate_tuple(
@@ -268,87 +346,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using AsGridPointer = decltype(MakeAsGridPointer());
using BsGridPointer = decltype(MakeBsGridPointer());
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
__host__ __device__ static auto MakeAGridDescriptor_M_K(index_t M, index_t K, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
}
__host__ __device__ static auto MakeBGridDescriptor_N_K(index_t N, index_t K, index_t StrideB)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
static_assert(!PermuteA, "PermuteA is not supported");
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}
@@ -360,123 +378,25 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const std::array<index_t, NumATensor>& StrideAs,
const index_t AK0)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding;
constexpr bool padK = GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding;
return generate_tuple(
[&](auto i) {
return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0);
const auto base_desc = MakeAGridDescriptor_M_K(M, K, StrideAs[i]);
return ATransfer::template MakeGridDescriptor<padM, padK>(
base_desc, M, MPad, K, KPad, StrideAs[i], AK0);
},
Number<NumATensor>{});
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
if constexpr(!PermuteB)
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr index_t BK01 = KPerBlock / BK1Value;
const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01;
const auto b_grid_desc_bk00_n_bk01_bk1_permute =
make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
b_grid_desc_bk00_n_bk01_bk1_permute,
make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_pass_through_transform(make_tuple(N)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_grid_desc_bk0_n_bk1_permute;
}
}
}
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
const index_t KPad,
@@ -485,27 +405,36 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const std::array<index_t, NumBTensor>& StrideBs,
const index_t BK0)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding;
constexpr bool padK = GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding;
return generate_tuple(
[&](auto i) {
return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0);
const auto base_desc = MakeBGridDescriptor_N_K(N, K, StrideBs[i]);
return BTransfer::template MakeGridDescriptor<padN, padK>(
base_desc, N, NPad, K, KPad, StrideBs[i], BK0);
},
Number<NumBTensor>{});
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&)
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor()
{
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return MakeWmmaTileDescriptor<MRepeat, MWaves, MPerWmma>(ABlockDesc_AK0_M_AK1{});
return ATransfer::template MakeWmmaTileDescriptor<MRepeat, MWaves>();
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&)
__host__ __device__ static constexpr auto MakeBWmmaTileDescriptor()
{
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
return BTransfer::template MakeWmmaTileDescriptor<NRepeat, NWaves>();
}
template <typename DELayout>
@@ -610,278 +539,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
Number<NumDTensor>{});
}
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock>{} * AK1Number, AK1Number, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize;
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
else // ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerWmma;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(LDSTypeA));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerWmma * sizeof(LDSTypeA) > 128)
? 1
: ((128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))) > M0
? M0
: 128 / (AK1Number * MPerWmma * sizeof(LDSTypeA)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * M1>{},
Number<kfold * M0 / mpair>{},
Number<mpair>{},
AK1Number));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<M1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0 / mpair>{}, Number<mpair>{}, Number<M1>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
}
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number<NLdsLayer>{})),
make_pass_through_transform(Number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(make_pass_through_transform(BK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<NPerBlock / NLdsLayer>{}, Number<NLdsLayer>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
}
else // RowMajor B
{
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerWmma;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(LDSTypeB));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerWmma * sizeof(LDSTypeB) > 128)
? 1
: ((128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))) > N0
? N0
: 128 / (BK1Number * NPerWmma * sizeof(LDSTypeB)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * N1>{},
Number<kfold * N0 / npair>{},
Number<npair>{},
BK1Number));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<N1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0 / npair>{}, Number<npair>{}, Number<N1>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
}
}
__host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
@@ -899,28 +556,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
LDSTypeA,
LDSTypeB,
ComputeTypeA,
ComputeTypeB,
AccDataType,
decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>())>;
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmPipeline_Selector<BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
LDSTypeA,
LDSTypeB,
ComputeTypeA,
ComputeTypeB,
AccDataType,
decltype(MakeAWmmaTileDescriptor()),
decltype(MakeBWmmaTileDescriptor()),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>())>;
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -1168,8 +824,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto a_block_desc_ak0_m_ak1 = ATransfer::GetBlockDescriptor();
constexpr auto b_block_desc_bk0_n_bk1 = BTransfer::GetBlockDescriptor();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
@@ -1257,161 +913,32 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto a_block_desc_ak0_m_ak1 = ATransfer::GetBlockDescriptor();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto b_block_desc_bk0_n_bk1 = BTransfer::GetBlockDescriptor();
// A matrix blockwise copy
// workaround because v7r2 is not as general as v4r1
auto get_a_blockwise_transfer = [&]() {
if constexpr(NumATensor > 1)
{
const auto idx_as_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
Tuple<LDSTypeA>,
AGridDesc_AK0_M_K1,
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>,
BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
idx_as_block_begin,
tie(a_block_desc_ak0_m_ak1),
make_tuple(make_multi_index(0, 0, 0)),
a_element_op};
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
remove_cvref_t<tuple_element_t<0, AsDataType>>,
remove_cvref_t<tuple_element_t<0, AsDataType>>,
decltype(as_grid_desc_ak0_m_ak1[I0]),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
as_grid_desc_ak0_m_ak1[I0],
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_transfer();
auto a_blockwise_copy =
ATransfer::template GetBlockTransfer<AGridDesc_AK0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
AsDataType,
AElementwiseOperation,
BlockwiseGemmPipe::GlobalBufferNum>(
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id);
// B matrix blockwise copy
// workaround because v7r2 is not as general as v4r1
auto get_b_blockwise_transfer = [&]() {
if constexpr(NumBTensor > 1)
{
const auto idx_bs_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
Tuple<LDSTypeB>,
BGridDesc_BK0_N_K1,
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>,
BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
idx_bs_block_begin,
tie(b_block_desc_bk0_n_bk1),
make_tuple(make_multi_index(0, 0, 0)),
b_element_op};
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
remove_cvref_t<tuple_element_t<0, BsDataType>>,
remove_cvref_t<tuple_element_t<0, BsDataType>>,
decltype(bs_grid_desc_bk0_n_bk1[I0]),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
bs_grid_desc_bk0_n_bk1[I0],
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto b_blockwise_copy = get_b_blockwise_transfer();
auto b_blockwise_copy =
BTransfer::template GetBlockTransfer<BGridDesc_BK0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
BsDataType,
BElementwiseOperation,
BlockwiseGemmPipe::GlobalBufferNum>(
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -1427,8 +954,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
constexpr auto a_block_slice_copy_step = ATransfer::GetBlockStep();
constexpr auto b_block_slice_copy_step = BTransfer::GetBlockStep();
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
@@ -1436,8 +963,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
KPerBlock);
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),