mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Support multiple D in GridwiseGemm_wmma_cshuffle_v3
DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.
This commit is contained in:
@@ -180,11 +180,13 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
@@ -219,7 +221,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -294,7 +296,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
@@ -312,7 +314,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
@@ -468,11 +470,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -488,20 +504,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch);
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -19,7 +19,7 @@ namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
@@ -31,22 +31,26 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
@@ -59,8 +63,8 @@ __global__ void
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM kernel is carrying out following mathematical equation:
|
||||
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
|
||||
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B/CDE_op are
|
||||
/// elementwise operations that could be applied on each tensor respectively.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
@@ -73,18 +77,19 @@ __global__ void
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam DsLayout D tensors data layout.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam EDataType E tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
|
||||
/// GEMM) and D input tensors.
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
@@ -142,11 +147,12 @@ __global__ void
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
|
||||
/// Used when loading data from D tensors and storing data
|
||||
/// to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
@@ -160,15 +166,17 @@ __global__ void
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -198,8 +206,8 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
@@ -217,6 +225,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// TODO: remove
|
||||
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
|
||||
@@ -530,17 +542,18 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
template <typename DELayout>
|
||||
__device__ static auto
|
||||
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -593,6 +606,44 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
__device__ static auto MakeDsGridDescriptor_M_N(
|
||||
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <typename DsGridDesc>
|
||||
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n[i], MBlock, NBlock);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
@@ -600,14 +651,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
@@ -627,8 +680,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << ", "
|
||||
<< "SB:" << StrideB << ", ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", "
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", "
|
||||
@@ -644,7 +705,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
@@ -661,21 +723,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsReduceAdd() const
|
||||
@@ -690,42 +766,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
|
||||
bool is_reduce;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -736,7 +819,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1143,7 +1226,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * KPerBlock! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
@@ -1219,7 +1302,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
@@ -1252,23 +1335,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
|
||||
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
|
||||
{
|
||||
if(!karg.IsReduceAdd())
|
||||
if(karg.IsAtomicAdd() && karg.KBatch > 1)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
|
||||
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1301,18 +1381,18 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
|
||||
}
|
||||
|
||||
template <typename CGridDesc>
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
de_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
return de_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
@@ -1322,30 +1402,40 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock)
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
@@ -1355,8 +1445,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -1483,7 +1573,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
// Epilogue: shuffle C for better memory access pattern, apply elementwise operation to
|
||||
// C and Ds, write out result E to global memory
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
@@ -1601,31 +1692,60 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
|
||||
},
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
|
||||
c_element_op};
|
||||
sequence_merge_t<Sequence<true>,
|
||||
uniform_sequence_gen_t<
|
||||
NumDTensor,
|
||||
false>>, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
Sequence<false>> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
@@ -1641,7 +1761,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
@@ -1651,7 +1771,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
@@ -1668,57 +1788,78 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_to_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
problem,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock);
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user