Reorganize files, Part 1 (#119)

* delete obselete files

* move files

* build

* update cmake

* update cmake

* fix build

* reorg examples

* update cmake for example and test

[ROCm/composable_kernel commit: 5d37d7bff4]
This commit is contained in:
Chao Liu
2022-03-08 21:46:36 -06:00
committed by GitHub
parent c1ac480787
commit 82ad74304e
422 changed files with 388 additions and 3326 deletions

View File

@@ -0,0 +1,90 @@
#ifndef CONV_COMMON_HPP
#define CONV_COMMON_HPP
#include "tensor_descriptor.hpp"
template <typename... InDesc,
typename... WeiDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads>
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
const ck::TensorDescriptor<InDesc...>& in_desc,
const ck::TensorDescriptor<WeiDesc...>& wei_desc,
const ConvStrides& conv_strides,
const ConvDilations conv_dilations,
const LeftPads& left_pads,
const RightPads& right_pads)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
assert(in_desc.GetNumOfDimension() == 4);
assert(wei_desc.GetNumOfDimension() == 4);
assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1));
const auto N = in_desc.GetLength(I0);
const auto Hi = in_desc.GetLength(I2);
const auto Wi = in_desc.GetLength(I3);
const auto K = wei_desc.GetLength(I0);
const auto Y = wei_desc.GetLength(I2);
const auto X = wei_desc.GetLength(I3);
const auto LeftPadH = left_pads[I0];
const auto LeftPadW = left_pads[I1];
const auto RightPadH = right_pads[I0];
const auto RightPadW = right_pads[I1];
const auto YEff = (Y - I1) * conv_dilations[I0] + I1;
const auto XEff = (X - I1) * conv_dilations[I1] + I1;
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
}
template <class InDesc, class WeiDesc, class OutDesc>
constexpr std::size_t
calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const index_t N = out_desc.GetLength(I0);
const index_t K = out_desc.GetLength(I1);
const index_t Ho = out_desc.GetLength(I2);
const index_t Wo = out_desc.GetLength(I3);
const index_t C = wei_desc.GetLength(I1);
const index_t Y = wei_desc.GetLength(I2);
const index_t X = wei_desc.GetLength(I3);
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
}
template <typename T>
inline auto activ(T v, const ck::ActivTypeEnum_t activ_type)
{
const T alpha = 0.3;
switch(activ_type)
{
case ck::ActivTypeEnum_t::None: return v;
case ck::ActivTypeEnum_t::LeakyRelu: return (v >= 0 ? v : alpha * v);
case ck::ActivTypeEnum_t::Sigmoid: return (1 / (1 + exp(-v)));
default: throw std::runtime_error("unsupported activ type"); break;
}
}
#endif

View File

@@ -0,0 +1,90 @@
#ifndef DEVICE_HPP
#define DEVICE_HPP
#include <memory>
#include <functional>
#include <thread>
#include <chrono>
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
struct DeviceMem
{
DeviceMem() = delete;
DeviceMem(std::size_t mem_size);
void* GetDeviceBuffer();
void ToDevice(const void* p);
void FromDevice(void* p);
~DeviceMem();
void* mpDeviceBuf;
std::size_t mMemSize;
};
struct KernelTimerImpl;
struct KernelTimer
{
KernelTimer();
~KernelTimer();
void Start();
void End();
float GetElapsedTime() const;
std::unique_ptr<KernelTimerImpl> impl;
};
using device_stream_t = hipStream_t;
template <typename... Args, typename F>
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
hipStream_t stream_id = nullptr;
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
template <typename... Args, typename F>
float launch_and_time_kernel(
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
#if 1
KernelTimer timer;
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up\n");
hipStream_t stream_id = nullptr;
// warm up
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
printf("Start running %d times...\n", nrepeat);
timer.Start();
for(int i = 0; i < nrepeat; ++i)
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
timer.End();
// std::this_thread::sleep_for (std::chrono::microseconds(10));
return timer.GetElapsedTime() / nrepeat;
#else
launch_kernel(kernel, grid_dim, block_dim, lds_byte, args...);
return 0;
#endif
}
#endif

View File

@@ -0,0 +1,9 @@
#pragma once
#include "host_tensor.hpp"
#include "common_header.hpp"
template <typename TensorDesc>
void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout)
{
ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os);
}

View File

@@ -0,0 +1,149 @@
#pragma once
#include "host_tensor.hpp"
#include "conv_common.hpp"
template <typename TIn,
typename TWei,
typename TOut,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
const Tensor<TWei>& wei,
Tensor<TOut>& out,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads&)
{
constexpr auto I0 = ck::Number<0>{};
constexpr auto I1 = ck::Number<1>{};
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x));
}
}
}
}
out(n, k, ho, wo) = ck::type_convert<TOut>(v);
};
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
template <typename TIn,
typename TWei,
typename TOut,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in,
const Tensor<TWei>& wei,
Tensor<TOut>& out,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads&)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const auto Di = in.mDesc.GetLengths()[1];
const auto Hi = in.mDesc.GetLengths()[2];
const auto Wi = in.mDesc.GetLengths()[3];
const auto Z = wei.mDesc.GetLengths()[1];
const auto Y = wei.mDesc.GetLengths()[2];
const auto X = wei.mDesc.GetLengths()[3];
const auto C = wei.mDesc.GetLengths()[4];
auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) {
// do__ must be converted to signed integer, otherwise zmin might be wrong in cases
// negative values.
const int do_ = static_cast<int>(do_tmp);
const int ho = static_cast<int>(ho_tmp);
const int wo = static_cast<int>(wo_tmp);
const int zmin =
std::max(0,
(in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) /
conv_dilations[I0]);
const int ymin =
std::max(0,
(in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) /
conv_dilations[I1]);
const int xmin =
std::max(0,
(in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) /
conv_dilations[I2]);
const int zmax =
std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]);
const int ymax =
std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]);
const int xmax =
std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]);
const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0];
const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1];
const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2];
double v = 0;
const TIn* in_n = in.mData.data() + n * Di * Hi * Wi * C;
const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C;
int di = di_min;
for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0])
{
const TIn* in_n_di = in_n + di * Hi * Wi * C;
const TWei* wei_k_z = wei_k + z * Y * X * C;
int hi = hi_min;
for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1])
{
const TIn* in_n_di_hi = in_n_di + hi * Wi * C;
const TWei* wei_k_z_y = wei_k_z + y * X * C;
int wi = wi_min;
for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2])
{
const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C;
const TWei* wei_k_z_y_x = wei_k_z_y + x * C;
for(int c = 0; c < C; ++c)
{
v += static_cast<const double>(in_n_di_hi_wi[c]) *
static_cast<const double>(wei_k_z_y_x[c]);
}
}
}
}
out(n, do_, ho, wo, k) = v;
};
make_ParallelTensorFunctor(f_ndhwc,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3],
out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4);
}

View File

@@ -0,0 +1,43 @@
#pragma once
#include "host_tensor.hpp"
template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));
v_acc += v_a * v_b;
}
float v_c;
c_element_op(v_c, v_acc);
c_m_n(m, n) = v_c;
};
make_ParallelTensorFunctor(f_mk_kn_mn,
c_m_n.mDesc.GetLengths()[0],
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}

View File

