mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement batched gemm add relu gemm add for rdna4 (#3391)
* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation * wip: many fixes in implementation of batched gemm gemm multiple d * wip: batched gemm gemm multiple d gridwise op compiling, not working yet * fix: incorrect d0 grid indexing in batched gemm gemm multipled * feat: add instances for batched gemm add relu gemm add * chore: configure instance with low vector transfer size for odd sizes * chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense * fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes * fix: disable odd size tests on XDL archs * chore: removed temporary logging * chore: update some references to C tensor to E tensor * Tentative fix for example template params * Tentative fix for non-multi-D batched gemm gemm device impl. * Tentative fix for xdl example template params * Tentative fix for profiler build on gfx90a * chore: improve device batched gemm gemm multi D comment to include all ops and dimensions * chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply * fix: make the gemm1 data types match what happens in the device op * feat: add d0s/d1s datatypes and layouts to the device op type string * chore: change element-wise op so addition happens in fp32 * chore: add static asserts for gemm0/gemm1 calculated wave sizes * chore: also updated other element-wise ops to use fp32 calculations * chore: log number of supported instances * chore: update instance comment * chore: disable kernel timing in example by default * fix: gemm1 wave size calculation * fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions * chore: remove increased tolerance in batched gemm gemm multiple d example * chore: add comment explaining that verification fails for certain input values * chore: clarify instance comment --------- Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -51,12 +52,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
Tuple<>{}, // p_d0s_grid
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
Tuple<>{}, // p_d1s_grid
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
@@ -240,8 +245,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// DataType Family
|
||||
ADataType,
|
||||
B0DataType,
|
||||
Tuple<>, // Ds0DataType
|
||||
AccDataType, // Acc0DataType
|
||||
B1DataType,
|
||||
Tuple<>, // Ds1DataType
|
||||
AccDataType, // Acc1DataType
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
@@ -255,7 +262,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
Tuple<>, // Ds0GridDesc
|
||||
B1GridDesc,
|
||||
Tuple<>, // Ds1GridDesc
|
||||
CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
@@ -290,6 +299,7 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B0BlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
B0BlockLdsAddExtraL,
|
||||
1, // CDE0BlockTransferSrcScalarPerVector
|
||||
B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
@@ -369,8 +379,8 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
|
||||
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
@@ -405,10 +415,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
@@ -500,7 +510,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -825,6 +825,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<A0DataType, B0DataType, Gemm0MPerXdl, Gemm0NPerXdl>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "wrong! XDL/WMMA not supported for these datatypes or operation sizes."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -843,6 +848,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
|
||||
{
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "wrong! Unsupported tensor layout combination." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -101,6 +101,15 @@ struct GemmGemmPadder
|
||||
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
|
||||
}
|
||||
|
||||
// D0[M, N]
|
||||
template <typename D0Desc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw& d0_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor(
|
||||
d0_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
|
||||
}
|
||||
|
||||
// B1[Gemm1N, Gemm1K] = B1[O, N]
|
||||
template <typename B1Desc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
|
||||
@@ -13,31 +13,35 @@
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
|
||||
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
|
||||
// Gemm0: AccOp(A [M x K] x B0 [K x L], D0) = Acc [M x L]
|
||||
// Gemm1: CDEOp1(Acc [M x L] x B1 [L x N], D1) = E [M x N]
|
||||
template <typename ADataType,
|
||||
typename B0DataType,
|
||||
typename D0sDataType,
|
||||
typename Acc0DataType,
|
||||
typename B1DataType,
|
||||
typename D1sDataType,
|
||||
typename Acc1DataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename E1DataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc,
|
||||
typename B0GridDesc,
|
||||
typename D0sGridDesc,
|
||||
typename B1GridDesc,
|
||||
typename CGridDesc_M_N,
|
||||
typename D1sGridDesc,
|
||||
typename E1GridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t LPerBlock,
|
||||
index_t KPerBlock,
|
||||
@@ -69,6 +73,7 @@ template <typename ADataType,
|
||||
index_t B0BlockTransferDstScalarPerVector_K1,
|
||||
bool B0ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool B0BlockLdsExtraL,
|
||||
index_t CDE0BlockTransferSrcScalarPerVector,
|
||||
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
@@ -79,8 +84,8 @@ template <typename ADataType,
|
||||
bool B1BlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool PadN,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
|
||||
@@ -94,6 +99,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
|
||||
@@ -105,9 +113,19 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
static constexpr auto BL0 = Number<L0PerBlock>{};
|
||||
static constexpr auto BL1 = Number<L1Value>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WaveSize0 = BlockSize / (MWaves * LWaves);
|
||||
static constexpr auto WaveSize1 = BlockSize / (MWaves * NWaves);
|
||||
static constexpr auto WaveSize = WaveSize0;
|
||||
|
||||
static_assert(
|
||||
WaveSize0 == 32 || WaveSize0 == 64,
|
||||
"Misconfigured wave parameters: BlockSize / (MWaves * LWaves) != 32/64 threads per wave");
|
||||
static_assert(
|
||||
WaveSize1 == 32 || WaveSize1 == 64,
|
||||
"Misconfigured wave parameters: BlockSize / (MWaves * NWaves) != 32/64 threads per wave");
|
||||
|
||||
static constexpr index_t KPerWmmaBlk =
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
|
||||
@@ -212,6 +230,52 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return b1_block_copy_step;
|
||||
}
|
||||
|
||||
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
|
||||
static constexpr auto MakeD0sGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
return static_cast<const D0iDataType*>(nullptr);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
|
||||
static constexpr auto MakeD1sGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using D1iDataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
return static_cast<const D1iDataType*>(nullptr);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
__device__ static auto GetGemm0WaveIdx()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, LWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
|
||||
{
|
||||
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(WaveSize / LPerWmma, LPerWmma))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
|
||||
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
|
||||
{
|
||||
@@ -369,14 +433,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
return c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
@@ -432,12 +496,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
true>())>; // TransposeC (must be true to work), C' = B' x A'
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
template <typename Block2ETileMap>
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
|
||||
const B0GridDesc& b0_grid_desc,
|
||||
const D0sGridDesc& d0s_grid_desc,
|
||||
const B1GridDesc& b1_grid_desc,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const D1sGridDesc& d1s_grid_desc,
|
||||
const E1GridDesc& c_grid_desc_m_n,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
@@ -482,6 +548,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return false;
|
||||
}
|
||||
|
||||
bool d0s_desc_valid = true;
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
if(!(M == d0s_grid_desc[i].GetLength(I0) && L == d0s_grid_desc[i].GetLength(I1)))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
print("GridwiseOp: M/L Length err, A_M/B0_L = %d, %d | D0s_M/N = %d, %d\n",
|
||||
M,
|
||||
L,
|
||||
d0s_grid_desc[i].GetLength(I0),
|
||||
d0s_grid_desc[i].GetLength(I1));
|
||||
}
|
||||
|
||||
d0s_desc_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
bool d1s_desc_valid = true;
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
if(!(M == d1s_grid_desc[i].GetLength(I0) && N == d1s_grid_desc[i].GetLength(I1)))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
print("GridwiseOp: M/N Length err, A_M/N = %d, %d | D1s_M/N = %d, %d\n",
|
||||
M,
|
||||
N,
|
||||
d1s_grid_desc[i].GetLength(I0),
|
||||
d1s_grid_desc[i].GetLength(I1));
|
||||
}
|
||||
d1s_desc_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!(d0s_desc_valid && d1s_desc_valid))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
@@ -513,11 +617,11 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
|
||||
if(!block_2_etile_map.CheckValidity(c_grid_desc_m_n))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
print("GridwiseOp: invalid block_2_ctile_map\n");
|
||||
print("GridwiseOp: invalid block_2_etile_map\n");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -539,37 +643,94 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc& e_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
const auto M = e_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = e_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
e_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
return e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
// D0 desc for source in blockwise copy
|
||||
template <typename D0GridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
const D0GridDesc_M_N& d0_grid_desc_m_n)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
const auto M = d0_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = d0_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto wmma =
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
d0_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / MPerBlock, MRepeat, MWaves, MPerWmma)),
|
||||
make_unmerge_transform(make_tuple(N / LPerBlock,
|
||||
LRepeat,
|
||||
LWaves,
|
||||
WaveSize / LPerWmma,
|
||||
wmma.num_acc_vgprs_per_wave))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 3, 4>{}, Sequence<1, 5, 6, 7, 8>{}));
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
// D0s desc for source in blockwise copy
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
const D0sGridDesc& ds_grid_desc_m_n)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
ds_grid_desc_m_n[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
// Ds desc for source in blockwise copy
|
||||
template <typename DsGridDescriptor_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
// return block_id to E matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap(
|
||||
const E1GridDesc& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, E1GridDesc>(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
E1GridDesc{}))>;
|
||||
|
||||
using D0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs =
|
||||
remove_cvref_t<
|
||||
decltype(MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
D0sGridDesc{}))>;
|
||||
|
||||
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
D1sGridDesc{}))>;
|
||||
using DefaultBlock2ETileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(E1GridDesc{}, 1, 1))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
@@ -600,45 +761,69 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
.GetElementSpaceSize();
|
||||
};
|
||||
|
||||
using D0sGridPointer = decltype(MakeD0sGridPointer());
|
||||
using D1sGridPointer = decltype(MakeD1sGridPointer());
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
TailNumber TailNum,
|
||||
typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
typename Block2ETileMap = DefaultBlock2ETileMap>
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const B0DataType* __restrict__ p_b0_grid,
|
||||
D0sGridPointer p_d0s_grid,
|
||||
const B1DataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
D1sGridPointer p_d1s_grid,
|
||||
E1DataType* __restrict__ p_e1_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc& a_grid_desc,
|
||||
const B0GridDesc& b0_grid_desc,
|
||||
const D0sGridDesc& d0s_grid_desc,
|
||||
const B1GridDesc& b1_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const B0ElementwiseOperation& b0_element_op,
|
||||
const AccElementwiseOperation& acc_element_op,
|
||||
const B1ElementwiseOperation& b1_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const CDEElementwiseOperation& c_element_op,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
// clang-format off
|
||||
/*******************************************************************************/
|
||||
// Memory buffer zone.
|
||||
const auto d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(d0s_grid_desc);
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc.GetElementSpaceSize());
|
||||
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b0_grid, b0_grid_desc.GetElementSpaceSize());
|
||||
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b1_grid, b1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto e1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
const auto d0s_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d0s_grid[i],
|
||||
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
const auto d1s_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d1s_grid[i],
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockIdx.x -> [BlockId.m, BlockId.n]
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
if(!block_2_etile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{ return; }
|
||||
|
||||
// Store BlockId into SGPR
|
||||
@@ -757,6 +942,72 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
|
||||
constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
|
||||
|
||||
// d0 matrix threadwise copy
|
||||
constexpr auto d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
mrepeat,
|
||||
mwave,
|
||||
mthreadpersubgroup,
|
||||
lrepeat,
|
||||
lwave,
|
||||
lsubgroup,
|
||||
laccvgprs));
|
||||
|
||||
auto d0s_thread_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
return StaticBuffer<
|
||||
AddressSpaceEnum::Vgpr,
|
||||
D0iDataType,
|
||||
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize(),
|
||||
true>{};
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
const auto wave_id = GetGemm0WaveIdx(); // I0: MWaves, I1: LWaves, I2: WaveSize
|
||||
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I0: WaveSize / LPerWmma, I1: LPerWmma
|
||||
|
||||
static_assert(CDE0BlockTransferSrcScalarPerVector <= laccvgprs,
|
||||
"vector load must be not greater than n4");
|
||||
static_assert(laccvgprs % CDE0BlockTransferSrcScalarPerVector == 0);
|
||||
|
||||
auto d0s_threadwise_copy = generate_tuple(
|
||||
[&](auto i) {
|
||||
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
return ThreadwiseTensorSliceTransfer_v2<
|
||||
D0iDataType,
|
||||
D0iDataType,
|
||||
decltype(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i]),
|
||||
decltype(d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
Sequence<I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
mrepeat,
|
||||
mwave,
|
||||
mthreadpersubgroup,
|
||||
lrepeat,
|
||||
lwave,
|
||||
lsubgroup,
|
||||
laccvgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
|
||||
8, // NOTE: XDL has this exposed as CDE0BlockTransferSrcVectorDim.
|
||||
// But as the grid descriptor is built internally, the parameter doesn't really make sense to configure per instance
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
false>(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
|
||||
make_multi_index(block_work_idx[I0], // MBlockId
|
||||
0, // NBlockId
|
||||
0, // mrepeat
|
||||
wave_id[I0], // mwave
|
||||
wave_m_n_id[I1], // mthreadpersubgroup
|
||||
0, // nrepeat
|
||||
wave_id[I1], // nwave
|
||||
wave_m_n_id[I0], // nsubgroup
|
||||
0)); // register number
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
|
||||
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)),
|
||||
@@ -924,9 +1175,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
b_scale_struct,
|
||||
KBlockMainLoop,
|
||||
1); // num_k_block_per_scale
|
||||
// multiple d
|
||||
if constexpr(NumD0Tensor)
|
||||
{
|
||||
constexpr auto d0s_thread_buf_size = d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize();
|
||||
|
||||
static_for<0, acc0_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).Run(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
|
||||
d0s_grid_buf[i],
|
||||
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
d0s_thread_buf(i));
|
||||
});
|
||||
static_for<0, d0s_thread_buf_size, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto) -> auto& { return acc0_thread_buf(i); },
|
||||
Number<2>{});
|
||||
|
||||
unpack2(acc_element_op, dst_data_refs, src_data_refs);
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
|
||||
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, acc0_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -995,15 +1281,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
}
|
||||
} // end gemm1
|
||||
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
|
||||
constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
|
||||
constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
|
||||
constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
|
||||
constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
|
||||
constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
|
||||
constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
|
||||
constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
|
||||
constexpr auto c_mrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
|
||||
constexpr auto c_mwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
|
||||
constexpr auto c_mthreadpersubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
|
||||
constexpr auto c_nrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
|
||||
constexpr auto c_nwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
|
||||
constexpr auto c_nsubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
|
||||
constexpr auto c_naccvgprs = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
|
||||
|
||||
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
|
||||
@@ -1032,29 +1318,29 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
/*******************************************************************************/
|
||||
// write out to C, implement shuffle
|
||||
{
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
|
||||
|
||||
// This API Provide All dimension (size) you need
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
|
||||
constexpr auto c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
|
||||
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
|
||||
|
||||
constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
|
||||
constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
|
||||
constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
|
||||
constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
|
||||
constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
|
||||
constexpr auto MWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
|
||||
constexpr auto MThreadPerSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
|
||||
constexpr auto NWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
|
||||
constexpr auto NSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
|
||||
constexpr auto NAccVgprs = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
|
||||
|
||||
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
auto c1_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
@@ -1097,10 +1383,10 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
auto c1_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
decltype(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
@@ -1125,36 +1411,68 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
n_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto e1_d1s_desc_refs = concat_tuple_of_reference(
|
||||
tie(c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumD1Tensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c1_d1s_buf_refs = concat_tuple_of_reference(
|
||||
tie(c1_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return d1s_grid_buf[i]; },
|
||||
Number<NumD1Tensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c1_d1s_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
|
||||
},
|
||||
Number<NumD1Tensor>{}));
|
||||
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), D1sDataType{})),
|
||||
Tuple<E1DataType>,
|
||||
decltype(e1_d1s_desc_refs),
|
||||
decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(CGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray
|
||||
// type
|
||||
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumD1Tensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{e1_d1s_desc_refs,
|
||||
idx_c1_d1s_block_begin,
|
||||
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
|
||||
c_element_op};
|
||||
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
constexpr auto sfc_c1_vgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, NAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
@@ -1166,7 +1484,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
NAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
constexpr auto sfc_e1_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
@@ -1174,37 +1492,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
sfc_c1_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
c1_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
cde1_shuffle_block_copy_lds_to_global.Run(
|
||||
e1_d1s_desc_refs,
|
||||
c1_d1s_buf_refs,
|
||||
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e1_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id);
|
||||
|
||||
// move on D1s
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
|
||||
e1_d1s_desc_refs, i + I1, e1_global_step);
|
||||
});
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -219,6 +219,30 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// D0
|
||||
//
|
||||
static auto MakeD0GridDescriptorPair(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimN, CSpec>(d0_gs_ms_ns_lengths_vec,
|
||||
d0_gs_ms_ns_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_NRaw
|
||||
static auto MakeD0GridDescriptor_G_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).first;
|
||||
}
|
||||
|
||||
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return matrix_padder.PadD0Descriptor_M_N(
|
||||
MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).second);
|
||||
}
|
||||
|
||||
//
|
||||
// B1
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user