Implement batched gemm add relu gemm add for rdna4 (#3391)

* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation

* wip: many fixes in implementation of batched gemm gemm multiple d

* wip: batched gemm gemm multiple d gridwise op compiling, not working yet

* fix: incorrect d0 grid indexing in batched gemm gemm multipled

* feat: add instances for batched gemm add relu gemm add

* chore: configure instance with low vector transfer size for odd sizes

* chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense

* fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes

* fix: disable odd size tests on XDL archs

* chore: removed temporary logging

* chore: update some references to C tensor to E tensor

* Tentative fix for example template params

* Tentative fix for non-multi-D batched gemm gemm device impl.

* Tentative fix for xdl example template params

* Tentative fix for profiler build on gfx90a

* chore: improve device batched gemm gemm multi D comment to include all ops and dimensions

* chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply

* fix: make the gemm1 data types match what happens in the device op

* feat: add d0s/d1s datatypes and layouts to the device op type string

* chore: change element-wise op so addition happens in fp32

* chore: add static asserts for gemm0/gemm1 calculated wave sizes

* chore: also updated other element-wise ops to use fp32 calculations

* chore: log number of supported instances

* chore: update instance comment

* chore: disable kernel timing in example by default

* fix: gemm1 wave size calculation

* fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions

* chore: remove increased tolerance in batched gemm gemm multiple d example

* chore: add comment explaining that verification fails for certain input values

* chore: clarify instance comment

---------

Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
Erwin Terpstra
2026-01-20 22:06:59 +01:00
committed by GitHub
parent 91b4102a59
commit d5ae81b292
22 changed files with 2956 additions and 499 deletions

View File

@@ -20,6 +20,7 @@
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
namespace tensor_operation {
@@ -51,12 +52,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
Tuple<>{}, // p_d0s_grid
arg.p_b1_grid + b1_batch_offset,
Tuple<>{}, // p_d1s_grid
arg.p_c_grid + c_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.b1_grid_desc,
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
@@ -240,8 +245,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
// DataType Family
ADataType,
B0DataType,
Tuple<>, // Ds0DataType
AccDataType, // Acc0DataType
B1DataType,
Tuple<>, // Ds1DataType
AccDataType, // Acc1DataType
CShuffleDataType,
CDataType,
@@ -255,7 +262,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
Tuple<>, // Ds0GridDesc
B1GridDesc,
Tuple<>, // Ds1GridDesc
CGridDesc_M_N,
// Tiling Family
MPerBlock,
@@ -290,6 +299,7 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
B0BlockTransferDstScalarPerVector_K1,
true,
B0BlockLdsAddExtraL,
1, // CDE0BlockTransferSrcScalarPerVector
B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
@@ -369,8 +379,8 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
}
// Pointers
const ADataType* p_a_grid;
@@ -405,10 +415,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n;
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
@@ -500,7 +510,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_grid_desc_m_n,
arg.block_2_ctile_map))
{

View File

@@ -825,6 +825,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
if(!ck::is_xdl_wmma_supported<A0DataType, B0DataType, Gemm0MPerXdl, Gemm0NPerXdl>())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "wrong! XDL/WMMA not supported for these datatypes or operation sizes."
<< std::endl;
}
return false;
}
@@ -843,6 +848,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "wrong! Unsupported tensor layout combination." << std::endl;
}
return false;
}

View File