@@ -0,0 +1,424 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef HOST_GENERIC_REDUCTION_HPP_
#define HOST_GENERIC_REDUCTION_HPP_
#include <vector>
#include <functional>
#include <limits>
#include <type_traits>
#include <cassert>
#include <cmath>
#include "reduction_enums.hpp"
#include "host_reduce_util.hpp"
using float16 = half_float::half;
namespace ck {
namespace host_reduce {
template <typename T>
static void
get_all_indexes(const std::vector<T>& dimLengths, int dim, std::vector<std::vector<T>>& indexes)
{
if(dim < dimLengths.size())
{
std::vector<std::vector<T>> updated_indexes;
if(dim == 0)
{
assert(indexes.size() == 0);
assert(dimLengths[dim] > 0);
for(T i = 0; i < dimLengths[dim]; i++)
{
std::vector<T> index = {i};
updated_indexes.push_back(index);
};
}
else
{
// go through all the current indexes
for(const auto& index : indexes)
for(T i = 0; i < dimLengths[dim]; i++)
{
auto index_new = index;
index_new.push_back(i);
updated_indexes.push_back(index_new);
};
};
// update to the indexes (output)
indexes = updated_indexes;
// further to construct the indexes from the updated status
get_all_indexes(dimLengths, dim + 1, indexes);
};
};
template <typename T>
static T get_offset_from_index(const std::vector<T>& strides, const std::vector<T>& index)
{
T offset = 0;
assert(strides.size() == index.size());
for(int i = 0; i < index.size(); i++)
offset += strides[i] * static_cast<T>(index[i]);
return (offset);
};
template <typename T>
static inline T get_flatten_offset(const std::vector<T>& lengths, const std::vector<T>& index)
{
T offset = 0;
assert(lengths.size() == index.size() && lengths.size() > 0);
int len = lengths.size();
T stride = 1;
// for len==1, the loop is not executed
for(int i = len - 1; i > 0; i--)
{
offset += stride * static_cast<T>(index[i]);
stride *= lengths[i];
};
offset += stride * static_cast<T>(index[0]);
return (offset);
};
template <typename InDataType,
typename AccDataType,
typename OutDataType,
ck::ReduceTensorOp_t ReduceOpId,
bool PropagateNan,
bool NeedIndices>
class ReductionHost
{
public:
ReductionHost() = default;
ReductionHost(HostTensorDescriptor& inDesc,
HostTensorDescriptor& outDesc,
const std::vector<int>& invariantDims_,
const std::vector<int>& toReduceDims_)
{
this->inLengths = to_int_vector(inDesc.GetLengths());
this->outLengths = to_int_vector(outDesc.GetLengths());
this->inStrides = to_int_vector(inDesc.GetStrides());
this->outStrides = to_int_vector(outDesc.GetStrides());
this->invariantDims = invariantDims_;
this->toReduceDims = toReduceDims_;
assert(this->inLengths.size() == this->outLengths.size());
assert(!this->toReduceDims.empty());
for(const auto dim : this->invariantDims)
this->invariantLengths.push_back(this->inLengths[dim]);
for(const auto dim : this->toReduceDims)
toReduceLengths.push_back(this->inLengths[dim]);
this->reduceAllDims = this->invariantDims.empty();
};
~ReductionHost(){};
void
Run(float alpha, const InDataType* in_data, float beta, OutDataType* out_data, int* indices)
{
if constexpr(NeedIndices)
RunImpl_with_indices(alpha, in_data, beta, out_data, indices);
else
RunImpl_no_indices(alpha, in_data, beta, out_data);
};
private:
std::vector<int> inLengths;
std::vector<int> outLengths;
std::vector<int> inStrides;
std::vector<int> outStrides;
std::vector<int> invariantLengths;
std::vector<int> toReduceLengths;
std::vector<int> invariantDims;
std::vector<int> toReduceDims;
bool reduceAllDims;
void RunImpl_with_indices(
float alpha, const InDataType* in_data, float beta, OutDataType* out_data, int* indices)
{
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::binop_with_nan_check2;
using ck::host_reduce::float_equal_one;
using ck::host_reduce::float_equal_zero;
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>();
int divider = 1;
for(int i = 0; i < toReduceLengths.size(); i++)
divider *= toReduceLengths[i];
auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
if(reduceAllDims)
{
std::vector<std::vector<int>> indexes_1;
get_all_indexes(inLengths, 0, indexes_1); // generate the input indexes space
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
int accuIndex = 0;
// go through indexes of the invariant dimensions
for(const auto& src_index : indexes_1)
{
auto src_offset = get_offset_from_index(this->inStrides, src_index);
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
// unary operation before reducing, needed by AMAX. For MIN/MAX, nothing is actually
// done
PreUnaryOp(currVal);
auto currIndex = get_flatten_offset(inLengths, src_index);
binop_with_nan_check2<AccDataType, PropagateNan>(
opReduce, accuVal, currVal, accuIndex, currIndex);
};
// scale the accumulated value
if(!float_equal_one(alpha))
accuVal *= static_cast<AccDataType>(alpha);
// scale the prior dst value and add it to the accumulated value
if(!float_equal_zero(beta))
accuVal += static_cast<AccDataType>(out_data[0]) * static_cast<AccDataType>(beta);
// store the reduced value to dst location
out_data[0] = static_cast<OutDataType>(accuVal);
indices[0] = accuIndex;
}
else
{
std::vector<std::vector<int>> indexes_1, indexes_2;
get_all_indexes(
this->invariantLengths, 0, indexes_1); // generate the invariant indexes space
get_all_indexes(
this->toReduceLengths, 0, indexes_2); // generate the toReduce indexes space
// go through indexes of the invariant dimensions
for(const auto& index_1 : indexes_1)
{
std::vector<int> src_index;
std::vector<int> dst_index;
src_index.resize(this->inLengths.size());
// generate the part of src index belonging to invariant dims
for(int k = 0; k < invariantDims.size(); k++)
src_index[invariantDims[k]] = index_1[k];
for(int k = 0; k < invariantDims.size(); k++)
dst_index.push_back(index_1[k]);
int dst_offset = get_offset_from_index(this->outStrides, dst_index);
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
int accuIndex = 0;
// go through indexes of the toReduce dimensions
for(const auto& index_2 : indexes_2)
{
// generate the part of src index belonging to toReduce dims
for(int k = 0; k < toReduceDims.size(); k++)
src_index[toReduceDims[k]] = index_2[k];
auto src_offset = get_offset_from_index(this->inStrides, src_index);
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
// unary operation before reducing, needed by AMAX. For MIN/MAX, nothing is
// actually done
PreUnaryOp(currVal);
auto currIndex = get_flatten_offset(toReduceLengths, index_2);
binop_with_nan_check2<AccDataType, PropagateNan>(
opReduce, accuVal, currVal, accuIndex, currIndex);
};
// scale the accumulated value
if(!float_equal_one(alpha))
accuVal *= static_cast<AccDataType>(alpha);
// scale the prior dst value and add it to the accumulated value
if(!float_equal_zero(beta))
accuVal += static_cast<AccDataType>(out_data[dst_offset]) *
static_cast<AccDataType>(beta);
// store the reduced value to dst location
out_data[dst_offset] = static_cast<OutDataType>(accuVal);
indices[dst_offset] = accuIndex;
};
};
}; // end of RunImpl_with_indices()
void
RunImpl_no_indices(float alpha, const InDataType* in_data, float beta, OutDataType* out_data)
{
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::binop_with_nan_check2;
using ck::host_reduce::float_equal_one;
using ck::host_reduce::float_equal_zero;
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
using ck::host_reduce::ReduceOpFn;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
int divider = 1;
for(int i = 0; i < toReduceLengths.size(); i++)
divider *= toReduceLengths[i];
auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
if(reduceAllDims)
{
std::vector<std::vector<int>> indexes_1;
get_all_indexes(inLengths, 0, indexes_1); // generate the input indexes space
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
// go through indexes of the invariant dimensions
for(const auto& src_index : indexes_1)
{
auto src_offset = get_offset_from_index(this->inStrides, src_index);
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
PreUnaryOp(currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
};
PosUnaryOp(accuVal);
// scale the accumulated value
if(!float_equal_one(alpha))
accuVal *= static_cast<AccDataType>(alpha);
// scale the prior dst value and add it to the accumulated value
if(!float_equal_zero(beta))
accuVal += static_cast<AccDataType>(out_data[0]) * static_cast<AccDataType>(beta);
// store the reduced value to dst location
out_data[0] = static_cast<OutDataType>(accuVal);
}
else
{
std::vector<std::vector<int>> indexes_1, indexes_2;
get_all_indexes(
this->invariantLengths, 0, indexes_1); // generate the invariant indexes space
get_all_indexes(
this->toReduceLengths, 0, indexes_2); // generate the toReduce indexes space
// go through indexes of the invariant dimensions
for(const auto& index_1 : indexes_1)
{
std::vector<int> src_index;
std::vector<int> dst_index;
src_index.resize(this->inLengths.size());
for(int k = 0; k < invariantDims.size(); k++)
dst_index.push_back(index_1[k]);
int dst_offset = get_offset_from_index(this->outStrides, dst_index);
// generate the part of src index belonging to invariant dims
for(int k = 0; k < invariantDims.size(); k++)
src_index[invariantDims[k]] = index_1[k];
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
// go through indexes of the toReduce dimensions
for(const auto& index_2 : indexes_2)
{
// generate the part of src index belonging to toReduce dims
for(int k = 0; k < toReduceDims.size(); k++)
src_index[toReduceDims[k]] = index_2[k];
auto src_offset = get_offset_from_index(this->inStrides, src_index);
auto currVal = static_cast<AccDataType>(in_data[src_offset]);
PreUnaryOp(currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
};
PosUnaryOp(accuVal);
// scale the accumulated value
if(!float_equal_one(alpha))
accuVal *= static_cast<AccDataType>(alpha);
// scale the prior dst value and add it to the accumulated value
if(!float_equal_zero(beta))
accuVal += static_cast<AccDataType>(out_data[dst_offset]) *
static_cast<AccDataType>(beta);
// store the reduced value to dst location
out_data[dst_offset] = static_cast<OutDataType>(accuVal);
};
};
}; // end of RunImpl_no_indices()
};
}; // end of namespace host_reduce
}; // end of namespace ck
#endif

View File

@@ -0,0 +1,291 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <half.hpp>
#include <limits>
#include <cmath>
#include <cassert>
#include <stdexcept>
#include <string>
#include "reduction_enums.hpp"
namespace ck {
namespace host_reduce {
using ck::NanPropagation_t;
using ck::ReduceTensorOp_t;
template <typename T>
static inline bool float_equal_one(T);
static inline bool float_equal_one(float x) { return x == 1.0f; };
static inline bool float_equal_one(double x) { return x == 1.0; };
static inline bool float_equal_one(half_float::half x)
{
return x == static_cast<half_float::half>(1.0f);
};
template <typename T>
static inline bool float_equal_zero(T x);
static inline bool float_equal_zero(float x) { return x == 0.0f; };
static inline bool float_equal_zero(double x) { return x == 0.0; };
static inline bool float_equal_zero(half_float::half x)
{
return x == static_cast<half_float::half>(0.0f);
};
template <typename compType, ReduceTensorOp_t ReduceOpId>
__host__ static inline std::function<void(compType&)> PreUnaryOpFn(int)
{
using std::abs;
if constexpr(ReduceOpId == ReduceTensorOp_t::NORM1)
{
return ([&](compType& a_) { a_ = abs(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::NORM2)
{
return ([&](compType& a_) { a_ = a_ * a_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::AMAX)
{
return ([&](compType& a_) { a_ = abs(a_); });
}
else
{
// ReduceTensorOp_t::AVG:
// ReduceTensorOp_t::ADD:
// ReduceTensorOp_t::MUL:
// ReduceTensorOp_t::MIN:
// ReduceTensorOp_t::MAX:
return ([&](compType&) {});
};
};
template <typename compType, ReduceTensorOp_t ReduceOpId>
__host__ static inline std::function<void(compType&)> PosUnaryOpFn(int divider)
{
using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp_t::NORM2)
{
return ([&](compType& a_) { a_ = sqrt(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::AVG)
{
return ([&, divider](compType& a_) {
a_ = a_ / static_cast<compType>(static_cast<float>(divider));
});
}
else
{
// ReduceTensorOp_t::ADD:
// ReduceTensorOp_t::NORM1:
// ReduceTensorOp_t::MUL:
// ReduceTensorOp_t::MIN:
// ReduceTensorOp_t::MAX:
// ReduceTensorOp_t::AMAX:
return ([&](compType&) {});
}
};
template <typename compType, ReduceTensorOp_t ReduceOpId>
__host__ static inline std::function<void(compType&, compType)> ReduceOpFn()
{
if constexpr(ReduceOpId == ReduceTensorOp_t::ADD || ReduceOpId == ReduceTensorOp_t::AVG ||
ReduceOpId == ReduceTensorOp_t::NORM1 || ReduceOpId == ReduceTensorOp_t::NORM2)
{
return ([&](compType& a_, compType b_) { a_ = a_ + b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::MUL)
{
return ([&](compType& a_, compType b_) { a_ = a_ * b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::MIN)
{
return ([&](compType& a_, compType b_) {
if(a_ > b_)
a_ = b_;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX || ReduceOpId == ReduceTensorOp_t::AMAX)
{
return ([&](compType& a_, compType b_) {
if(a_ < b_)
a_ = b_;
});
}
};
template <typename compType, ReduceTensorOp_t ReduceOpId>
__host__ static inline std::function<void(compType&, compType, bool& changed)> ReduceOpFn2()
{
if constexpr(ReduceOpId == ReduceTensorOp_t::MIN)
{
return ([&](compType& a_, compType b_, bool& changed) {
if(a_ > b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX || ReduceOpId == ReduceTensorOp_t::AMAX)
{
return ([&](compType& a_, compType b_, bool& changed) {
if(a_ < b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else
{
// ReduceTensorOp_t::ADD:
// ReduceTensorOp_t::MUL:
// ReduceTensorOp_t::AVG:
// ReduceTensorOp_t::NORM1:
// ReduceTensorOp_t::NORM2:
return (std::function<void(compType&, compType, bool&)>{});
};
};
template <typename compType, ReduceTensorOp_t ReduceOpId>
__host__ static inline compType ReduceOpZeroVal()
{
if constexpr(ReduceOpId == ReduceTensorOp_t::MUL)
{
return (static_cast<compType>(1.0f));
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::MIN)
{
return (std::numeric_limits<compType>::max());
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::MAX)
{
return (std::numeric_limits<compType>::lowest());
}
else if constexpr(ReduceOpId == ReduceTensorOp_t::AMAX)
{
return (static_cast<compType>(0.0f));
}
else
{
// ReduceTensorOp_t::ADD
// ReduceTensorOp_t::AVG
// ReduceTensorOp_t::NORM1
// ReduceTensorOp_t::NORM2
return (static_cast<compType>(0.0f));
};
};
template <typename compType, bool PropagateNan>
__host__ static inline void binop_with_nan_check(std::function<void(compType&, compType)> opReduce,
compType& accuVal,
compType currVal)
{
using std::isnan;
if constexpr(!PropagateNan)
{
opReduce(accuVal, currVal);
}
else
{
if(isnan(currVal))
accuVal = currVal;
else
opReduce(accuVal, currVal);
};
};
template <typename compType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check2(std::function<void(compType&, compType, bool&)> opReduce,
compType& accuVal,
compType currVal,
int& accuIndex,
int currIndex)
{
using std::isnan;
if constexpr(!PropagateNan)
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
}
else
{
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
}
else
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
};
};
};
}; // namespace host_reduce
static inline std::vector<int> to_int_vector(const std::vector<size_t>& inData)
{
std::vector<int> outData;
for(auto elem : inData)
outData.push_back(static_cast<int>(elem));
return (outData);
};
}; // namespace ck
#endif

View File

@@ -0,0 +1,383 @@
#ifndef HOST_TENSOR_HPP
#define HOST_TENSOR_HPP
#include <thread>
#include <vector>
#include <numeric>
#include <algorithm>
#include <utility>
#include <cassert>
#include <iostream>
#include "data_type.hpp"
template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
{
bool first = true;
for(auto&& v : range)
{
if(first)
first = false;
else
os << delim;
os << v;
}
return os;
}
template <typename T, typename Range>
std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
{
bool first = true;
for(auto&& v : range)
{
if(first)
first = false;
else
os << delim;
os << static_cast<T>(v);
}
return os;
}
typedef enum
{
Half = 0,
Float = 1,
} DataType_t;
template <typename T>
struct DataType;
template <>
struct DataType<float> : std::integral_constant<DataType_t, DataType_t::Float>
{
};
template <typename F, typename T, std::size_t... Is>
auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
{
return f(std::get<Is>(args)...);
}
template <typename F, typename T>
auto call_f_unpack_args(F f, T args)
{
constexpr std::size_t N = std::tuple_size<T>{};
return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{});
}
template <typename F, typename T, std::size_t... Is>
auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
{
return F(std::get<Is>(args)...);
}
template <typename F, typename T>
auto construct_f_unpack_args(F, T args)
{
constexpr std::size_t N = std::tuple_size<T>{};
return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
}
struct HostTensorDescriptor
{
HostTensorDescriptor() = delete;
template <typename X>
HostTensorDescriptor(std::vector<X> lens);
template <typename X, typename Y>
HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides);
void CalculateStrides();
template <typename Range>
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename Range1, typename Range2>
HostTensorDescriptor(const Range1& lens, const Range2& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}
std::size_t GetNumOfDimension() const;
std::size_t GetElementSize() const;
std::size_t GetElementSpace() const;
const std::vector<std::size_t>& GetLengths() const;
const std::vector<std::size_t>& GetStrides() const;
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
assert(sizeof...(Is) == this->GetNumOfDimension());
std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
private:
std::vector<std::size_t> mLens;
std::vector<std::size_t> mStrides;
};
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <typename F, typename... Xs>
struct ParallelTensorFunctor
{
F mF;
static constexpr std::size_t NDIM = sizeof...(Xs);
std::array<std::size_t, NDIM> mLens;
std::array<std::size_t, NDIM> mStrides;
std::size_t mN1d;
ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast<std::size_t>(xs)...})
{
mStrides.back() = 1;
std::partial_sum(mLens.rbegin(),
mLens.rend() - 1,
mStrides.rbegin() + 1,
std::multiplies<std::size_t>());
mN1d = mStrides[0] * mLens[0];
}
std::array<std::size_t, NDIM> GetNdIndices(std::size_t i) const
{
std::array<std::size_t, NDIM> indices;
for(int idim = 0; idim < NDIM; ++idim)
{
indices[idim] = i / mStrides[idim];
i -= indices[idim] * mStrides[idim];
}
return indices;
}
void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
{
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d);
auto f = [=] {
for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
{
call_f_unpack_args(mF, GetNdIndices(iw));
}
};
threads[it] = joinable_thread(f);
}
}
};
template <typename F, typename... Xs>
auto make_ParallelTensorFunctor(F f, Xs... xs)
{
return ParallelTensorFunctor<F, Xs...>(f, xs...);
}
template <typename T>
struct Tensor
{
template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
{
}
template <typename X>
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
{
}
template <typename X, typename Y>
Tensor(std::vector<X> lens, std::vector<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpace())
{
}
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
template <typename G>
void GenerateTensorValue(G g, std::size_t num_thread = std::thread::hardware_concurrency())
{
switch(mDesc.GetNumOfDimension())
{
case 1: {
auto f = [&](auto i) { (*this)(i) = g(i); };
make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread);
break;
}
case 2: {
auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); };
make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread);
break;
}
case 3: {
auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
make_ParallelTensorFunctor(
f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread);
break;
}
case 4: {
auto f = [&](auto i0, auto i1, auto i2, auto i3) {
(*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
mDesc.GetLengths()[1],
mDesc.GetLengths()[2],
mDesc.GetLengths()[3])(num_thread);
break;
}
case 5: {
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) {
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
mDesc.GetLengths()[1],
mDesc.GetLengths()[2],
mDesc.GetLengths()[3],
mDesc.GetLengths()[4])(num_thread);
break;
}
default: throw std::runtime_error("unspported dimension");
}
}
template <typename... Is>
T& operator()(Is... is)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
template <typename... Is>
const T& operator()(Is... is) const
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
typename std::vector<T>::iterator begin() { return mData.begin(); }
typename std::vector<T>::iterator end() { return mData.end(); }
typename std::vector<T>::const_iterator begin() const { return mData.begin(); }
typename std::vector<T>::const_iterator end() const { return mData.end(); }
HostTensorDescriptor mDesc;
std::vector<T> mData;
};
template <typename X>
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
{
this->CalculateStrides();
}
template <typename X, typename Y>
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
: mLens(lens), mStrides(strides)
{
}
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
float bf16_to_f32_(ck::bhalf_t src_val);
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst);
template <typename T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
{
float error = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
if constexpr(std::is_same<ck::bhalf_t, T>::value)
{
for(int i = 0; i < ref.mData.size(); ++i)
{
error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
if(max_diff < diff)
{
max_diff = diff;
ref_value = bf16_to_f32_(ref.mData[i]);
result_value = bf16_to_f32_(result.mData[i]);
}
}
}
else
{
for(int i = 0; i < ref.mData.size(); ++i)
{
error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
max_diff = diff;
ref_value = ref.mData[i];
result_value = result.mData[i];
}
}
}
std::cout << "error: " << error << std::endl;
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
}
template <typename T>
void check_indices(const Tensor<T>& ref, const Tensor<T>& result)
{
bool has_error = false;
int error_count = 0;
for(int i = 0; i < ref.mData.size(); ++i)
{
if(ref.mData[i] != result.mData[i])
{
std::cerr << std::endl
<< "Indices different at position " << i << " (ref: " << ref.mData[i]
<< ", result: " << result.mData[i] << ")" << std::endl;
has_error = true;
error_count++;
if(error_count == 20)
break;
};
}
if(!has_error)
std::cout << std::endl << "Indices result is completely acccurate!" << std::endl;
}
#endif

View File

@@ -0,0 +1,167 @@
#ifndef HOST_TENSOR_GENERATOR_HPP
#define HOST_TENSOR_GENERATOR_HPP
#include <cmath>
#include "config.hpp"
template <typename T>
struct GeneratorTensor_0
{
template <typename... Is>
T operator()(Is...)
{
return T{0};
}
};
template <typename T>
struct GeneratorTensor_1
{
int value = 1;
template <typename... Is>
T operator()(Is...)
{
return ck::type_convert<T>(value);
}
};
template <>
struct GeneratorTensor_1<ck::bhalf_t>
{
float value = 1.0;
template <typename... Is>
ck::bhalf_t operator()(Is...)
{
return ck::type_convert<ck::bhalf_t>(value);
}
};
template <>
struct GeneratorTensor_1<int8_t>
{
int8_t value = 1;
template <typename... Is>
int8_t operator()(Is...)
{
return value;
}
};
template <typename T>
struct GeneratorTensor_2
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
T operator()(Is...)
{
return static_cast<T>((std::rand() % (max_value - min_value)) + min_value);
}
};
template <>
struct GeneratorTensor_2<ck::bhalf_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::bhalf_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::bhalf_t>(tmp);
}
};
template <>
struct GeneratorTensor_2<int8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
int8_t operator()(Is...)
{
return (std::rand() % (max_value - min_value)) + min_value;
}
};
template <typename T>
struct GeneratorTensor_3
{
T min_value = 0;
T max_value = 1;
template <typename... Is>
T operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
return static_cast<T>(min_value + tmp * (max_value - min_value));
}
};
template <>
struct GeneratorTensor_3<ck::bhalf_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::bhalf_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::bhalf_t>(fp32_tmp);
}
};
template <>
struct GeneratorTensor_3<int8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
int8_t operator()(Is...)
{
int8_t min_tmp = static_cast<int8_t>(min_value);
int8_t max_tmp = static_cast<int8_t>(max_value);
return (std::rand() % (max_tmp - min_tmp)) + min_tmp;
}
};
struct GeneratorTensor_Checkboard
{
template <typename... Ts>
float operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {static_cast<ck::index_t>(Xs)...};
return std::accumulate(dims.begin(),
dims.end(),
true,
[](bool init, ck::index_t x) -> int { return init != (x % 2); })
? 1
: -1;
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential
{
template <typename... Ts>
float operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
return dims[Dim];
}
};
#endif

