mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Compile for gfx908 and gfx90a (#130)
* adding compilation for multiple targets * fix build * clean * update Jekinsfile * update readme * update Jenkins * use ck::half_t instead of ushort for bf16 * rename enum classes * clean * rename * clean
This commit is contained in:
@@ -207,9 +207,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
|
||||
@@ -220,9 +220,9 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
|
||||
@@ -119,7 +119,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
|
||||
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||
a_thread_buf;
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||
|
||||
@@ -42,7 +42,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
@@ -250,9 +250,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
template <index_t BlockSize,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace ck {
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace ck {
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace ck {
|
||||
// 3. Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace ck {
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
|
||||
@@ -5,7 +5,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct ConvolutionBackwardDataSpecialization_t
|
||||
enum struct ConvolutionBackwardDataSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Stride1Pad0,
|
||||
|
||||
@@ -7,7 +7,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct ConvolutionForwardSpecialization_t
|
||||
enum struct ConvolutionForwardSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Pad0,
|
||||
@@ -15,14 +15,14 @@ enum struct ConvolutionForwardSpecialization_t
|
||||
OddC,
|
||||
};
|
||||
|
||||
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization_t& s)
|
||||
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ConvolutionForwardSpecialization_t::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization_t::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization_t::OddC: return "OddC";
|
||||
case ConvolutionForwardSpecialization::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization::OddC: return "OddC";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
GemmSpecialization_t GemmSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::NKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::NKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto MPad = M - MRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M
|
||||
return transform_tensor_descriptor(d_grid_desc_mraw,
|
||||
@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -209,7 +209,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -250,7 +250,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename InDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization,
|
||||
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
|
||||
@@ -27,7 +27,7 @@ template <
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -125,7 +125,7 @@ struct
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
@@ -179,7 +179,7 @@ struct
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
@@ -249,7 +249,7 @@ struct
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
@@ -466,7 +466,7 @@ struct
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -811,7 +811,7 @@ struct
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
@@ -823,7 +823,7 @@ struct
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
|
||||
@@ -27,8 +27,8 @@ template <
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
InMemoryDataOperationEnum OutGlobalMemoryDataOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
|
||||
@@ -26,7 +26,7 @@ template <
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
AccDataType,
|
||||
CDataType, // TODO: Add ShuffleType for DeviceConv2d
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename InDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -119,7 +119,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
@@ -159,7 +159,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
@@ -316,7 +316,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -565,7 +565,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
@@ -577,7 +577,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
|
||||
@@ -83,7 +83,7 @@ template <typename InDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
const index_t Ho = output_spatial_lengths[1];
|
||||
const index_t Wo = output_spatial_lengths[2];
|
||||
|
||||
static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Default,
|
||||
static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default,
|
||||
"Wrong! This specialization not implemented!");
|
||||
|
||||
const auto in_desc_n_di_hi_wi_c =
|
||||
@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename InDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionBackwardDataSpecialization_t ConvBackwardDataSpecialization,
|
||||
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
|
||||
ck::index_t NumDimSpatial,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
@@ -116,7 +116,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
@@ -336,7 +336,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
@@ -618,7 +618,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
@@ -959,7 +959,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -1385,7 +1385,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
for(int i = 0; i < NumDimSpatial; i++)
|
||||
@@ -1527,7 +1527,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0){
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
|
||||
|
||||
str<< " Filter1x1Stride1Pad0";
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ template <typename InDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t NumDimSpatial,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
const index_t ConvStrideW = conv_filter_strides[0];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
|
||||
@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
|
||||
@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
@@ -395,7 +395,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
|
||||
@@ -409,7 +409,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
@@ -613,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -878,7 +878,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
}
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(ck::index_t i = 0; i < NumDimSpatial; ++i)
|
||||
@@ -891,7 +891,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(ck::index_t i = 0; i < NumDimSpatial; ++i)
|
||||
|
||||
@@ -29,7 +29,7 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
GemmSpecialization_t GemmSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -95,8 +95,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
@@ -119,8 +119,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
@@ -136,8 +136,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::NKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
@@ -198,8 +198,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
@@ -222,8 +222,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
@@ -239,8 +239,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
@@ -301,8 +301,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
@@ -311,8 +311,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
@@ -321,8 +321,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::NKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
@@ -346,10 +346,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto MPad = M - MRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M
|
||||
return transform_tensor_descriptor(d_grid_desc_mraw,
|
||||
@@ -382,8 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -27,7 +27,7 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization_t GemmSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -80,7 +80,7 @@ struct DeviceGemmXdl
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
|
||||
@@ -119,7 +119,7 @@ struct DeviceGemmXdl
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
@@ -154,7 +154,7 @@ struct DeviceGemmXdl
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
@@ -186,7 +186,7 @@ struct DeviceGemmXdl
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -139,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -24,7 +24,7 @@ template <typename ALayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization_t GemmSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -84,8 +84,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
@@ -108,8 +108,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
@@ -125,8 +125,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::NKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
@@ -187,8 +187,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
@@ -211,8 +211,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
@@ -228,8 +228,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
@@ -290,8 +290,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MNKPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
@@ -300,8 +300,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::MKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
@@ -310,8 +310,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding ||
|
||||
GemmSpecialization == GemmSpecialization_t::NKPadding)
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
@@ -340,7 +340,7 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -31,7 +31,7 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization_t GemmSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -91,7 +91,7 @@ struct DeviceGemmXdlSplitK
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
@@ -136,7 +136,7 @@ struct DeviceGemmXdlSplitK
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
@@ -170,7 +170,7 @@ struct DeviceGemmXdlSplitK
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
@@ -209,7 +209,7 @@ struct DeviceGemmXdlSplitK
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -250,7 +250,7 @@ struct DeviceGemmXdlSplitK
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -31,7 +31,7 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization_t GemmSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -93,7 +93,7 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
@@ -138,7 +138,7 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
@@ -172,7 +172,7 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
@@ -211,7 +211,7 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
@@ -253,7 +253,7 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -27,7 +27,7 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization_t GemmSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -81,7 +81,7 @@ struct DeviceGroupedGemmXdl
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
|
||||
@@ -120,7 +120,7 @@ struct DeviceGroupedGemmXdl
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
@@ -155,7 +155,7 @@ struct DeviceGroupedGemmXdl
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding)
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
@@ -187,7 +187,7 @@ struct DeviceGroupedGemmXdl
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::ReduceTensorOp_t ReduceOpId>
|
||||
template <ck::ReduceTensorOp ReduceOpId>
|
||||
struct DevicePool2dFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
@@ -29,7 +29,7 @@ struct DevicePool2dFwd : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <ck::ReduceTensorOp_t ReduceOpId>
|
||||
template <ck::ReduceTensorOp ReduceOpId>
|
||||
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace device {
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
ck::ReduceTensorOp_t ReduceOpId,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool NeedIndices,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t ReduceMThreadClusterSize,
|
||||
@@ -181,7 +181,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
reduce_lowest_length_ = window_spatial_lengths[1];
|
||||
|
||||
// TODO: is this correct?
|
||||
if constexpr(ReduceOpId == ck::ReduceTensorOp_t::AVG)
|
||||
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
|
||||
{
|
||||
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
in_element_op_ = InElementwiseOperation{divider};
|
||||
|
||||
@@ -5,7 +5,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct GemmSpecialization_t
|
||||
enum struct GemmSpecialization
|
||||
{
|
||||
Default,
|
||||
MPadding,
|
||||
|
||||
@@ -37,11 +37,11 @@ namespace ck {
|
||||
// The boolean member "indexable" are also provided in reduce_binary_operactor for
|
||||
// easier checking by the upper-layer codes in the kernels.
|
||||
|
||||
template <typename T, ReduceTensorOp_t Op>
|
||||
template <typename T, ReduceTensorOp Op>
|
||||
struct reduce_binary_operator;
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::ADD>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
@@ -50,7 +50,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::MUL>
|
||||
{
|
||||
using opType = reduce::Mul<T>;
|
||||
using dataType = T;
|
||||
@@ -59,7 +59,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::MIN>
|
||||
{
|
||||
using opType = reduce::Min<T>;
|
||||
using dataType = T;
|
||||
@@ -68,7 +68,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::MAX>
|
||||
{
|
||||
using opType = reduce::Max<T>;
|
||||
using dataType = T;
|
||||
@@ -77,7 +77,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::AMAX>
|
||||
{
|
||||
using opType = reduce::AMax<T>;
|
||||
using dataType = T;
|
||||
@@ -86,7 +86,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::AVG>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
@@ -95,7 +95,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::NORM1>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
@@ -104,7 +104,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
|
||||
{
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
@@ -115,7 +115,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
|
||||
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
|
||||
// functor classes.
|
||||
// The two unary functors are called before and afer the Reduction is executed respectively
|
||||
template <typename T, ReduceTensorOp_t Op, bool IsFirstReduce, bool IsLastReduce>
|
||||
template <typename T, ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
|
||||
struct reduce_unary_operator
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
@@ -123,42 +123,42 @@ struct reduce_unary_operator
|
||||
};
|
||||
|
||||
template <typename T, bool IsFirstReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::AVG, IsFirstReduce, true>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::AVG, IsFirstReduce, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
|
||||
};
|
||||
|
||||
template <typename T, bool IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM1, true, IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM1, true, IsLastReduce>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
};
|
||||
|
||||
template <typename T, bool IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::AMAX, true, IsLastReduce>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::AMAX, true, IsLastReduce>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, false>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, false>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, true>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, false, true>
|
||||
struct reduce_unary_operator<T, ReduceTensorOp::NORM2, false, true>
|
||||
{
|
||||
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
|
||||
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
|
||||
|
||||
@@ -227,21 +227,18 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
@@ -336,7 +333,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
@@ -376,7 +373,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
@@ -422,30 +419,26 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
@@ -561,7 +554,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
@@ -601,7 +594,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
@@ -619,7 +612,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
@@ -678,36 +671,32 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_ws_values_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
@@ -835,7 +824,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
{
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValueBuf;
|
||||
|
||||
auto threadwise_dst_load =
|
||||
@@ -875,7 +864,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
@@ -893,7 +882,7 @@ struct GridwiseReduction_mk_to_m_blockwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
|
||||
@@ -140,21 +140,18 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
// LDS
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
@@ -259,7 +256,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m,
|
||||
|
||||
@@ -163,22 +163,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
__shared__ AccDataType p_block_reduce_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
@@ -272,7 +269,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
@@ -322,33 +319,29 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
__shared__ index_t p_block_reduce_idx_buffer[BlockSize];
|
||||
|
||||
const auto in_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto block_reduce_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize);
|
||||
auto block_reduce_idx_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_val_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
IndexDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
in_thread_idx_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
@@ -461,7 +454,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
@@ -480,7 +473,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
workspace_desc_m_k,
|
||||
|
||||
@@ -132,18 +132,15 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
|
||||
@@ -223,7 +220,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
true>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
@@ -248,7 +245,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
@@ -277,22 +274,18 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true>
|
||||
accu_index_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
@@ -382,7 +375,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
false>(
|
||||
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
|
||||
priorDstValue_buf;
|
||||
|
||||
threadwise_dst_load.Run(out_grid_desc_m,
|
||||
@@ -407,7 +400,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
@@ -424,7 +417,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
|
||||
@@ -55,7 +55,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
@@ -329,11 +329,11 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
|
||||
@@ -383,7 +383,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
@@ -407,7 +407,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
@@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
@@ -481,15 +481,15 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
@@ -268,11 +268,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
|
||||
@@ -315,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
@@ -341,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
@@ -403,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
@@ -428,15 +428,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||
BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
@@ -275,11 +275,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -325,7 +325,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
@@ -349,7 +349,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
@@ -409,7 +409,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
@@ -423,15 +423,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
@@ -84,11 +84,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
@@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<E, KPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
@@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
BGlobalMoveSliceWindowStepHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAB,
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
|
||||
@@ -20,7 +20,7 @@ template <typename GridwiseGemm,
|
||||
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
|
||||
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
|
||||
bool HasMainE0BlockLoop,
|
||||
ActivTypeEnum_t ActivType>
|
||||
ActivTypeEnum ActivType>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainE0BlockLoop>{},
|
||||
integral_constant<ActivTypeEnum_t, ActivType>{});
|
||||
integral_constant<ActivTypeEnum, ActivType>{});
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
@@ -62,7 +62,7 @@ template <typename GridwiseGemm,
|
||||
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
|
||||
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
|
||||
bool HasMainE0BlockLoop,
|
||||
ActivTypeEnum_t ActivType>
|
||||
ActivTypeEnum ActivType>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -94,7 +94,7 @@ __global__ void
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainE0BlockLoop>{},
|
||||
integral_constant<ActivTypeEnum_t, ActivType>{});
|
||||
integral_constant<ActivTypeEnum, ActivType>{});
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
@@ -106,7 +106,7 @@ template <typename GridwiseGemm,
|
||||
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
|
||||
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
|
||||
bool HasMainE0BlockLoop,
|
||||
ActivTypeEnum_t ActivType>
|
||||
ActivTypeEnum ActivType>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -140,14 +140,14 @@ __global__ void
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainE0BlockLoop>{},
|
||||
integral_constant<ActivTypeEnum_t, ActivType>{});
|
||||
integral_constant<ActivTypeEnum, ActivType>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_E0_E1_K_E2,
|
||||
typename BGridDesc_E0_E1_N_Ho_Wo_E2,
|
||||
typename CGridDesc_K_N_Ho_Wo,
|
||||
@@ -559,7 +559,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
constexpr auto bias_k0_k1_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatC,
|
||||
bias_k0_k1_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -602,10 +602,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
});
|
||||
}
|
||||
|
||||
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum_t activ_type_>
|
||||
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum activ_type_>
|
||||
__device__ static void Activation(CThreadBuff& c_thread_buf,
|
||||
const CThreadDesc_K1_N_H2_W2&,
|
||||
integral_constant<ActivTypeEnum_t, activ_type_>)
|
||||
integral_constant<ActivTypeEnum, activ_type_>)
|
||||
{
|
||||
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
|
||||
|
||||
@@ -737,7 +737,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
I1,
|
||||
Number<WoPerThread_2>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatC,
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -783,7 +783,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
make_multi_index(k_block_work_id,
|
||||
@@ -843,7 +843,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
I1,
|
||||
Number<WoPerThreadx2>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatC,
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -874,7 +874,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
InMemoryDataOperationEnum_t::Add,
|
||||
InMemoryDataOperationEnum::Add,
|
||||
1,
|
||||
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
make_multi_index(k_block_work_id,
|
||||
@@ -964,7 +964,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<I1, E1, I1, KPerBlock, E2>,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
@@ -1023,11 +1023,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
0,
|
||||
0));
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize());
|
||||
|
||||
//// register allocation for output
|
||||
// StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
// StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
// FloatAcc,
|
||||
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
|
||||
// true>
|
||||
@@ -1050,7 +1050,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAB,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -1294,21 +1294,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
const auto bias_k0_k1_grid_desc =
|
||||
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
|
||||
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -1344,7 +1344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
|
||||
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
|
||||
bool HasMainE0BlockLoop,
|
||||
ActivTypeEnum_t ActivType>
|
||||
ActivTypeEnum ActivType>
|
||||
__device__ static void ConvBiasActiv(
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
@@ -1356,26 +1356,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainE0BlockLoop>,
|
||||
integral_constant<ActivTypeEnum_t, ActivType>)
|
||||
integral_constant<ActivTypeEnum, ActivType>)
|
||||
{
|
||||
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{};
|
||||
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
|
||||
|
||||
const auto bias_k0_k1_grid_desc =
|
||||
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -1423,7 +1423,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
|
||||
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
|
||||
bool HasMainE0BlockLoop,
|
||||
ActivTypeEnum_t ActivType>
|
||||
ActivTypeEnum ActivType>
|
||||
__device__ static void ConvBiasActivMaxpool(
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
@@ -1437,28 +1437,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainE0BlockLoop>,
|
||||
integral_constant<ActivTypeEnum_t, ActivType>)
|
||||
integral_constant<ActivTypeEnum, ActivType>)
|
||||
{
|
||||
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{};
|
||||
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
|
||||
|
||||
const auto bias_k0_k1_grid_desc =
|
||||
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
|
||||
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
@@ -1514,7 +1514,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
|
||||
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
|
||||
bool HasMainE0BlockLoop,
|
||||
ActivTypeEnum_t ActivType>
|
||||
ActivTypeEnum ActivType>
|
||||
__device__ static void ConvBiasActivResizeAdd(
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
@@ -1527,26 +1527,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainE0BlockLoop>,
|
||||
integral_constant<ActivTypeEnum_t, ActivType>)
|
||||
integral_constant<ActivTypeEnum, ActivType>)
|
||||
{
|
||||
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{};
|
||||
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
|
||||
|
||||
const auto bias_k0_k1_grid_desc =
|
||||
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
|
||||
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
|
||||
@@ -79,8 +79,8 @@ template <typename FloatAB,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum_t DGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -363,15 +363,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -399,7 +399,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -430,7 +430,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -484,10 +484,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
@@ -563,7 +563,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -632,7 +632,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
@@ -723,13 +723,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
|
||||
|
||||
// TODO: this should be implemented as a blockwise reduction
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>(
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
|
||||
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>(
|
||||
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>(
|
||||
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
// reduce: threadwise copy from LDS to VGPR
|
||||
|
||||
@@ -60,7 +60,7 @@ template <typename FloatAB,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -348,7 +348,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -379,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -433,10 +433,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
@@ -512,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -581,7 +581,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
@@ -132,7 +132,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -426,11 +426,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
@@ -460,7 +460,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -491,7 +491,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -543,10 +543,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
@@ -316,11 +316,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
@@ -410,7 +410,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -440,7 +440,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -497,9 +497,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
|
||||
@@ -61,7 +61,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CMNGridDesc,
|
||||
@@ -305,11 +305,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
@@ -399,7 +399,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -429,7 +429,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -486,9 +486,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
@@ -560,7 +560,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatC*>(p_shared_block),
|
||||
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -632,7 +632,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
@@ -64,7 +64,7 @@ template <
|
||||
typename FloatAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -369,11 +369,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
@@ -403,7 +403,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -434,7 +434,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -488,10 +488,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
@@ -567,7 +567,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
@@ -644,7 +644,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
@@ -68,7 +68,7 @@ template <
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -382,15 +382,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c0_grid,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
@@ -422,7 +422,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -453,7 +453,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -505,10 +505,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
@@ -582,7 +582,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatC*>(p_shared),
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
@@ -661,7 +661,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
@@ -74,7 +74,7 @@ template <
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -397,19 +397,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid,
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c0_grid,
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c1_grid,
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
@@ -441,7 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -471,7 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -522,10 +522,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
constexpr auto a_block_space_size_aligned =
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
@@ -599,7 +599,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
|
||||
GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatC*>(p_shared),
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
|
||||
.GetElementSpaceSize());
|
||||
@@ -678,7 +678,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
@@ -45,13 +45,13 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
|
||||
|
||||
const index_t thread_global_id = block_global_id * BlockSize + thread_local_id;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, DataType, 1, true> value_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, DataType, 1, true> value_buf;
|
||||
|
||||
value_buf(I0) = value;
|
||||
|
||||
constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}));
|
||||
|
||||
auto global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
auto global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_global, grid_1d_buffer_desc.GetElementSpaceSize());
|
||||
|
||||
if(thread_global_id < grid_1d_buffer_desc.GetElementSize())
|
||||
@@ -65,7 +65,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
|
||||
Sequence<0>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
grid_1d_buffer_desc, make_multi_index(thread_global_id), PassThroughOp{});
|
||||
|
||||
@@ -56,7 +56,7 @@ template <typename SrcData,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
@@ -407,7 +407,7 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -464,8 +464,8 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -621,8 +621,8 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -979,7 +979,7 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -1,523 +0,0 @@
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V1R4_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V1R4_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions:
|
||||
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
|
||||
// instead
|
||||
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
|
||||
// tensor coordinate instead
|
||||
// 3. Don't use a pointer to VGPR buffer, use vector instead
|
||||
|
||||
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
|
||||
// TODO: fix this
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is StaticBuffer
|
||||
// 3. SrcSliceOrginIdx is known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is not known at compile-time
|
||||
// 2. DstBuffer is DynamicBuffer
|
||||
// 3. DstSliceOrginIdx is not known at compile time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc
|
||||
typename Dst1Desc, // this is really one of sources, but it has same shape as DstDesc
|
||||
typename DstElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r4
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
|
||||
using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{}));
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
|
||||
using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r4(
|
||||
const DstDesc& dst_desc,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst1Desc& dst1_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
|
||||
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)),
|
||||
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin_idx)),
|
||||
dst_element_op_{dst_element_op}
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename Dst0Buffer,
|
||||
typename Dst1Buffer,
|
||||
typename DstStepHacks,
|
||||
typename Dst0StepHacks,
|
||||
typename Dst1StepHacks>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstStepHacks& dst_step_hacks,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst0Buffer& dst0_buf,
|
||||
const Dst0StepHacks& dst0_step_hacks,
|
||||
const Dst1Desc& dst1_desc,
|
||||
const Dst1Buffer& dst1_buf,
|
||||
const Dst1StepHacks& dst1_step_hacks)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// make forward steps: dst
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make forward steps: dst0
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst0_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst0_desc, forward_step_idx, dst0_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make forward steps: dst1
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst1_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst1_desc, forward_step_idx, dst1_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps: dst
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps: dst0
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst0_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst0_desc, backward_step_idx, dst0_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps: dst1
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst1_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst1_desc, backward_step_idx, dst1_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_access_idx[i]
|
||||
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
// load dst0 and dst1 and apply elementwise operation
|
||||
{
|
||||
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
|
||||
// TODO: fix this
|
||||
static_assert(DstScalarPerVector == 1, "wrong!");
|
||||
|
||||
// copy data from src_buf into dst_vector_src_data
|
||||
constexpr index_t src_offset =
|
||||
src_desc.CalculateOffset(src_slice_origin_idx + dst_data_idx);
|
||||
|
||||
const SrcData src_v = src_buf[Number<src_offset>{}];
|
||||
|
||||
// load dst0 and dst1
|
||||
const bool is_dst0_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst0_desc,
|
||||
dst0_coord_);
|
||||
const bool is_dst1_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst1_desc,
|
||||
dst1_coord_);
|
||||
|
||||
const DstData dst0_v =
|
||||
dst0_buf.template Get<DstData>(dst0_coord_.GetOffset(), is_dst0_valid);
|
||||
const DstData dst1_v =
|
||||
dst1_buf.template Get<DstData>(dst1_coord_.GetOffset(), is_dst1_valid);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
|
||||
// apply element-wise operation in SrcData type
|
||||
const SrcData dst_v = dst_element_op_(
|
||||
src_v, type_convert<SrcData>(dst0_v), type_convert<SrcData>(dst1_v));
|
||||
|
||||
// apply type convert
|
||||
dst_vector.template AsType<DstData>()(Number<0>{}) = type_convert<DstData>(dst_v);
|
||||
#else
|
||||
// apply element-wise operation in DstData type
|
||||
DstData dst_v;
|
||||
|
||||
dst_element_op_(dst_v, src_v, dst0_v, dst1_v);
|
||||
|
||||
dst_vector.template AsType<DstData>()(Number<0>{}) = dst_v;
|
||||
#endif
|
||||
}
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Add)
|
||||
{
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type tmp;
|
||||
tmp.template AsType<dst_vector_t>()(Number<0>{}) =
|
||||
dst_buf.template Get<dst_vector_t>(dst_coord_.GetOffset(), is_dst_valid);
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto t) {
|
||||
dst_vector.template AsType<DstData>()(t) += tmp.template AsType<DstData>()[t];
|
||||
});
|
||||
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
|
||||
// dst0
|
||||
move_tensor_coordinate(
|
||||
dst0_desc, dst0_coord_, dst0_forward_steps[dim_access_order[i]]);
|
||||
|
||||
// dst1
|
||||
move_tensor_coordinate(
|
||||
dst1_desc, dst1_coord_, dst1_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
|
||||
// dst0
|
||||
move_tensor_coordinate(
|
||||
dst0_desc, dst0_coord_, dst0_backward_steps[dim_access_order[i]]);
|
||||
|
||||
// dst1
|
||||
move_tensor_coordinate(
|
||||
dst1_desc, dst1_coord_, dst1_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename Dst0Buffer,
|
||||
typename Dst1Buffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst0Buffer& dst0_buf,
|
||||
const Dst1Desc& dst1_desc,
|
||||
const Dst1Buffer& dst1_buf)
|
||||
{
|
||||
auto f_step_hacks = [&](auto desc) {
|
||||
constexpr index_t ntransform = decltype(desc)::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform, 0>::type{};
|
||||
|
||||
constexpr auto step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
return step_hacks;
|
||||
};
|
||||
|
||||
Run(SrcDesc{},
|
||||
SrcSliceOriginIdx{},
|
||||
src_buf,
|
||||
dst_desc,
|
||||
dst_buf,
|
||||
f_step_hacks(dst_desc),
|
||||
dst0_desc,
|
||||
dst0_buf,
|
||||
f_step_hacks(dst0_desc),
|
||||
dst1_desc,
|
||||
dst1_buf,
|
||||
f_step_hacks(dst1_desc));
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in Run(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
Dst0Coord dst0_coord_;
|
||||
Dst1Coord dst1_coord_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,453 +0,0 @@
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V1R5_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V1R5_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions:
|
||||
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
|
||||
// instead
|
||||
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
|
||||
// tensor coordinate instead
|
||||
// 3. Don't use a pointer to VGPR buffer, use vector instead
|
||||
|
||||
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
|
||||
// TODO: fix this
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is StaticBuffer
|
||||
// 3. SrcSliceOrginIdx is known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is not known at compile-time
|
||||
// 2. DstBuffer is DynamicBuffer
|
||||
// 3. DstSliceOrginIdx is not known at compile time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc
|
||||
typename DstElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r5
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r5(
|
||||
const DstDesc& dst_desc,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
|
||||
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)),
|
||||
dst_element_op_{dst_element_op}
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename Dst0Buffer,
|
||||
typename DstStepHacks,
|
||||
typename Dst0StepHacks>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstStepHacks& dst_step_hacks,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst0Buffer& dst0_buf,
|
||||
const Dst0StepHacks& dst0_step_hacks)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// make forward steps: dst
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make forward steps: dst0
|
||||
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
|
||||
// TODO: fix this
|
||||
const auto dst0_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst0_desc, forward_step_idx, dst0_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps: dst
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps: dst0
|
||||
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
|
||||
// TODO: fix this
|
||||
const auto dst0_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst0_desc, backward_step_idx, dst0_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_access_idx[i]
|
||||
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
// load dst0 and apply elementwise operation
|
||||
{
|
||||
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
|
||||
// TODO: fix this
|
||||
static_assert(DstScalarPerVector == 1, "wrong!");
|
||||
|
||||
// copy data from src_buf into dst_vector_src_data
|
||||
constexpr index_t src_offset =
|
||||
src_desc.CalculateOffset(src_slice_origin_idx + dst_data_idx);
|
||||
|
||||
const SrcData src_v = src_buf[Number<src_offset>{}];
|
||||
|
||||
// load dst0
|
||||
const bool is_dst0_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst0_desc,
|
||||
dst0_coord_);
|
||||
const DstData dst0_v =
|
||||
dst0_buf.template Get<DstData>(dst0_coord_.GetOffset(), is_dst0_valid);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE
|
||||
// apply element-wise operation in SrcData type
|
||||
const SrcData dst_v = dst_element_op_(src_v, type_convert<SrcData>(dst0_v));
|
||||
|
||||
// apply type convert
|
||||
dst_vector.template AsType<DstData>()(Number<0>{}) = type_convert<DstData>(dst_v);
|
||||
#else
|
||||
// apply element-wise operation in DstData type
|
||||
const DstData dst_v = dst_element_op_(src_v, dst0_v);
|
||||
|
||||
dst_vector.template AsType<DstData>()(Number<0>{}) = dst_v;
|
||||
#endif
|
||||
}
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Add)
|
||||
{
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type tmp;
|
||||
tmp.template AsType<dst_vector_t>()(Number<0>{}) =
|
||||
dst_buf.template Get<dst_vector_t>(dst_coord_.GetOffset(), is_dst_valid);
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto t) {
|
||||
dst_vector.template AsType<DstData>()(t) += tmp.template AsType<DstData>()[t];
|
||||
});
|
||||
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
|
||||
// dst0
|
||||
move_tensor_coordinate(
|
||||
dst0_desc, dst0_coord_, dst0_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
|
||||
// dst0
|
||||
move_tensor_coordinate(
|
||||
dst0_desc, dst0_coord_, dst0_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename Dst0Buffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst0Buffer& dst0_buf)
|
||||
{
|
||||
auto f_step_hacks = [&](auto desc) {
|
||||
constexpr index_t ntransform = decltype(desc)::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform, 0>::type{};
|
||||
|
||||
constexpr auto step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
return step_hacks;
|
||||
};
|
||||
|
||||
Run(SrcDesc{},
|
||||
SrcSliceOriginIdx{},
|
||||
src_buf,
|
||||
dst_desc,
|
||||
dst_buf,
|
||||
f_step_hacks(dst_desc),
|
||||
dst0_desc,
|
||||
dst0_buf,
|
||||
f_step_hacks(dst0_desc));
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in Run(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
Dst0Coord dst0_coord_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -48,7 +48,7 @@ struct lambda_scalar_per_access_for_src_and_dst
|
||||
template <typename SliceLengths,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -110,8 +110,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -271,7 +271,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) =
|
||||
type_convert<DstData>(src_thread_scratch_tuple[thread_scratch_id][idx]);
|
||||
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
|
||||
});
|
||||
#else
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
@@ -361,8 +361,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
|
||||
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -763,13 +763,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
SrcData,
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(dst_thread_scratch_desc_),
|
||||
|
||||
@@ -48,7 +48,7 @@ struct lambda_scalar_per_access_for_src_and_dst
|
||||
template <typename SliceLengths,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -120,8 +120,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -369,8 +369,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch();
|
||||
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -859,14 +859,14 @@ struct ThreadwiseTensorSliceTransfer_v3r3
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
SrcData,
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>
|
||||
src_thread_scratch_;
|
||||
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(dst_thread_scratch_desc_),
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace ck {
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -76,8 +76,8 @@ struct ThreadwiseTensorSliceTransfer_v5r1
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -244,8 +244,8 @@ struct ThreadwiseTensorSliceTransfer_v5r1
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -29,7 +29,7 @@ template <typename SrcData,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
bool DstResetCoordinateAfterRun>
|
||||
struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
|
||||
@@ -31,7 +31,7 @@ template <typename Src0Data,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
bool Src0ResetCoordinateAfterRun,
|
||||
bool Src1ResetCoordinateAfterRun,
|
||||
bool DstResetCoordinateAfterRun>
|
||||
|
||||
@@ -33,7 +33,7 @@ template <typename Src0Data,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
bool Src0ResetCoordinateAfterRun,
|
||||
bool Src1ResetCoordinateAfterRun,
|
||||
bool Src2ResetCoordinateAfterRun,
|
||||
|
||||
@@ -476,7 +476,7 @@ struct MfmaSelector
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 32, 32>()
|
||||
{
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x4bf16;
|
||||
@@ -486,7 +486,7 @@ struct MfmaSelector
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 16, 16>()
|
||||
{
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x8bf16;
|
||||
|
||||
Reference in New Issue
Block a user