mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
N-D Tensor Contraction example, instance, and client example (#270)
* adding contraction * add contraction example * update examle * update example * format * update readme * clean header * clean header * contraction with multiple D * rename * fix naming issue; add instances for contraction+bilinear * change assumed virtual layout of contraction; add client example * update example * update * contraction+scale * use type_convert * rename
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceContractionMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
std::vector<index_t> a_ms_ks_lengths,
|
||||
std::vector<index_t> a_ms_ks_strides,
|
||||
std::vector<index_t> b_ns_ks_lengths,
|
||||
std::vector<index_t> b_ns_ks_strides,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
|
||||
std::vector<index_t> e_ms_ns_lengths,
|
||||
std::vector<index_t> e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,981 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
#include "ck/device_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatDsPointer,
|
||||
typename FloatE,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_multiple_d_xdl_cshuffle(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatDsPointer p_ds_grid,
|
||||
FloatE* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_etile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
: public DeviceContractionMultipleD<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_ms_ks_strides_vec)
|
||||
{
|
||||
assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
|
||||
a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
|
||||
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto MRaw = a_grid_desc_mraw_kraw.GetLength(I0);
|
||||
const auto KRaw = a_grid_desc_mraw_kraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
|
||||
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_ns_ks_strides_vec)
|
||||
{
|
||||
assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
|
||||
b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number<NumDimN + NumDimK>{});
|
||||
const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number<NumDimN + NumDimK>{});
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
|
||||
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto NRaw = b_grid_desc_nraw_kraw.GetLength(I0);
|
||||
const auto KRaw = b_grid_desc_nraw_kraw.GetLength(I1);
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
}
|
||||
|
||||
// assume E[M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& e_ms_ns_strides_vec)
|
||||
{
|
||||
assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN &&
|
||||
e_ms_ns_strides_vec.size() == NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number<NumDimM + NumDimN>{});
|
||||
const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number<NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto e_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
e_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto MRaw = e_grid_desc_mraw_nraw.GetLength(I0);
|
||||
const auto NRaw = e_grid_desc_mraw_nraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return e_grid_desc_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
decltype(MakeAGridDescriptor_AK0_M_AK1(std::vector<index_t>{}, std::vector<index_t>{}));
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
decltype(MakeBGridDescriptor_BK0_N_BK1(std::vector<index_t>{}, std::vector<index_t>{}));
|
||||
using EGridDesc_M_N =
|
||||
decltype(MakeEGridDescriptor_M_N(std::vector<index_t>{}, std::vector<index_t>{}));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
EGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
std::vector<index_t> a_ms_ns_lengths,
|
||||
std::vector<index_t> a_ms_ks_strides,
|
||||
std::vector<index_t> b_ns_ks_lengths,
|
||||
std::vector<index_t> b_ns_ks_strides,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
|
||||
std::vector<index_t> e_ms_ns_lengths,
|
||||
std::vector<index_t> e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{}, // FIXME
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_ms_ns_lengths, a_ms_ks_strides)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_ns_ks_lengths, b_ns_ks_strides)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_mz_stride_{},
|
||||
a_kz_stride_{},
|
||||
b_nz_stride_{},
|
||||
b_kz_stride_{},
|
||||
ds_nz_stride_{},
|
||||
e_nz_stride_{}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
|
||||
const auto d_grid_desc_m_n =
|
||||
DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
d_grid_desc_m_n);
|
||||
});
|
||||
}
|
||||
|
||||
// for sanity check of vector memory access
|
||||
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1];
|
||||
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1];
|
||||
|
||||
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1];
|
||||
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1];
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1];
|
||||
}
|
||||
|
||||
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1];
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
StaticallyIndexedArray<
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
NumDTensor>
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
|
||||
// type from E
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Strides for the last M/N/K dimensions of A/B/Ds/E
|
||||
// for sanity check of vector load/store
|
||||
index_t a_mz_stride_;
|
||||
index_t a_kz_stride_;
|
||||
index_t b_nz_stride_;
|
||||
index_t b_kz_stride_;
|
||||
std::array<index_t, NumDTensor> ds_nz_stride_;
|
||||
index_t e_mz_stride_;
|
||||
index_t e_nz_stride_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
ck::StaticallyIndexedArray<
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
NumDTensor>,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access
|
||||
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
|
||||
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
|
||||
"wrong!");
|
||||
|
||||
// vector memory access of A: could be on M or AK1 dimension
|
||||
if constexpr(ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.a_mz_stride_ == 1 &&
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector memory access of B: could be on N or BK1 dimension
|
||||
if constexpr(BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.b_nz_stride_ == 1 &&
|
||||
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.b_kz_stride_ == 1 &&
|
||||
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector memory access of Ds: always on NPerBlock dimension
|
||||
bool valid_d_access = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
if(!(arg.ds_nz_stride_[i] == 1 &&
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
|
||||
CDEBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
valid_d_access = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(valid_d_access == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector memory access of E: always on NPerBlock dimension
|
||||
if(!(arg.e_nz_stride_ == 1 &&
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
|
||||
CDEBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
std::vector<index_t> a_ms_ns_lengths,
|
||||
std::vector<index_t> a_ms_ks_strides,
|
||||
std::vector<index_t> b_ns_ks_lengths,
|
||||
std::vector<index_t> b_ns_ks_strides,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
|
||||
std::vector<index_t> e_ms_ns_lengths,
|
||||
std::vector<index_t> e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_ms_ns_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
ds_ms_ns_lengths,
|
||||
ds_ms_ns_strides,
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
std::vector<index_t> a_ms_ns_lengths,
|
||||
std::vector<index_t> a_ms_ks_strides,
|
||||
std::vector<index_t> b_ns_ks_lengths,
|
||||
std::vector<index_t> b_ns_ks_strides,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
|
||||
std::vector<index_t> e_ms_ns_lengths,
|
||||
std::vector<index_t> e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_ms_ns_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
ds_ms_ns_lengths,
|
||||
ds_ms_ns_strides,
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceContractionMultipleD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< NumDimM << ", "
|
||||
<< NumDimN << ", "
|
||||
<< NumDimK << ", "
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< ABlockTransferSrcVectorDim << ", "
|
||||
<< BBlockTransferSrcVectorDim
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,10 +2,11 @@
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -11,11 +11,14 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// GEMM:
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
|
||||
@@ -88,12 +88,15 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// input : A[M, K], or A[K, N]
|
||||
// input : B[K, N], or A[N, K]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// GEMM:
|
||||
// input : A[AK0, M, AK1]
|
||||
// input : B[AK0, N, AK1]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
@@ -363,7 +366,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
|
||||
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
|
||||
@@ -423,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
|
||||
@@ -496,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -518,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
|
||||
const auto d_grid_desc_m_n =
|
||||
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
|
||||
DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -527,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
// ck::Tuple<const DsDataType*...>
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cv_t<decltype(DsDataType{}.At(i))>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
StaticallyIndexedArray<
|
||||
@@ -554,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
Reference in New Issue
Block a user