View File

@@ -0,0 +1,13 @@
#ifndef DEBUG_HPP
#define DEBUG_HPP
namespace debug {
namespace debug_driver_gemm_xdlops_v2r3 {
// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping
static ck::index_t M01 = 1;
static ck::index_t N01 = 1;
} // namespace debug_driver_gemm_xdlops_v2r3
} // namespace debug
#endif

View File

@@ -0,0 +1,220 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
ck::ActivTypeEnum_t activ_type,
typename InLengths,
typename WeiLengths,
typename AddLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
const InLengths& in_n_c0_hi_wi_c1_lengths,
const WeiLengths& wei_k_c0_y_x_c1_lengths,
const AddLengths& add_n_k0_hox2_wox2_k1_lengths,
const OutLengths& out_n_k0_ho_wo_k1_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
const Tensor<TInWei>& wei_k_c0_y_x_c1,
const Tensor<TOut>& bias_k0_k1,
const Tensor<TOut>& add_n_k0_hox2_wox2_k1,
Tensor<TOut>& add_n_k0_hox2_wox2_k1_out,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
const auto K = wei_k_c0_y_x_c1_lengths[I0];
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
const auto X = wei_k_c0_y_x_c1_lengths[I3];
const auto Hox2 = add_n_k0_hox2_wox2_k1_lengths[I2];
const auto Wox2 = add_n_k0_hox2_wox2_k1_lengths[I3];
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
DeviceMem add_n_k0_hox2_wox2_k1_device_buf(sizeof(TOut) *
add_n_k0_hox2_wox2_k1.mDesc.GetElementSpace());
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
constexpr index_t InWeiVectorSize = 8;
if(C1 % InWeiVectorSize != 0)
{
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif 1
constexpr auto BlockSize = 64;
constexpr auto KPerBlock = 8;
constexpr auto HoPerBlock = 8;
constexpr auto WoPerBlock = 32;
constexpr auto E1 = 2 * 9;
constexpr auto E2 = 1;
constexpr auto K2 = 2;
constexpr auto E1PerBlock = 2;
constexpr auto KPerThread = KPerBlock;
constexpr auto HoPerThread = 2;
constexpr auto WoPerThread = 2;
constexpr auto EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
constexpr auto ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr auto ABlockTransferDstScalarPerVector_E2 = E2;
constexpr auto BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr auto CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
#endif
const auto in_n_c0_hi_wi_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
const auto wei_k_c0_y_x_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
const auto add_n_k0_hox2_wox2_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1));
const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2,
CThreadTransferDstScalarPerVector_K,
activ_type>{};
std::cerr << "conv_bias_activ_resize_add_input_"
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_addout_n" << N << "k" << K0
<< "h" << Ho * 2 << "w" << Wo * 2 << "k" << K1 << std::endl;
for(int i = 0; i < 5; i++)
{
const auto ave_time =
conv_driver.Run(wei_k_c0_y_x_c1_desc,
in_n_c0_hi_wi_c1_desc,
out_n_k0_ho_wo_k1_desc,
add_n_k0_hox2_wox2_k1_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
nrepeat);
{
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
conv_driver.Run(wei_k_c0_y_x_c1_desc,
in_n_c0_hi_wi_c1_desc,
out_n_k0_ho_wo_k1_desc,
add_n_k0_hox2_wox2_k1_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
0);
add_n_k0_hox2_wox2_k1_device_buf.FromDevice(add_n_k0_hox2_wox2_k1_out.mData.data());
}

View File

@@ -0,0 +1,309 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
in_n_hi_wi_c_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
I0,
I0,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
// clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
//clang-format on
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
false, // ABlockLdsExtraM
false // BBlockLdsExtraN
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
}