@@ -101,6 +101,15 @@ struct GemmGemmPadder
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
// D0[M, N]
template <typename D0Desc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw& d0_desc_mraw_nraw) const
{
return PadTensorDescriptor(
d0_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
// B1[Gemm1N, Gemm1K] = B1[O, N]
template <typename B1Desc_NRaw_KRaw>
__host__ __device__ constexpr auto

View File

@@ -13,31 +13,35 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
// Gemm0: AccOp(A [M x K] x B0 [K x L], D0) = Acc [M x L]
// Gemm1: CDEOp1(Acc [M x L] x B1 [L x N], D1) = E [M x N]
template <typename ADataType,
typename B0DataType,
typename D0sDataType,
typename Acc0DataType,
typename B1DataType,
typename D1sDataType,
typename Acc1DataType,
typename CShuffleDataType,
typename CDataType,
typename E1DataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc,
typename B0GridDesc,
typename D0sGridDesc,
typename B1GridDesc,
typename CGridDesc_M_N,
typename D1sGridDesc,
typename E1GridDesc,
index_t MPerBlock,
index_t LPerBlock,
index_t KPerBlock,
@@ -69,6 +73,7 @@ template <typename ADataType,
index_t B0BlockTransferDstScalarPerVector_K1,
bool B0ThreadTransferSrcResetCoordinateAfterRun,
bool B0BlockLdsExtraL,
index_t CDE0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
@@ -79,8 +84,8 @@ template <typename ADataType,
bool B1BlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
bool PadN,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
@@ -94,6 +99,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
@@ -105,9 +113,19 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
static constexpr auto BL0 = Number<L0PerBlock>{};
static constexpr auto BL1 = Number<L1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WaveSize0 = BlockSize / (MWaves * LWaves);
static constexpr auto WaveSize1 = BlockSize / (MWaves * NWaves);
static constexpr auto WaveSize = WaveSize0;
static_assert(
WaveSize0 == 32 || WaveSize0 == 64,
"Misconfigured wave parameters: BlockSize / (MWaves * LWaves) != 32/64 threads per wave");
static_assert(
WaveSize1 == 32 || WaveSize1 == 64,
"Misconfigured wave parameters: BlockSize / (MWaves * NWaves) != 32/64 threads per wave");
static constexpr index_t KPerWmmaBlk =
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
@@ -212,6 +230,52 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
return b1_block_copy_step;
}
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0iDataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
static constexpr auto MakeD1sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D1iDataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
return static_cast<const D1iDataType*>(nullptr);
},
Number<NumD1Tensor>{});
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, LWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / LPerWmma, LPerWmma))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
{
@@ -369,14 +433,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
return c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
@@ -432,12 +496,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
true>())>; // TransposeC (must be true to work), C' = B' x A'
// block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}
template <typename Block2CTileMap>
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const B0GridDesc& b0_grid_desc,
const D0sGridDesc& d0s_grid_desc,
const B1GridDesc& b1_grid_desc,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
const D1sGridDesc& d1s_grid_desc,
const E1GridDesc& c_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
@@ -482,6 +548,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
return false;
}
bool d0s_desc_valid = true;
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
if(!(M == d0s_grid_desc[i].GetLength(I0) && L == d0s_grid_desc[i].GetLength(I1)))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
print("GridwiseOp: M/L Length err, A_M/B0_L = %d, %d | D0s_M/N = %d, %d\n",
M,
L,
d0s_grid_desc[i].GetLength(I0),
d0s_grid_desc[i].GetLength(I1));
}
d0s_desc_valid = false;
}
});
bool d1s_desc_valid = true;
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
if(!(M == d1s_grid_desc[i].GetLength(I0) && N == d1s_grid_desc[i].GetLength(I1)))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
print("GridwiseOp: M/N Length err, A_M/N = %d, %d | D1s_M/N = %d, %d\n",
M,
N,
d1s_grid_desc[i].GetLength(I0),
d1s_grid_desc[i].GetLength(I1));
}
d1s_desc_valid = false;
}
});
if(!(d0s_desc_valid && d1s_desc_valid))
{
return false;
}
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
@@ -513,11 +617,11 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
if(!block_2_etile_map.CheckValidity(c_grid_desc_m_n))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
print("GridwiseOp: invalid block_2_ctile_map\n");
print("GridwiseOp: invalid block_2_etile_map\n");
}
return false;
}
@@ -539,37 +643,94 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc& e_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return e1_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
const D0GridDesc_M_N& d0_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto wmma =
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / MPerBlock, MRepeat, MWaves, MPerWmma)),
make_unmerge_transform(make_tuple(N / LPerBlock,
LRepeat,
LWaves,
WaveSize / LPerWmma,
wmma.num_acc_vgprs_per_wave))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 3, 4>{}, Sequence<1, 5, 6, 7, 8>{}));
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
const D0sGridDesc& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumD1Tensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap(
const E1GridDesc& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, E1GridDesc>(c_grid_desc_m_n);
}
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
E1GridDesc{}))>;
using D0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs =
remove_cvref_t<
decltype(MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
D0sGridDesc{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
D1sGridDesc{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(E1GridDesc{}, 1, 1))>;
struct SharedMemTrait
{
@@ -600,45 +761,69 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
.GetElementSpaceSize();
};
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D1sGridPointer = decltype(MakeD1sGridPointer());
template <bool HasMainKBlockLoop,
TailNumber TailNum,
typename Block2CTileMap = DefaultBlock2CTileMap>
typename Block2ETileMap = DefaultBlock2ETileMap>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
D0sGridPointer p_d0s_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
D1sGridPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
void* __restrict__ p_shared,
const AGridDesc& a_grid_desc,
const B0GridDesc& b0_grid_desc,
const D0sGridDesc& d0s_grid_desc,
const B1GridDesc& b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const B0ElementwiseOperation& b0_element_op,
const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
const CDEElementwiseOperation& c_element_op,
const Block2ETileMap& block_2_etile_map)
{
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const auto d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(d0s_grid_desc);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto e1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
const auto d1s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1s_grid[i],
d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumD1Tensor>{});
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ return; }
// Store BlockId into SGPR
@@ -757,6 +942,72 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
// d0 matrix threadwise copy
constexpr auto d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
make_naive_tensor_descriptor_packed(make_tuple(
I1, // MBlockId
I1, // NBlockID
mrepeat,
mwave,
mthreadpersubgroup,
lrepeat,
lwave,
lsubgroup,
laccvgprs));
auto d0s_thread_buf = generate_tuple(
[&](auto i) {
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return StaticBuffer<
AddressSpaceEnum::Vgpr,
D0iDataType,
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize(),
true>{};
},
Number<NumD0Tensor>{});
const auto wave_id = GetGemm0WaveIdx(); // I0: MWaves, I1: LWaves, I2: WaveSize
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I0: WaveSize / LPerWmma, I1: LPerWmma
static_assert(CDE0BlockTransferSrcScalarPerVector <= laccvgprs,
"vector load must be not greater than n4");
static_assert(laccvgprs % CDE0BlockTransferSrcScalarPerVector == 0);
auto d0s_threadwise_copy = generate_tuple(
[&](auto i) {
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
D0iDataType,
D0iDataType,
decltype(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i]),
decltype(d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
Sequence<I1, // MBlockId
I1, // NBlockID
mrepeat,
mwave,
mthreadpersubgroup,
lrepeat,
lwave,
lsubgroup,
laccvgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8, // NOTE: XDL has this exposed as CDE0BlockTransferSrcVectorDim.
// But as the grid descriptor is built internally, the parameter doesn't really make sense to configure per instance
CDE0BlockTransferSrcScalarPerVector,
1,
false>(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
wave_id[I0], // mwave
wave_m_n_id[I1], // mthreadpersubgroup
0, // nrepeat
wave_id[I1], // nwave
wave_m_n_id[I0], // nsubgroup
0)); // register number
},
Number<NumD0Tensor>{});
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)),
@@ -924,9 +1175,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
b_scale_struct,
KBlockMainLoop,
1); // num_k_block_per_scale
// multiple d
if constexpr(NumD0Tensor)
{
constexpr auto d0s_thread_buf_size = d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize();
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
d0s_grid_buf[i],
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_thread_buf(i));
});
static_for<0, d0s_thread_buf_size, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
Number<NumD0Tensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return acc0_thread_buf(i); },
Number<2>{});
unpack2(acc_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0));
});
}
else
{
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
}
block_sync_lds();
@@ -995,15 +1281,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
}
} // end gemm1
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto c_mrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto c_mwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto c_mthreadpersubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto c_nrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto c_nwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto c_nsubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto c_naccvgprs = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
@@ -1032,29 +1318,29 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
/*******************************************************************************/
// write out to C, implement shuffle
{
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
// This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
constexpr auto c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
constexpr auto MWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
constexpr auto MThreadPerSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
constexpr auto NSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
constexpr auto NAccVgprs = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
auto c1_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
@@ -1097,10 +1383,10 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
auto c1_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
@@ -1125,36 +1411,68 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto e1_d1s_desc_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumD1Tensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_buf_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_buf[i]; },
Number<NumD1Tensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c1_d1s_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumD1Tensor>{}));
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), D1sDataType{})),
Tuple<E1DataType>,
decltype(e1_d1s_desc_refs),
decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(CGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray
// type
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumD1Tensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{e1_d1s_desc_refs,
idx_c1_d1s_block_begin,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
c_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
constexpr auto sfc_c1_vgpr =
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, NAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMRepeatPerShuffle,
@@ -1166,7 +1484,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
NAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
constexpr auto sfc_e1_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
@@ -1174,37 +1492,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
sfc_c1_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
c_shuffle_block_buf);
c1_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
cde1_shuffle_block_copy_lds_to_global.Run(
e1_d1s_desc_refs,
c1_d1s_buf_refs,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e1_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id);
// move on D1s
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
e1_d1s_desc_refs, i + I1, e1_global_step);
});
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step);
}
});
}

View File

@@ -219,6 +219,30 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// D0
//
static auto MakeD0GridDescriptorPair(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimN, CSpec>(d0_gs_ms_ns_lengths_vec,
d0_gs_ms_ns_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeD0GridDescriptor_G_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
{
return MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).first;
}
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
{
return matrix_padder.PadD0Descriptor_M_N(
MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).second);
}
//
// B1
//