View File

@@ -0,0 +1,423 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-:
// gemmk1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
// clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
// clang-format on
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
float ave_time = 0;
for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda)
{
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
{
const auto descs =
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
out_n_ho_wo_k_desc,
wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
i_ytilda,
i_xtilda,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0);
if(GemmK0 != 0)
{
ave_time += driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
#endif
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
true, // CAccessOrderMRepeatNRepeat
false, // ABlockLdsExtraM
false // BBlockLdsExtraN
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
}
}
}
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
}

View File

@@ -0,0 +1,389 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations&,
const InLeftPads&,
const InRightPads&,
Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0>{}, // 1+: gemmm
Sequence<0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0>{}, // 1-: gemmm
Sequence<0, 0, 0>{})); // 2-: gemmk1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0>{}, // 1+: gemmn
Sequence<0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0
Sequence<0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0>{})); // 2-: Gemmk1
// clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
// clang-format on
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
out_n_ho_wo_k_desc,
wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
conv_strides,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
#endif
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
true, // CAccessOrderMRepeatNRepeat
false, // ABlockLdsExtraM
false // BBlockLdsExtraN
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
}

View File

@@ -0,0 +1,256 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename GridSizeType>
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_c_hi_wi,
Tensor<TWei>& wei_k_c_y_x,
const Tensor<TOut>& out_n_k_ho_wo,
GridSizeType desired_grid_size,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto N = in_n_c_hi_wi_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_desc.GetLength(I1);
const auto Ho = out_n_k_ho_wo_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_desc.GetLength(I2);
const auto X = wei_k_c_y_x_desc.GetLength(I3);
const auto GemmM = K;
const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
<< std::endl;
const auto descs =
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{},
GemmKBatch,
GemmKPad);
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM
Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB
Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 1, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
const auto driver_gemm_xdlops =
driver_gemm_xdlops_v2r4<BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1,
Sequence<0, 2, 1, 3>,
Sequence<0, 2, 1, 3>,
3,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1, 3>,
Sequence<0, 2, 1, 3>,
3,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false,
true,
true>;
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
0);
// copy result back to host
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
}

View File

@@ -0,0 +1,234 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_c_hi_wi,
Tensor<TWei>& wei_k_c_y_x,
const Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 1, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true, // ABlockLdsExtraM
true // BBlockLdsExtraN
>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
}

View File

@@ -0,0 +1,288 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename GridSizeType>
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_hi_wi_c,
Tensor<TWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
GridSizeType desired_grid_size,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
const auto X = wei_k_y_x_c_desc.GetLength(I2);
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
<< std::endl;
const auto descs =
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{},
GemmKBatch,
GemmKPad);
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
2,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true,
true>;
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
0);
// copy result back to host
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
}

View File

@@ -0,0 +1,276 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_hi_wi_c,
Tensor<TWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true,
true>(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
}

View File

@@ -0,0 +1,456 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template <typename TIn,
typename TWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename GridSizeType>
void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TIn>& in_n_hi_wi_c,
Tensor<TWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
GridSizeType desired_grid_size,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
const auto X = wei_k_y_x_c_desc.GetLength(I2);
const auto GemmM = K;
const auto GemmN = Y * X * C;
const auto GemmKTotal = N * Ho * Wo;
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
<< std::endl;
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{},
GemmKBatch,
GemmKPad);
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
BlockSize,
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 2, 3>,
2,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 1, 2, 3>,
Sequence<0, 1, 3, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true,
true>;
// timing
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// verification
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
0);
// copy result back to host
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
}

View File

@@ -0,0 +1,201 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_dlops_v1r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_dlops_v1r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_K_M0_M1,
GemmABlockTransferThreadClusterLengths_K_M0_M1,
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
0, // ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_K,
GemmABlockTransferDstScalarPerVector_M1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_N1,
GemmBBlockTransferDstScalarPerVector_N1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11,
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,273 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_dlops_v1r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 1;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 2;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 4;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_dlops_v1r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM110Xs,
GemmM11N11ThreadClusterN110Xs,
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11,
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -0,0 +1,228 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<1, 0, 2>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,600 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#if 0
__host__ __device__ static constexpr auto
MakePaddedGridDescriptors(const AGridDesc_K0Raw_MRaw_K1& a_grid_desc_k0raw_mraw_k1,
const BGridDesc_K0Raw_NRaw_K1& b_grid_desc_k0raw_nraw_k1,
const CGridDesc_MRaw_NRaw& c_grid_desc_mraw_nraw)
{
const auto K0Raw = a_grid_desc_k0raw_mraw_k1.GetLength(I0);
const auto K1 = a_grid_desc_k0raw_mraw_k1.GetLength(I2);
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
const auto K0Pad = math::integer_least_multiple(K0Raw, K0PerBlock) - K0Raw;
const auto MPad = math::integer_least_multiple(MRaw, MPerBlock) - MRaw;
const auto NPad = math::integer_least_multiple(NRaw, NPerBlock) - NRaw;
// A
const auto a_grid_desc_k0_m_k1 = [&]() {
if constexpr(DoPad_K0 && DoPad_M)
{
return transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(DoPad_K0 && !DoPad_M)
{
return transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_pass_through_transform(MRaw),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(!DoPad_K0 && DoPad_M)
{
return transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0Raw),
make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else
{
return a_grid_desc_k0raw_mraw_k1;
}
}();
// B
const auto b_grid_desc_k0_n_k1 = [&]() {
if constexpr(DoPad_K0 && DoPad_N)
{
return transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_right_pad_transform(NRaw, NPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(DoPad_K0 && !DoPad_N)
{
return transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_pass_through_transform(NRaw),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(!DoPad_K0 && DoPad_N)
{
return transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0Raw),
make_right_pad_transform(NRaw, NPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else
{
return b_grid_desc_k0raw_nraw_k1;
}
}();
// C
const auto c_grid_desc_m_n = [&]() {
if constexpr(DoPad_M && DoPad_N)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(DoPad_M && !DoPad_N)
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(!DoPad_M && DoPad_N)
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
reutnr c_grid_desc_m_n;
}
}();
}
#endif
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
#if 0 // debug
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
// HACK: hacks that control index calculation when iterating over A matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
#else
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0];
const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0);
const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1);
const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw;
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// HACK: hacks that control index calculation when iterating over A matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
#endif
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
#if 0
const auto out_gemmm_gemmn_grid_desc = descs[I2];
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
#else
const auto out_gemmmraw_gemmn_grid_desc = descs[I2];
const auto GemmN = out_gemmmraw_gemmn_grid_desc.GetLength(I1);
const auto out_gemmm_gemmn_grid_desc =
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
#endif
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true, // ABlockLdsExtraM
true // BBlockLdsExtraN
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -0,0 +1,196 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
ck::ActivTypeEnum_t activ_type,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
const InLengths& in_n_c0_hi_wi_c1_lengths,
const WeiLengths& wei_k_c0_y_x_c1_lengths,
const OutLengths& out_n_k0_ho_wo_k1_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
const Tensor<TInWei>& wei_k_c0_y_x_c1,
const Tensor<TOut>& bias_k0_k1,
Tensor<TOut>& out_n_k0_ho_wo_k1,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
const auto K = wei_k_c0_y_x_c1_lengths[I0];
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
const auto X = wei_k_c0_y_x_c1_lengths[I3];
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
constexpr index_t InWeiVectorSize = 8;
if(C1 % InWeiVectorSize != 0)
{
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif 1
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 8;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1;
constexpr index_t K2 = 2;
constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
#endif
if(KPerThread % InWeiVectorSize != 0)
{
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
}
const auto in_n_c0_hi_wi_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
const auto wei_k_c0_y_x_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2,
CThreadTransferDstScalarPerVector_K,
activ_type>{};
std::cerr << "conv_bias_activ_input_"
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0
<< "h" << Ho << "w" << Wo << "k" << K1 << std::endl;
for(int i = 0; i < 5; i++)
{
const auto ave_time =
conv_driver.Run(wei_k_c0_y_x_c1_desc,
in_n_c0_hi_wi_c1_desc,
out_n_k0_ho_wo_k1_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
nrepeat);
{
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
}

View File

@@ -0,0 +1,241 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "driver_contraction_dlops_v1r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GN0 = 4;
constexpr index_t GK1 = 1;
constexpr index_t GM1PerBlockGM11 = 128;
constexpr index_t GN1PerBlockGN11 = 32;
constexpr index_t GK0PerBlock = 8;
constexpr index_t BM1PerThreadBM11 = 4;
constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1;
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
#elif 1
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GN0 = 4;
constexpr index_t GK1 = 2;
constexpr index_t GM1PerBlockGM11 = 128;
constexpr index_t GN1PerBlockGN11 = 32;
constexpr index_t GK0PerBlock = 8;
constexpr index_t BM1PerThreadBM11 = 4;
constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1;
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
#endif
const auto descs =
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x,
in_desc_n_c_hi_wi,
out_desc_n_k_ho_wo,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GN0>{},
Number<GK1>{});
const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
constexpr auto in_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
constexpr auto out_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
constexpr auto in_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_contraction_dlops_v1r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector_BN1,
decltype(wei_grid_step_hacks),
decltype(in_grid_step_hacks),
decltype(out_grid_step_hacks),
decltype(wei_grid_move_slice_window_step_hacks),
decltype(in_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_grid_desc_gk0_gm0_gm1_gk1,
in_grid_desc_gk0_gn0_gn1_gk1,
out_grid_desc_gm0_gm1_gn0_gn1,
wei_grid_step_hacks,
in_grid_step_hacks,
out_grid_step_hacks,
wei_grid_move_slice_window_step_hacks,
in_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,212 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
ck::ActivTypeEnum_t activ_type,
typename InLengths,
typename WeiLengths,
typename MaxLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
const InLengths& in_n_c0_hi_wi_c1_lengths,
const WeiLengths& wei_k_c0_y_x_c1_lengths,
const MaxLengths& max_n_k0_hx_wx_k1_lengths,
const OutLengths& out_n_k0_ho_wo_k1_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
const Tensor<TInWei>& wei_k_c0_y_x_c1,
const Tensor<TOut>& bias_k0_k1,
Tensor<TOut>& out_n_k0_ho_wo_k1,
Tensor<TOut>& max_n_k0_hx_wx_k1,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
const auto K = wei_k_c0_y_x_c1_lengths[I0];
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
const auto X = wei_k_c0_y_x_c1_lengths[I3];
const auto Hx = max_n_k0_hx_wx_k1_lengths[I2];
const auto Wx = max_n_k0_hx_wx_k1_lengths[I3];
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
DeviceMem max_n_k0_hx_wx_k1_device_buf(sizeof(TOut) *
max_n_k0_hx_wx_k1.mDesc.GetElementSpace());
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
max_n_k0_hx_wx_k1_device_buf.ToDevice(max_n_k0_hx_wx_k1.mData.data());
constexpr index_t InWeiVectorSize = 8;
if(C1 % InWeiVectorSize != 0)
{
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif 1
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 8;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = 2 * 9;
constexpr index_t E2 = 1;
constexpr index_t K2 = 2;
constexpr index_t E1PerBlock = 2;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
#endif
if(KPerThread % InWeiVectorSize != 0)
{
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
}
const auto in_n_c0_hi_wi_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
const auto wei_k_c0_y_x_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
const auto max_n_k0_hx_wx_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1));
const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2,
CThreadTransferDstScalarPerVector_K,
activ_type>{};
std::cerr << "conv_bias_activ_maxpool_input_"
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0
<< "h" << Ho << "w" << Wo << "k" << K1 << "_maxpoolout_n" << N << "k" << K0 << "h"
<< Ho / 2 << "w" << Wo / 2 << "k" << K1 << std::endl;
for(int i = 0; i < 5; i++)
{
const auto ave_time =
conv_driver.Run(wei_k_c0_y_x_c1_desc,
in_n_c0_hi_wi_c1_desc,
out_n_k0_ho_wo_k1_desc,
max_n_k0_hx_wx_k1_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(max_n_k0_hx_wx_k1_device_buf.GetDeviceBuffer()),
nrepeat);
{
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
max_n_k0_hx_wx_k1_device_buf.FromDevice(max_n_k0_hx_wx_k1.mData.data());
}

View File

@@ -0,0 +1,463 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_km_kn_mn(const Tensor<ABType>& a_k_m,
const Tensor<ABType>& b_k_n,
Tensor<CType>& c_m_n,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif
const auto K = a_k_m.mDesc.GetLengths()[0];
const auto M = a_k_m.mDesc.GetLengths()[1];
const auto N = b_k_n.mDesc.GetLengths()[1];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
a_k_m.mDesc.GetStrides()[1],
a_k_m.mDesc.GetStrides()[0]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
b_k_n.mDesc.GetStrides()[1],
b_k_n.mDesc.GetStrides()[0]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
ABlockTransferSrcScalarPerVector_M,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
BBlockTransferSrcScalarPerVector_N,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
7,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true, // ABlockLdsExtraM
true // BBlockLdsExtraN
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
}

View File

@@ -0,0 +1,263 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_km_kn_nm(const Tensor<ABType>& a_k_m,
const Tensor<ABType>& b_k_n,
Tensor<CType>& c_n_m,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#endif
const auto K = a_k_m.mDesc.GetLengths()[0];
const auto M = a_k_m.mDesc.GetLengths()[1];
const auto N = b_k_n.mDesc.GetLengths()[1];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
a_k_m.mDesc.GetStrides()[1],
a_k_m.mDesc.GetStrides()[0]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
b_k_n.mDesc.GetStrides()[1],
b_k_n.mDesc.GetStrides()[0]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
ABlockTransferSrcScalarPerVector_M,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
BBlockTransferSrcScalarPerVector_N,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
}

View File

@@ -0,0 +1,463 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_km_nk_mn(const Tensor<ABType>& a_k_m,
const Tensor<ABType>& b_n_k,
Tensor<CType>& c_m_n,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif
const auto K = a_k_m.mDesc.GetLengths()[0];
const auto M = a_k_m.mDesc.GetLengths()[1];
const auto N = b_n_k.mDesc.GetLengths()[0];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
a_k_m.mDesc.GetStrides()[1],
a_k_m.mDesc.GetStrides()[0]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
b_n_k.mDesc.GetStrides()[0],
b_n_k.mDesc.GetStrides()[1]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
ABlockTransferSrcScalarPerVector_M,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
BBlockTransferSrcScalarPerVector_K1,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
7,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true, // ABlockLdsExtraM
true // BBlockLdsExtraN
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
}

View File

@@ -0,0 +1,263 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_km_nk_nm(const Tensor<ABType>& a_k_m,
const Tensor<ABType>& b_n_k,
Tensor<CType>& c_n_m,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#endif
const auto K = a_k_m.mDesc.GetLengths()[0];
const auto M = a_k_m.mDesc.GetLengths()[1];
const auto N = b_n_k.mDesc.GetLengths()[0];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
a_k_m.mDesc.GetStrides()[1],
a_k_m.mDesc.GetStrides()[0]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
b_n_k.mDesc.GetStrides()[0],
b_n_k.mDesc.GetStrides()[1]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
ABlockTransferSrcScalarPerVector_M,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
BBlockTransferSrcScalarPerVector_K1,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
}

View File

@@ -0,0 +1,463 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
const Tensor<ABType>& b_k_n,
Tensor<CType>& c_m_n,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif
const auto K = a_m_k.mDesc.GetLengths()[1];
const auto M = a_m_k.mDesc.GetLengths()[0];
const auto N = b_k_n.mDesc.GetLengths()[1];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
a_m_k.mDesc.GetStrides()[0],
a_m_k.mDesc.GetStrides()[1]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
b_k_n.mDesc.GetStrides()[1],
b_k_n.mDesc.GetStrides()[0]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
ABlockTransferSrcScalarPerVector_K1,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
BBlockTransferSrcScalarPerVector_N,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
7,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true, // ABlockLdsExtraM
true // BBlockLdsExtraN
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
}

View File

@@ -0,0 +1,291 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k,
const Tensor<ABType>& b_k_n,
Tensor<CType>& c_n_m,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#endif
const auto K = a_m_k.mDesc.GetLengths()[1];
const auto M = a_m_k.mDesc.GetLengths()[0];
const auto N = b_k_n.mDesc.GetLengths()[1];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
a_m_k.mDesc.GetStrides()[0],
a_m_k.mDesc.GetStrides()[1]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
b_k_n.mDesc.GetStrides()[1],
b_k_n.mDesc.GetStrides()[0]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
ABlockTransferSrcScalarPerVector_K1,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<0, 2, 1>,
Sequence<0, 2, 1>,
1,
BBlockTransferSrcScalarPerVector_N,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
}

View File

@@ -0,0 +1,564 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k,
const Tensor<ABType>& b_n_k,
Tensor<CType>& c_m_n,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 64;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif
const auto K = a_m_k.mDesc.GetLengths()[1];
const auto M = a_m_k.mDesc.GetLengths()[0];
const auto N = b_n_k.mDesc.GetLengths()[0];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
#if 1
// non-padded GEMM
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
a_m_k.mDesc.GetStrides()[0],
a_m_k.mDesc.GetStrides()[1]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
b_n_k.mDesc.GetStrides()[0],
b_n_k.mDesc.GetStrides()[1]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
#else
// padded GEMM
const auto a_k0_m_k1_grid_desc_tmp =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
a_m_k.mDesc.GetStrides()[0],
a_m_k.mDesc.GetStrides()[1]));
const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M;
const auto a_k0_m_k1_grid_desc =
transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp,
make_tuple(make_pass_through_transform(K0),
make_right_pad_transform(M, MRightPad),
make_pass_through_transform(K1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
b_n_k.mDesc.GetStrides()[0],
b_n_k.mDesc.GetStrides()[1]));
const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
const auto c_m_n_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc_tmp,
make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0
Sequence<0, 0, 0, 0>{}, // 1+: M
Sequence<0, 0, 0, 0>{}), // 2+: K1
make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0
Sequence<0, 0, 0, 0>{}, // 1-: M
Sequence<0, 0, 0, 0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
#endif
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
ABlockTransferSrcScalarPerVector_K1,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
BBlockTransferSrcScalarPerVector_K1,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
7,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat
true, // ABlockLdsExtraM
true // BBlockLdsExtraN
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
debug::debug_driver_gemm_xdlops_v2r3::M01,
debug::debug_driver_gemm_xdlops_v2r3::N01,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
}

View File

@@ -0,0 +1,347 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename ABType, typename AccType, typename CType>
void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k,
const Tensor<ABType>& b_n_k,
Tensor<CType>& c_n_m,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 128;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 64;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#endif
const auto K = a_m_k.mDesc.GetLengths()[1];
const auto M = a_m_k.mDesc.GetLengths()[0];
const auto N = b_n_k.mDesc.GetLengths()[0];
constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
a_m_k.mDesc.GetStrides()[0],
a_m_k.mDesc.GetStrides()[1]));
const auto b_k0_n_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
b_n_k.mDesc.GetStrides()[0],
b_n_k.mDesc.GetStrides()[1]));
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: M
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: M
Sequence<0>{})); // 2-: K1
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
Sequence<0>{}, // 1+: N
Sequence<0>{}), // 2+: K1
make_tuple(Sequence<0>{}, // 0-: K0
Sequence<0>{}, // 1-: N
Sequence<0>{})); // 2-: K1
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time =
driver_gemm_xdlops_v2r3<BlockSize,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
ABlockTransferSrcScalarPerVector_K1,
ABlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
BBlockTransferSrcScalarPerVector_K1,
BBlockTransferDstScalarPerVector_K1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
CThreadTransferDstScalarPerVector,
decltype(a_k0_m_k1_grid_step_hacks),
decltype(b_k0_n_k1_grid_step_hacks),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m_n_grid_desc,
a_k0_m_k1_grid_step_hacks,
b_k0_n_k1_grid_step_hacks,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
a_k0_m_k1_grid_move_slice_window_step_hacks,
b_k0_n_k1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
}

View File

@@ -0,0 +1,286 @@
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_contraction_dlops_v1r2.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
ck::index_t GM1PerBlockGM11,
ck::index_t GN1PerBlockGN11,
ck::index_t GK0PerBlock,
ck::index_t BM1PerThreadBM11,
ck::index_t BN1PerThreadBN11,
ck::index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs,
typename BM10BN10ThreadClusterBN10Xs,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
__host__ float
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction =
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
{
throw std::runtime_error("wrong! "
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GM0_GM1_GN0_GN1 has invalid setting");
}
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1);
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
// c_grid_block_cluster_blockid_to_gm10_gn10
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1);
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
const bool has_double_tail_k_block_loop =
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
{
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
return ave_time;
}
#endif

View File

@@ -0,0 +1,429 @@
#ifndef DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#define DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::index_t E1_,
ck::index_t E2_,
ck::index_t K2_,
ck::index_t KPerBlock,
ck::index_t HoPerBlock,
ck::index_t WoPerBlock,
ck::index_t E1PerBlock,
ck::index_t KPerThread,
ck::index_t HoPerThread,
ck::index_t WoPerThread,
ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K,
ck::ActivTypeEnum_t activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
{
template <typename... Wei,
typename... In,
typename... Add,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ck::TensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_d_grid,
const int nrepeat) const
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto Hox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I2);
const auto Wox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I3);
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto OutRightPadHx = OutRightPadH * 2;
const auto OutRightPadWx = OutRightPadW * 2;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
const auto E = C0 * Y * X;
constexpr auto E1 = Number<E1_>{};
constexpr auto E2 = Number<E2_>{};
constexpr auto K2 = Number<K2_>{};
const auto E0 = E / E1;
// weight tensor
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
make_tuple(make_pass_through_transform(K),
make_pass_through_transform(C0 * Y * X),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
const auto a_e0_e1_k_e2_grid_desc =
transform_tensor_descriptor(a_e_k_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(K),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// input tensor
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C0),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
in_n_c0_hip_wip_e2_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C0),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_n_c0_y_ho_x_wo_e2_global_desc,
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_e_n_ho_wo_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
// output tensor
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, I0, OutRightPadH),
make_pad_transform(Wo, I0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// add tensor
const auto d_k_n_hopx2_wopx2_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Hox2, I0, OutRightPadHx),
make_pad_transform(Wox2, I0, OutRightPadWx)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E1 % E1PerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// clang-format off
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
);
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// clang-format on
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc),
decltype(d_k_n_hopx2_wopx2_grid_desc),
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
Sequence<2, 3, 0, 1, 4>,
Sequence<0, 1, 2, 3, 4>,
4,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
9,
BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
1,
CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(
d_k_n_hopx2_wopx2_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_e0_block_loop = E0 > 1;
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
const auto cblockid_to_k_n_h_w_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(cblockid_to_k_n_h_w_block_cluster_adaptor);
float ave_time = 0;
if(has_main_e0_block_loop)
{
const auto kernel = kernel_gemm_dlops_v3_resize_add<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true,
activ_type>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_gemm_dlops_v3_resize_add<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false,
activ_type>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor);
}
return ave_time;
}
};
#endif

View File

@@ -0,0 +1,386 @@
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::index_t E1_,
ck::index_t E2_,
ck::index_t K2_,
ck::index_t KPerBlock,
ck::index_t HoPerBlock,
ck::index_t WoPerBlock,
ck::index_t E1PerBlock,
ck::index_t KPerThread,
ck::index_t HoPerThread,
ck::index_t WoPerThread,
ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K,
ck::ActivTypeEnum_t activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
{
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
const int nrepeat) const
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
#else
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
#endif
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
const auto E = C0 * Y * X;
constexpr auto E1 = Number<E1_>{};
constexpr auto E2 = Number<E2_>{};
constexpr auto K2 = Number<K2_>{};
const auto E0 = E / E1;
// weight tensor
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
make_tuple(make_pass_through_transform(K),
make_pass_through_transform(C0 * Y * X),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
const auto a_e0_e1_k_e2_grid_desc =
transform_tensor_descriptor(a_e_k_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(K),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// input tensor
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C0),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
in_n_c0_hip_wip_e2_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C0),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_n_c0_y_ho_x_wo_e2_global_desc,
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_e_n_ho_wo_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
// output tensor
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, I0, OutRightPadH),
make_pad_transform(Wo, I0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E1 % E1PerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// clang-format off
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
);
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// clang-format on
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc),
decltype(c_k_n_hop_wop_grid_desc),
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
Sequence<2, 3, 0, 1, 4>,
Sequence<0, 1, 2, 3, 4>,
4,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
9,
BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, H2, W0, W1, W2
1,
CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_e0_block_loop = E0 > 1;
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
const auto cblockid_to_k_n_h_w_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(cblockid_to_k_n_h_w_block_cluster_adaptor);
float ave_time = 0;
if(has_main_e0_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true,
activ_type>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor);
}
else
{
const auto kernel =
kernel_gemm_dlops_v3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false,
activ_type>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor);
}
return ave_time;
}
};
#endif

View File

@@ -0,0 +1,440 @@
#ifndef DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#define DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v3.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::index_t E1_,
ck::index_t E2_,
ck::index_t K2_,
ck::index_t KPerBlock,
ck::index_t HoPerBlock,
ck::index_t WoPerBlock,
ck::index_t E1PerBlock,
ck::index_t KPerThread,
ck::index_t HoPerThread,
ck::index_t WoPerThread,
ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K,
ck::ActivTypeEnum_t activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool
{
template <typename... Wei,
typename... In,
typename... MaxPool,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ck::TensorDescriptor<MaxPool...>& max_n_k0_hx_wx_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_d_grid,
const int nrepeat) const
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto Hx = max_n_k0_hx_wx_k1_global_desc.GetLength(I2);
const auto Wx = max_n_k0_hx_wx_k1_global_desc.GetLength(I3);
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto OutRightPadHx = Number<OutRightPadH / 2>{};
const auto OutRightPadWx = Number<OutRightPadW / 2>{};
#else
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto OutRightPadHx = OutRightPadH / 2;
const auto OutRightPadWx = OutRightPadW / 2;
#endif
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
const auto E = C0 * Y * X;
constexpr auto E1 = Number<E1_>{};
constexpr auto E2 = Number<E2_>{};
constexpr auto K2 = Number<K2_>{};
const auto E0 = E / E1;
// weight tensor
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
make_tuple(make_pass_through_transform(K),
make_pass_through_transform(C0 * Y * X),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
const auto a_e0_e1_k_e2_grid_desc =
transform_tensor_descriptor(a_e_k_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(K),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// input tensor
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C0),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
in_n_c0_hip_wip_e2_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C0),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_n_c0_y_ho_x_wo_e2_global_desc,
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_e_n_ho_wo_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
// output tensor
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, I0, OutRightPadH),
make_pad_transform(Wo, I0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// max tensor
const auto d_k_n_hx_wx_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Hx, I0, OutRightPadHx),
make_pad_transform(Wx, I0, OutRightPadWx)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E1 % E1PerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// clang-format off
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
);
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// clang-format on
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc),
decltype(d_k_n_hx_wx_grid_desc),
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
Sequence<2, 3, 0, 1, 4>,
Sequence<0, 1, 2, 3, 4>,
4,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
9,
BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused
// with MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
1,
CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_e0_block_loop = E0 > 1;
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
const auto cblockid_to_k_n_h_w_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(cblockid_to_k_n_h_w_block_cluster_adaptor);
float ave_time = 0;
if(has_main_e0_block_loop)
{
const auto kernel = kernel_gemm_dlops_v3_maxpool<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true,
activ_type>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_gemm_dlops_v3_maxpool<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false,
activ_type>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor);
}
return ave_time;
}
};
#endif

View File

@@ -0,0 +1,278 @@
#ifndef DRIVER_GEMM_DLOPS_V1R2
#define DRIVER_GEMM_DLOPS_V1R2
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v1r2.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CMNGridDesc,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t M1PerThread,
ck::index_t N1PerThread,
ck::index_t KPerThread,
ck::index_t M1N1ThreadClusterM10,
ck::index_t M1N1ThreadClusterN10,
ck::index_t M1N1ThreadClusterM11,
ck::index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AKMGridDesc,
BKNGridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting");
}
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
// c_m0_m10_m11_n0_n10_n11_grid_desc
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// cblockid_to_m0_n0_block_cluster_adaptor
const auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
{
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
else
{
const auto kernel =
kernel_gemm_dlops_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
return ave_time;
}
#endif

View File

@@ -0,0 +1,275 @@
#ifndef DRIVER_GEMM_DLOPS_V1R3
#define DRIVER_GEMM_DLOPS_V1R3
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v1r3.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t M1PerThread,
ck::index_t N1PerThread,
ck::index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseGemm =
GridwiseGemmDlops_km_kn_mn_v1r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting");
}
const auto a_k0_m0_m1_k1_grid_desc =
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
const auto b_k0_n0_n1_k1_grid_desc =
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
// c_m0_m10_m11_n0_n10_n11_grid_desc
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// cblockid_to_m0_n0_block_cluster_adaptor
const auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
{
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
else
{
const auto kernel =
kernel_gemm_dlops_v1r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AK0M0M1K1GridDesc>,
remove_reference_t<BK0N0N1K1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
}
return ave_time;
}
#endif

View File

@@ -0,0 +1,220 @@
#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP
#define DRIVER_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "element_wise_operation.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K,
typename CMNGridDesc,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t K1,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN>
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K& b_grid_desc_k0_n_k1,
const CMNGridDesc& c_grid_desc_m_n,
ck::index_t M01,
ck::index_t N01,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K,
CMNGridDesc,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
{
std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", "
<< a_grid_desc_k0_m_k1.GetLength(I1) << ", " << a_grid_desc_k0_m_k1.GetLength(I2)
<< "}" << std::endl;
std::cout << "b_grid_desc_k0_n_k1{" << b_grid_desc_k0_n_k1.GetLength(I0) << ", "
<< b_grid_desc_k0_n_k1.GetLength(I1) << ", " << b_grid_desc_k0_n_k1.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_grid_desc_m_n{ " << c_grid_desc_m_n.GetLength(I0) << ", "
<< c_grid_desc_m_n.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, M01, N01))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01);
using Block2CTileMap = decltype(block_2_ctile_map);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
auto element_op_ = ElementwiseOperation{};
if(has_main_k0_block_loop)
{
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map);
}
else
{
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map);
}
return ave_time;
}
#endif

View File

@@ -0,0 +1,213 @@
#ifndef DRIVER_GEMM_XDLOPS_V2R4
#define DRIVER_GEMM_XDLOPS_V2R4
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CMNGridDesc,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t K1,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat,
bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN>
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
ck::index_t M01,
ck::index_t N01,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
using GridwiseGemm =
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
ABK0MK1GridDesc,
BBK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
{
std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I1) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I2) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I1) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I2) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting");
}
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
const auto c_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch);
{
std::cout << "gridSize : " << grid_size << std::endl;
}
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
return ave_time;
}
#endif

View File

@@ -0,0 +1,134 @@
#ifndef REFERENCE_BATCHED_GEMM_HPP
#define REFERENCE_BATCHED_GEMM_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceBatchedGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_g_m_k,
const Tensor<BDataType>& b_g_k_n,
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_g_m_k_{a_g_m_k},
b_g_k_n_{b_g_k_n},
c_g_m_n_{c_g_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_g_m_k_;
const Tensor<BDataType>& b_g_k_n_;
Tensor<CDataType>& c_g_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceBatchedGemm::Argument;
float Run(const Argument& arg)
{
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_g_m_k_(g, m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_g_k_n_(g, k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_g_m_n_(g, m, n) = v_c;
};
make_ParallelTensorFunctor(f_gmk_gkn_gmn,
arg.c_g_m_n_.mDesc.GetLengths()[0],
arg.c_g_m_n_.mDesc.GetLengths()[1],
arg.c_g_m_n_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_g_m_k,
const Tensor<BDataType>& b_g_k_n,
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceBatchedGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,177 @@
#ifndef REFERENCE_CONV_WRW_HPP
#define REFERENCE_CONV_WRW_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct ReferenceConvWrw : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
}
const Tensor<InDataType>& in_n_c_hi_wi_;
Tensor<WeiDataType>& wei_k_c_y_x_;
const Tensor<OutDataType>& out_n_k_ho_wo_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceConvWrw::Argument;
float Run(const Argument& arg)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(int n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n)
{
for(int ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
{
int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] -
arg.in_left_pads_[I0];
for(int wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
{
int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] -
arg.in_left_pads_[I1];
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out,
ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert<OutDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.wei_k_c_y_x_.mDesc.GetLengths()[0],
arg.wei_k_c_y_x_.mDesc.GetLengths()[1],
arg.wei_k_c_y_x_.mDesc.GetLengths()[2],
arg.wei_k_c_y_x_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceConvFwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,192 @@
#ifndef REFERENCE_CONV_BWD_DATA_HPP
#define REFERENCE_CONV_BWD_DATA_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct ReferenceConvBwdData : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
}
Tensor<InDataType>& in_n_c_hi_wi_;
const Tensor<WeiDataType>& wei_k_c_y_x_;
const Tensor<OutDataType>& out_n_k_ho_wo_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceConvBwdData::Argument;
float Run(const Argument& arg)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.wei_k_c_y_x_.mDesc.GetLengths()[0];
std::size_t Y = arg.wei_k_c_y_x_.mDesc.GetLengths()[2];
std::size_t X = arg.wei_k_c_y_x_.mDesc.GetLengths()[3];
std::size_t Ho = arg.out_n_k_ho_wo_.mDesc.GetLengths()[2];
std::size_t Wo = arg.out_n_k_ho_wo_.mDesc.GetLengths()[3];
float v_acc = 0;
for(int y = 0; y < Y; ++y)
{
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0];
if(h_tmp % arg.conv_strides_[0] == 0)
{
int ho = h_tmp / arg.conv_strides_[0];
if(ho >= 0 && ho < Ho)
{
for(int x = 0; x < X; ++x)
{
int w_tmp = wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1];
if(w_tmp % arg.conv_strides_[1] == 0)
{
int wo = w_tmp / arg.conv_strides_[1];
if(wo >= 0 && wo < Wo)
{
for(int k = 0; k < K; ++k)
{
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<float>(
arg.out_n_k_ho_wo_(n, k, ho, wo)));
arg.wei_element_op_(v_wei,
ck::type_convert<float>(
arg.wei_k_c_y_x_(k, c, y, x)));
v_acc += v_out * v_wei;
}
}
}
}
}
}
}
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.in_n_c_hi_wi_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
};
make_ParallelTensorFunctor(f_nchw,
arg.in_n_c_hi_wi_.mDesc.GetLengths()[0],
arg.in_n_c_hi_wi_.mDesc.GetLengths()[1],
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2],
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceConvBwdData"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,234 @@
#ifndef REFERENCE_CONV_FWD_HPP
#define REFERENCE_CONV_FWD_HPP
#include <iostream>
#include <type_traits>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
//
// @brief Reference implementation for forward convolution.
//
// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout.
// Weights tensor supports KCYX data layout. Output tensor supports
// NKHoWo data layout.
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type.
// @tparam InElementwiseOperation Functor for input tensor elementwise
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NumDimSpatial Number of spatial dimensions.
//
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: input_{input},
weight_{weight},
output_{output},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
}
const Tensor<InDataType>& input_;
const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceConvFwd::Argument;
float Run(const Argument& arg)
{
if constexpr(NumDimSpatial == 1)
{
auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
{
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2])
{
float v_in;
float v_wei;
arg.in_element_op_(v_in,
static_cast<const float>(arg.input_(n, c, wi)));
arg.wei_element_op_(v_wei,
static_cast<const float>(arg.weight_(k, c, x)));
v_acc += v_in * v_wei;
}
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = v_out;
};
make_ParallelTensorFunctor(f_ncw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.input_.mDesc.GetLengths()[3])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x)));
v_acc += v_in * v_wei;
}
}
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{input,
weight,
output,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceConvFwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,182 @@
#ifndef REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP
#define REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, Ho, Wo, K] =
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K])
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
Tensor<OutDataType>& out_n_k_ho_wo,
const Tensor<OutDataType>& bias_k,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo},
bias_k_{bias_k},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
}
const Tensor<InDataType>& in_n_c_hi_wi_;
const Tensor<WeiDataType>& wei_k_c_y_x_;
Tensor<OutDataType>& out_n_k_ho_wo_;
const Tensor<OutDataType>& bias_k_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceConvFwd_Bias_Activation::Argument;
float Run(const Argument& arg)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in,
static_cast<const float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
arg.wei_element_op_(
v_wei, static_cast<const float>(arg.wei_k_c_y_x_(k, c, y, x)));
v_acc += v_in * v_wei;
}
}
}
}
float v_out;
arg.out_element_op_(v_out, v_acc, static_cast<float>(arg.bias_k_(k)));
arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_n_k_ho_wo_.mDesc.GetLengths()[0],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[1],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[2],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
Tensor<OutDataType>& out_n_k_ho_wo,
const Tensor<OutDataType>& bias_k,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo,
bias_k,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceConvFwd_Bias_Activation"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,190 @@
#ifndef REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP
#define REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, Ho, Wo, K] =
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K]
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
Tensor<OutDataType>& out_n_k_ho_wo,
const Tensor<OutDataType>& bias_k,
const Tensor<OutDataType>& resi_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo},
bias_k_{bias_k},
resi_n_k_ho_wo_{resi_n_k_ho_wo},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
}
const Tensor<InDataType>& in_n_c_hi_wi_;
const Tensor<WeiDataType>& wei_k_c_y_x_;
Tensor<OutDataType>& out_n_k_ho_wo_;
const Tensor<OutDataType>& bias_k_;
const Tensor<OutDataType>& resi_n_k_ho_wo_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceConvFwd_Bias_Activation_Add::Argument;
float Run(const Argument& arg)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in,
static_cast<const float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
arg.wei_element_op_(
v_wei, static_cast<const float>(arg.wei_k_c_y_x_(k, c, y, x)));
v_acc += v_in * v_wei;
}
}
}
}
float v_out;
arg.out_element_op_(v_out,
v_acc,
static_cast<const float>(arg.bias_k_(k)),
static_cast<const float>(arg.resi_n_k_ho_wo_(n, k, ho, wo)));
arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_n_k_ho_wo_.mDesc.GetLengths()[0],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[1],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[2],
arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in_n_c_hi_wi,
const Tensor<WeiDataType>& wei_k_c_y_x,
Tensor<OutDataType>& out_n_k_ho_wo,
const Tensor<OutDataType>& bias_k,
const Tensor<OutDataType>& resi_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{in_n_c_hi_wi,
wei_k_c_y_x,
out_n_k_ho_wo,
bias_k,
resi_n_k_ho_wo,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceConvFwd_Bias_Activation_Add"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,132 @@
#ifndef REFERENCE_GEMM_HPP
#define REFERENCE_GEMM_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemm::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,133 @@
#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP
#define REFERENCE_GEMM_BIAS_BIAS_2D_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename C0DataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBias2D : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<C0DataType>& c0_m_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c0_m_n_{c0_m_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const Tensor<CDataType>& c0_m_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBias2D::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType a = 0;
AccDataType b = 0;
AccDataType acc = 0;
for(int k = 0; k < K; ++k)
{
arg.a_element_op_(a, arg.a_m_k_(m, k));
arg.b_element_op_(b, arg.b_k_n_(k, n));
acc += a * b;
}
CDataType cast_acc = static_cast<CDataType>(acc);
arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n));
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<C0DataType>& c0_m_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBias2D"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,136 @@
#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_HPP
#define REFERENCE_GEMM_BIAS_ACTIVATION_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBiasActivation : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_{c0_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<CDataType>& c0_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBiasActivation::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c, v_acc, static_cast<float>(arg.c0_n_(n)));
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBiasActivation"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,144 @@
#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
#define REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
const Tensor<CDataType>& c1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_{c0_n},
c1_m_n_{c1_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<CDataType>& c0_n_;
const Tensor<CDataType>& c1_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBiasActivationAdd::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c,
v_acc,
static_cast<float>(arg.c0_n_(n)),
static_cast<float>(arg.c1_m_n_(m, n)));
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg, int) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
const Tensor<CDataType>& c1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBiasActivationAdd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,122 @@
#ifndef NAIVE_CONV_FWD_HPP
#define NAIVE_CONV_FWD_HPP
namespace ck {
namespace ref {
/*
* \brief naive implementation of 3D convolution. Layout is (NDHWC, KZYXC, NDHWK).
*
* \param N number of batches
* \param K number of filters
* \param C number of channels of weight
* \param (Di, Hi, Wi) depth, height and width dimension of data
* \param (Z, Y, X) depth, height and width dimensions of weights
* \param (Do, Ho, Wo) depth, height and width dimension of output
* \param (stride_z, stride_y, stride_x) strides
* \param (dilation_z, dilation_y, dilation_x) dilations
* \param (pad_z, pad_y, pad_x) pads
*/
template <typename TIn,
typename TWei,
typename TOut,
typename TAcc,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
__global__ void naive_conv_fwd_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in,
const TWei* __restrict__ p_wei,
TOut* __restrict__ p_out,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x)
{
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const index_t num_threads = blockDim.x * gridDim.x;
const long_index_t output_length = N * Do * Ho * Wo * K;
const index_t out_strides[] = {Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K};
const index_t in_strides[] = {Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C};
const index_t wei_strides[] = {Z * Y * X * C, Y * X * C, X * C, C};
constexpr auto in_op = InElementwiseOperation{};
constexpr auto wei_op = WeiElementwiseOperation{};
constexpr auto out_op = OutElementwiseOperation{};
TIn in_val;
TWei wei_val;
TOut out_val;
for(long_index_t ii = tid; ii < output_length; ii += num_threads)
{
const index_t n = ii / out_strides[0];
index_t k = ii - n * out_strides[0];
const index_t dO = k / out_strides[1];
k -= dO * out_strides[1];
const index_t ho = k / out_strides[2];
k -= ho * out_strides[2];
const index_t wo = k / out_strides[3];
k -= wo * out_strides[3];
TAcc acc = static_cast<TAcc>(0);
const TIn* in_n = p_in + static_cast<long_index_t>(n) * in_strides[0];
const TWei* wei_k = p_wei + static_cast<long_index_t>(k) * wei_strides[0];
for(index_t z = 0; z < Z; ++z)
{
index_t di = stride_z * dO - pad_z + dilation_z * z;
const TIn* in_n_di = in_n + di * in_strides[1];
const TWei* wei_k_z = wei_k + z * wei_strides[1];
for(index_t y = 0; y < Y; ++y)
{
index_t hi = stride_y * ho - pad_y + dilation_y * y;
const TIn* in_n_di_hi = in_n_di + hi * in_strides[2];
const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2];
for(index_t x = 0; x < X; ++x)
{
index_t wi = stride_x * wo - pad_x + dilation_x * x;
const TIn* in_n_di_hi_wi = in_n_di_hi + wi * in_strides[3];
const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3];
if(di >= 0 && di < Di && hi >= 0 && hi < Hi && wi >= 0 && wi < Wi)
{
for(index_t c = 0; c < C; ++c)
{
in_op(in_val, in_n_di_hi_wi[c]);
wei_op(wei_val, wei_k_z_y_x[c]);
acc += in_val * wei_val;
}
}
}
}
}
out_op(out_val, static_cast<TOut>(acc));
p_out[ii] = out_val;
}
}
} // namespace ref
} // namespace ck
#endif

View File

@@ -0,0 +1,26 @@
#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP
#define CK_DEVICE_OPERATION_INSTANCE_HPP
#include <stdlib.h>
namespace ck {
namespace tensor_operation {
namespace device {
template <typename OpInstance, typename NewOpInstances>
void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op_instances,
const NewOpInstances& new_op_instances)
{
ck::static_for<0, std::tuple_size_v<NewOpInstances>, 1>{}([&](auto i) {
const auto new_op_instance = std::get<i>(new_op_instances);
using NewOpInstance = remove_cvref_t<decltype(new_op_instance)>;
op_instances.push_back(std::make_unique<NewOpInstance>(new_op_instance));
});
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,28 @@
#ifndef DEVICE_REDUCE_INSTANTCE_HPP
#define DEVICE_REDUCE_INSTANTCE_HPP
#include "device_reduce_instance_blockwise_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_f16_f32_f16.hpp"
#include "device_reduce_instance_blockwise_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_f32_f64_f32.hpp"
#include "device_reduce_instance_blockwise_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
#include "device_reduce_instance_threadwise_f32_f64_f32.hpp"
#include "device_reduce_instance_threadwise_f64_f64_f64.hpp"
#endif

View File

@@ -0,0 +1,168 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 2, 2, 2, 1>,
ReductionConfiguration_2<0, 1, 1, 2, 1>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 2, 2, 1, 2>,
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3>
// clang-format on
>;
#else
using reduce_configuration_2_instances_blockwise = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 4, 4, 8, 1>,
ReductionConfiguration_2<0, 4, 4, 4, 1>,
ReductionConfiguration_2<0, 2, 2, 2, 1>,
ReductionConfiguration_2<1, 4, 1, 1, 8>,
ReductionConfiguration_2<1, 4, 1, 1, 4>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
// special instances
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<0, 1, 1, 5, 1>,
ReductionConfiguration_2<0, 1, 1, 7, 1>,
ReductionConfiguration_2<0, 1, 1, 11, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3>,
ReductionConfiguration_2<1, 1, 1, 1, 5>,
ReductionConfiguration_2<1, 1, 1, 1, 7>,
ReductionConfiguration_2<1, 1, 1, 1, 11>
// clang-format on
>;
#endif
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
using cfg1 =
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance = DeviceReduceBlockWise<InDataType,
AccDataType,
OutDataType,
Rank,
ReduceDims,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
NeedIndices,
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
};
#define ADD_BLOCKWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,41 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,32 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,32 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,167 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise_second_call = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 2, 2, 1, 2>,
ReductionConfiguration_2<1, 1, 1, 1, 3>,
ReductionConfiguration_2<1, 1, 2, 1, 3>
// clang-format on
>;
#else
using reduce_configuration_2_instances_blockwise_second_call = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<1, 4, 1, 1, 8>,
ReductionConfiguration_2<1, 4, 1, 1, 4>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 1, 1, 1, 3>,
ReductionConfiguration_2<1, 1, 1, 1, 5>,
ReductionConfiguration_2<1, 1, 1, 1, 7>,
ReductionConfiguration_2<1, 1, 1, 1, 11>
// clang-format on
>;
#endif
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
using deviceReduceBlockWiseSecondCallPtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
void add_device_reduce_instance_blockwise_second_call(
std::vector<deviceReduceBlockWiseSecondCallPtrType<AccDataType, ReduceOpId>>&
device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::
InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::
AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
static_assert(std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use "
"add_device_reduce_instance_blockwise_second_call!");
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
using cfg1 =
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0,
std::tuple_size<reduce_configuration_2_instances_blockwise_second_call>::value,
1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_blockwise_second_call{}))>;
using ReduceOpInstance = DeviceReduceBlockWiseSecondCall<InDataType,
AccDataType,
OutDataType,
Rank,
ReduceDims,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
NeedIndices,
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
};
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
extern template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,41 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,32 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,32 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,55 @@
#ifndef DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP
#define DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
template <int BlockSize, int MThreadClusterSize, int KThreadClusterSize>
struct ReductionConfiguration_1
{
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid Configuration!");
static constexpr int BlockSize_ = BlockSize;
static constexpr int MThreadClusterSize_ = MThreadClusterSize;
static constexpr int KThreadClusterSize_ = KThreadClusterSize;
};
template <int InSrcVectorDim,
int InSrcVectorSize,
int OutDstVectorSize,
int MThreadSliceSize,
int KThreadSliceSize>
struct ReductionConfiguration_2
{
static constexpr int InSrcVectorDim_ = InSrcVectorDim;
static constexpr int InSrcVectorSize_ = InSrcVectorSize;
static constexpr int OutDstVectorSize_ = OutDstVectorSize;
static constexpr int MThreadSliceSize_ = MThreadSliceSize;
static constexpr int KThreadSliceSize_ = KThreadSliceSize;
};
using reduce_configuration_1_instances = std::tuple<
// clang-format off
// BlockSize | MThreadClusterSize | KThreadClusterSize
ReductionConfiguration_1<256, 128, 2>,
ReductionConfiguration_1<256, 64, 4>,
ReductionConfiguration_1<256, 32, 8>,
ReductionConfiguration_1<256, 16, 16>,
ReductionConfiguration_1<256, 8, 32>,
ReductionConfiguration_1<256, 4, 64>,
ReductionConfiguration_1<256, 2, 128>,
ReductionConfiguration_1<256, 1, 256>
// clang-format on
>;
#define QUICK_REDUCE_TEST 1
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,192 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_multiblock_atomic_add.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 2, 2, 2, 1>,
ReductionConfiguration_2<0, 1, 1, 2, 1>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 2, 2, 1, 2>,
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3>
// clang-format on
>;
#else
using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 4, 4, 8, 1>,
ReductionConfiguration_2<0, 4, 4, 4, 1>,
ReductionConfiguration_2<0, 2, 2, 2, 1>,
ReductionConfiguration_2<1, 4, 1, 1, 8>,
ReductionConfiguration_2<1, 4, 1, 1, 4>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
// special instances
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<0, 1, 1, 5, 1>,
ReductionConfiguration_2<0, 1, 1, 7, 1>,
ReductionConfiguration_2<0, 1, 1, 11, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3>,
ReductionConfiguration_2<1, 1, 1, 1, 5>,
ReductionConfiguration_2<1, 1, 1, 1, 7>,
ReductionConfiguration_2<1, 1, 1, 1, 11>
// clang-format on
>;
#endif
template <typename AccDataType, ReduceTensorOp_t ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType =
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
void add_device_reduce_instance_multiblock_atomic_add(
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>&
device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
static_assert(IndicesOpt == ReduceTensorIndices_t::NO_INDICES,
"AtomicAdd can only be used with reduction operations without indices!");
constexpr bool op_acceptable =
(ReduceOpId == ReduceTensorOp_t::ADD || ReduceOpId == ReduceTensorOp_t::MUL ||
ReduceOpId == ReduceTensorOp_t::AVG || ReduceOpId == ReduceTensorOp_t::NORM1);
constexpr bool out_type_acceptable =
(std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value);
if constexpr(!op_acceptable || !out_type_acceptable)
return;
else
{
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
using cfg1 =
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<
0,
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_multiblock_atomic_add{}))>;
using ReduceOpInstance = DeviceReduceMultiBlockAtomicAdd<InDataType,
AccDataType,
OutDataType,
Rank,
ReduceDims,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
NeedIndices,
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
}
};
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,29 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_atomic_add.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 4, 0);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,29 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_atomic_add.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,29 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_atomic_add.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,175 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_multiblock_partial_reduce = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 1, 1, 2, 1>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3>
// clang-format on
>;
#else
using reduce_configuration_2_instances_multiblock_partial_reduce = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 4, 1, 8, 1>,
ReductionConfiguration_2<0, 4, 1, 4, 1>,
ReductionConfiguration_2<0, 2, 1, 2, 1>,
ReductionConfiguration_2<1, 4, 1, 1, 8>,
ReductionConfiguration_2<1, 4, 1, 1, 4>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
// special instances
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<0, 1, 1, 5, 1>,
ReductionConfiguration_2<0, 1, 1, 7, 1>,
ReductionConfiguration_2<0, 1, 1, 11, 1>,
ReductionConfiguration_2<0, 1, 1, 1, 3>,
ReductionConfiguration_2<0, 1, 1, 1, 5>,
ReductionConfiguration_2<0, 1, 1, 1, 7>,
ReductionConfiguration_2<0, 1, 1, 1, 11>
// clang-format on
>;
#endif
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
using deviceReduceMultiBlockPartialReducePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
void add_device_reduce_instance_multiblock_partial_reduce(
std::vector<deviceReduceMultiBlockPartialReducePtrType<AccDataType, ReduceOpId>>&
device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::
InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, false>::
AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
using cfg1 =
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<
0,
std::tuple_size<reduce_configuration_2_instances_multiblock_partial_reduce>::value,
1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_multiblock_partial_reduce{}))>;
using ReduceOpInstance = DeviceReduceMultiBlockPartialReduce<InDataType,
AccDataType,
OutDataType,
Rank,
ReduceDims,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
NeedIndices,
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
};
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceMultiBlockPartialReducePtrType<compT, ReduceOpId>> & \
device_op_instances)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
extern template void \
add_device_reduce_instance_multiblock_partial_reduce<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, false>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,41 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F16_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,32 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F32_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,45 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,26 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,53 @@
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
// Will be moved to use MultiBlockAtomicAdd
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,164 @@
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_HPP
#define DEVICE_REDUCE_INSTANCE_THREADWISE_HPP
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_threadwise = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 2, 2, 2, 1>,
ReductionConfiguration_2<0, 1, 1, 2, 1>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 2, 2, 1, 2>,
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3>
// clang-format on
>;
#else
using reduce_configuration_2_instances_threadwise = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<0, 4, 4, 8, 1>,
ReductionConfiguration_2<0, 4, 4, 4, 1>,
ReductionConfiguration_2<0, 2, 2, 2, 1>,
ReductionConfiguration_2<1, 4, 1, 1, 8>,
ReductionConfiguration_2<1, 4, 1, 1, 4>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
// special instances
ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<0, 1, 1, 5, 1>,
ReductionConfiguration_2<0, 1, 1, 7, 1>,
ReductionConfiguration_2<0, 1, 1, 11, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3>,
ReductionConfiguration_2<1, 1, 1, 1, 5>,
ReductionConfiguration_2<1, 1, 1, 1, 7>,
ReductionConfiguration_2<1, 1, 1, 1, 11>
// clang-format on
>;
#endif
template <typename AccDataType, ReduceTensorOp_t ReduceOpId>
using deviceReduceThreadWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
typename ReduceDims,
ReduceTensorOp_t ReduceOpId,
NanPropagation_t NanOpt,
ReduceTensorIndices_t IndicesOpt>
void add_device_reduce_instance_threadwise(
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true;
using cfg1 = ReductionConfiguration_1<256, 256, 1>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_threadwise{}))>;
using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
AccDataType,
OutDataType,
Rank,
ReduceDims,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
NeedIndices,
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
};
#define ADD_THREADWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_THREADWISE_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
Sequence<__VA_ARGS__>, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
ADD_THREADWISE_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
Rank, \
__VA_ARGS__)
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,41 @@
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,32 @@
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 0);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,32 @@
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 4, 0);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0, 1, 2); // for ADD
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 0);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0, 1, 2); // for AVG
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0, 1, 2); // for NORM2
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0, 1, 2); // for MIN
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0, 1, 2); // for MAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0, 1, 2); // for AMAX
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 0); //
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); //
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif