mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Absolute include path (#281)
* ad gelu and fast_gelu * added GeLU and fast GeLU * clean up * add gemm+fastgelu example * add gemm+gelu instances * update profiler * clean up * clean up * adding gemm+bias+activation * clean * adding bias * clean * adding gemm multiple d * debugging * add gemm bias add fastgelu * rename, clean * refactoring; add readme * refactor * refactor * refactor * refactor * refactor * refactor * fix * fix * update example * update example * rename * update example * add ckProfiler * clean * clean * clean * clean * add client app example * update readme * delete obselete files * remove old client app * delete old file * cleaning * clean * remove half * fix header path * fix header path * fix header path * fix header path * fix header path * fix header path for all examples * fix header path * fix header path * fix header path * fix header path * fix header path * fix header path * fix header path * fix header path * fix header path * revert client app example * clean build * fix build * temporary disable client test on Jenkins * clean * clean * clean
This commit is contained in:
@@ -1,54 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "stream_config.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device_base.hpp"
|
||||
|
||||
struct DeviceConvFwdPtr_t
|
||||
{
|
||||
using BaseArgument = ck::tensor_operation::device::BaseArgument;
|
||||
using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
|
||||
|
||||
struct DeviceConvFwdPtrImpl;
|
||||
std::unique_ptr<DeviceConvFwdPtrImpl> pImpl;
|
||||
DeviceConvFwdPtr_t();
|
||||
~DeviceConvFwdPtr_t();
|
||||
DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&);
|
||||
DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&);
|
||||
DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete;
|
||||
DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete;
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(void* in_ptr,
|
||||
void* wei_ptr,
|
||||
void* out_ptr,
|
||||
size_t N,
|
||||
size_t K,
|
||||
size_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
const; // in,wei and out element ops are ignored for now since even if we change them, they
|
||||
// cant be linked
|
||||
std::unique_ptr<BaseInvoker>
|
||||
MakeInvokerPointer() const; // requires including BaseInvoker headers
|
||||
std::string GetTypeString();
|
||||
bool IsSupportedArgument(const BaseArgument* arg_ptr);
|
||||
};
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
|
||||
std::vector<DeviceConvFwdPtr_t>& instances);
|
||||
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
|
||||
std::vector<DeviceConvFwdPtr_t>& instances);
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
|
||||
std::vector<DeviceConvFwdPtr_t>& instances);
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
|
||||
std::vector<DeviceConvFwdPtr_t>& instances);
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
|
||||
std::vector<DeviceConvFwdPtr_t>& instances);
|
||||
@@ -1,7 +1,6 @@
|
||||
#ifndef CONV_COMMON_HPP
|
||||
#define CONV_COMMON_HPP
|
||||
#pragma once
|
||||
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
|
||||
template <typename... InDesc,
|
||||
typename... WeiDesc,
|
||||
@@ -73,18 +72,3 @@ calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDes
|
||||
|
||||
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline auto activ(T v, const ck::ActivTypeEnum activ_type)
|
||||
{
|
||||
const T alpha = 0.3;
|
||||
switch(activ_type)
|
||||
{
|
||||
case ck::ActivTypeEnum::None: return v;
|
||||
case ck::ActivTypeEnum::LeakyRelu: return (v >= 0 ? v : alpha * v);
|
||||
case ck::ActivTypeEnum::Sigmoid: return (1 / (1 + exp(-v)));
|
||||
default: throw std::runtime_error("unsupported activ type"); break;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#include "stream_config.hpp"
|
||||
#include "ck/options.hpp"
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
|
||||
{
|
||||
for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x)
|
||||
{
|
||||
p[i] = x;
|
||||
}
|
||||
}
|
||||
|
||||
inline void hip_check_error(hipError_t x)
|
||||
{
|
||||
if(x != hipSuccess)
|
||||
{
|
||||
std::ostringstream ss;
|
||||
ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__
|
||||
<< "in function: " << __func__;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
struct DeviceMem
|
||||
{
|
||||
DeviceMem() = delete;
|
||||
DeviceMem(std::size_t mem_size);
|
||||
void* GetDeviceBuffer();
|
||||
std::size_t GetBufferSize();
|
||||
void ToDevice(const void* p);
|
||||
void FromDevice(void* p);
|
||||
void SetZero();
|
||||
template <typename T>
|
||||
void SetValue(T x)
|
||||
{
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be set");
|
||||
}
|
||||
|
||||
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
|
||||
}
|
||||
~DeviceMem();
|
||||
|
||||
void* mpDeviceBuf;
|
||||
std::size_t mMemSize;
|
||||
};
|
||||
|
||||
struct KernelTimerImpl;
|
||||
|
||||
struct KernelTimer
|
||||
{
|
||||
KernelTimer();
|
||||
~KernelTimer();
|
||||
void Start();
|
||||
void End();
|
||||
float GetElapsedTime() const;
|
||||
|
||||
std::unique_ptr<KernelTimerImpl> impl;
|
||||
};
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
const int nrepeat = 10;
|
||||
|
||||
printf("Warm up 1 time\n");
|
||||
|
||||
// warm up
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
}
|
||||
else
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
37
library/include/ck/library/host_tensor/device_memory.hpp
Normal file
37
library/include/ck/library/host_tensor/device_memory.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
|
||||
{
|
||||
for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x)
|
||||
{
|
||||
p[i] = x;
|
||||
}
|
||||
}
|
||||
|
||||
struct DeviceMem
|
||||
{
|
||||
DeviceMem() = delete;
|
||||
DeviceMem(std::size_t mem_size);
|
||||
void* GetDeviceBuffer();
|
||||
std::size_t GetBufferSize();
|
||||
void ToDevice(const void* p);
|
||||
void FromDevice(void* p);
|
||||
void SetZero();
|
||||
template <typename T>
|
||||
void SetValue(T x)
|
||||
{
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be set");
|
||||
}
|
||||
|
||||
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
|
||||
}
|
||||
~DeviceMem();
|
||||
|
||||
void* mpDeviceBuf;
|
||||
std::size_t mMemSize;
|
||||
};
|
||||
@@ -1,8 +0,0 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename TensorDesc>
|
||||
void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout)
|
||||
{
|
||||
ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os);
|
||||
}
|
||||
@@ -1,37 +1,11 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef GUARD_HOST_COMMON_UTIL_HPP
|
||||
#define GUARD_HOST_COMMON_UTIL_HPP
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -95,8 +69,5 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
|
||||
return (values);
|
||||
}
|
||||
|
||||
}; // namespace host_common
|
||||
|
||||
}; // namespace ck
|
||||
|
||||
#endif
|
||||
} // namespace host_common
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename AType,
|
||||
|
||||
@@ -1,42 +1,15 @@
|
||||
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef HOST_REDUCTION_HPP_
|
||||
#define HOST_REDUCTION_HPP_
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <functional>
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "host_common_util.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
#include "ck/library/host_tensor/host_common_util.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
template <int NDim>
|
||||
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
|
||||
@@ -400,5 +373,3 @@ struct ReductionHost
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#ifndef HOST_TENSOR_HPP
|
||||
#define HOST_TENSOR_HPP
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
@@ -8,7 +7,8 @@
|
||||
#include <utility>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include "data_type.hpp"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
template <typename Range>
|
||||
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
|
||||
@@ -413,5 +413,3 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
|
||||
return linf_error;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_0
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
#ifndef DEBUG_HPP
|
||||
#define DEBUG_HPP
|
||||
|
||||
namespace debug {
|
||||
namespace debug_driver_gemm_xdlops_v2r3 {
|
||||
|
||||
// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping
|
||||
static ck::index_t M01 = 1;
|
||||
static ck::index_t N01 = 1;
|
||||
|
||||
} // namespace debug_driver_gemm_xdlops_v2r3
|
||||
} // namespace debug
|
||||
#endif
|
||||
@@ -1,220 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
ck::ActivTypeEnum activ_type,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename AddLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
|
||||
const InLengths& in_n_c0_hi_wi_c1_lengths,
|
||||
const WeiLengths& wei_k_c0_y_x_c1_lengths,
|
||||
const AddLengths& add_n_k0_hox2_wox2_k1_lengths,
|
||||
const OutLengths& out_n_k0_ho_wo_k1_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
|
||||
const Tensor<TInWei>& wei_k_c0_y_x_c1,
|
||||
const Tensor<TOut>& bias_k0_k1,
|
||||
const Tensor<TOut>& add_n_k0_hox2_wox2_k1,
|
||||
Tensor<TOut>& add_n_k0_hox2_wox2_k1_out,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
|
||||
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
|
||||
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
|
||||
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
|
||||
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
|
||||
|
||||
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
|
||||
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
|
||||
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
|
||||
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_lengths[I0];
|
||||
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
|
||||
const auto X = wei_k_c0_y_x_c1_lengths[I3];
|
||||
|
||||
const auto Hox2 = add_n_k0_hox2_wox2_k1_lengths[I2];
|
||||
const auto Wox2 = add_n_k0_hox2_wox2_k1_lengths[I3];
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
|
||||
DeviceMem add_n_k0_hox2_wox2_k1_device_buf(sizeof(TOut) *
|
||||
add_n_k0_hox2_wox2_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
|
||||
add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
|
||||
|
||||
constexpr index_t InWeiVectorSize = 8;
|
||||
|
||||
if(C1 % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 64;
|
||||
|
||||
constexpr index_t E1 = C0 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t E1PerBlock = C0;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
|
||||
#elif 1
|
||||
constexpr auto BlockSize = 64;
|
||||
|
||||
constexpr auto KPerBlock = 8;
|
||||
constexpr auto HoPerBlock = 8;
|
||||
constexpr auto WoPerBlock = 32;
|
||||
|
||||
constexpr auto E1 = 2 * 9;
|
||||
constexpr auto E2 = 1;
|
||||
constexpr auto K2 = 2;
|
||||
constexpr auto E1PerBlock = 2;
|
||||
|
||||
constexpr auto KPerThread = KPerBlock;
|
||||
constexpr auto HoPerThread = 2;
|
||||
constexpr auto WoPerThread = 2;
|
||||
constexpr auto EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
|
||||
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
|
||||
|
||||
constexpr auto ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr auto ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
constexpr auto BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr auto CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
|
||||
#endif
|
||||
|
||||
const auto in_n_c0_hi_wi_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
|
||||
const auto wei_k_c0_y_x_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
|
||||
const auto add_n_k0_hox2_wox2_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
constexpr auto conv_driver =
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
activ_type>{};
|
||||
|
||||
std::cerr << "conv_bias_activ_resize_add_input_"
|
||||
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
|
||||
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_addout_n" << N << "k" << K0
|
||||
<< "h" << Ho * 2 << "w" << Wo * 2 << "k" << K1 << std::endl;
|
||||
|
||||
for(int i = 0; i < 5; i++)
|
||||
{
|
||||
|
||||
const auto ave_time =
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
add_n_k0_hox2_wox2_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
add_n_k0_hox2_wox2_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
|
||||
0);
|
||||
|
||||
add_n_k0_hox2_wox2_k1_device_buf.FromDevice(add_n_k0_hox2_wox2_k1_out.mData.data());
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
#include "debug.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
I0,
|
||||
I0,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
// clang-format off
|
||||
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
//clang-format on
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -1,423 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-:
|
||||
// gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
// clang-format off
|
||||
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
// clang-format on
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
{
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
i_ytilde,
|
||||
i_xtilde,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0);
|
||||
|
||||
if(GemmK0 != 0)
|
||||
{
|
||||
ave_time += driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -1,389 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations&,
|
||||
const InLeftPads&,
|
||||
const InRightPads&,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
// clang-format off
|
||||
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
// clang-format on
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
|
||||
out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -1,256 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename GridSizeType>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_c_hi_wi,
|
||||
Tensor<TWei>& wei_k_c_y_x,
|
||||
const Tensor<TOut>& out_n_k_ho_wo,
|
||||
GridSizeType desired_grid_size,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>;
|
||||
// using vector load 4, so config's wo*ho must be a multiple of 4
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto N = in_n_c_hi_wi_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_desc.GetLength(I1);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_desc.GetLength(I3);
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
<< std::endl;
|
||||
const auto descs =
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{},
|
||||
GemmKBatch,
|
||||
GemmKPad);
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
|
||||
|
||||
const auto driver_gemm_xdlops =
|
||||
driver_gemm_xdlops_v2r4<BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
3,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
3,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false,
|
||||
true,
|
||||
true>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
0);
|
||||
// copy result back to host
|
||||
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
|
||||
}
|
||||
@@ -1,234 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_c_hi_wi,
|
||||
Tensor<TWei>& wei_k_c_y_x,
|
||||
const Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
// using vector load 4, so config's wo*ho must be a multiple of 4
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
// using vector load 4, so config's wo*ho must be a multiple of 4
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 1, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename GridSizeType>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_hi_wi_c,
|
||||
Tensor<TWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
GridSizeType desired_grid_size,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_desc.GetLength(I2);
|
||||
|
||||
const auto GemmM = Y * X * C;
|
||||
const auto GemmN = K;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
<< std::endl;
|
||||
|
||||
const auto descs =
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{},
|
||||
GemmKBatch,
|
||||
GemmKPad);
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true,
|
||||
true>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
0);
|
||||
// copy result back to host
|
||||
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
#include "debug.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_hi_wi_c,
|
||||
Tensor<TWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true,
|
||||
true>(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
|
||||
}
|
||||
@@ -1,456 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename GridSizeType>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_hi_wi_c,
|
||||
Tensor<TWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
GridSizeType desired_grid_size,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 64;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 64;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_desc.GetLength(I2);
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
<< std::endl;
|
||||
|
||||
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{},
|
||||
GemmKBatch,
|
||||
GemmKPad);
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
|
||||
|
||||
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 3, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true,
|
||||
true>;
|
||||
|
||||
// timing
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// verification
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
0);
|
||||
// copy result back to host
|
||||
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_N1,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 1;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
||||
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_dlops_v1r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM110Xs,
|
||||
GemmM11N11ThreadClusterN110Xs,
|
||||
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
|
||||
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -1,600 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto
|
||||
MakePaddedGridDescriptors(const AGridDesc_K0Raw_MRaw_K1& a_grid_desc_k0raw_mraw_k1,
|
||||
const BGridDesc_K0Raw_NRaw_K1& b_grid_desc_k0raw_nraw_k1,
|
||||
const CGridDesc_MRaw_NRaw& c_grid_desc_mraw_nraw)
|
||||
{
|
||||
const auto K0Raw = a_grid_desc_k0raw_mraw_k1.GetLength(I0);
|
||||
const auto K1 = a_grid_desc_k0raw_mraw_k1.GetLength(I2);
|
||||
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
|
||||
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
|
||||
|
||||
const auto K0Pad = math::integer_least_multiple(K0Raw, K0PerBlock) - K0Raw;
|
||||
const auto MPad = math::integer_least_multiple(MRaw, MPerBlock) - MRaw;
|
||||
const auto NPad = math::integer_least_multiple(NRaw, NPerBlock) - NRaw;
|
||||
|
||||
// A
|
||||
const auto a_grid_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(DoPad_K0 && DoPad_M)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(DoPad_K0 && !DoPad_M)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_pass_through_transform(MRaw),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(!DoPad_K0 && DoPad_M)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_tuple(make_pass_through_transform(K0Raw),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return a_grid_desc_k0raw_mraw_k1;
|
||||
}
|
||||
}();
|
||||
|
||||
// B
|
||||
const auto b_grid_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(DoPad_K0 && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(DoPad_K0 && !DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_pass_through_transform(NRaw),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(!DoPad_K0 && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_tuple(make_pass_through_transform(K0Raw),
|
||||
make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return b_grid_desc_k0raw_nraw_k1;
|
||||
}
|
||||
}();
|
||||
|
||||
// C
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(DoPad_M && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(DoPad_M && !DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(!DoPad_M && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
reutnr c_grid_desc_m_n;
|
||||
}
|
||||
}();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
#if 0 // debug
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
#else
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0];
|
||||
|
||||
const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0);
|
||||
const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1);
|
||||
const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw;
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
#if 0
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
#else
|
||||
const auto out_gemmmraw_gemmn_grid_desc = descs[I2];
|
||||
|
||||
const auto GemmN = out_gemmmraw_gemmn_grid_desc.GetLength(I1);
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
#endif
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
ck::ActivTypeEnum activ_type,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
|
||||
const InLengths& in_n_c0_hi_wi_c1_lengths,
|
||||
const WeiLengths& wei_k_c0_y_x_c1_lengths,
|
||||
const OutLengths& out_n_k0_ho_wo_k1_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
|
||||
const Tensor<TInWei>& wei_k_c0_y_x_c1,
|
||||
const Tensor<TOut>& bias_k0_k1,
|
||||
Tensor<TOut>& out_n_k0_ho_wo_k1,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
|
||||
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
|
||||
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
|
||||
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
|
||||
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
|
||||
|
||||
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
|
||||
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
|
||||
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
|
||||
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_lengths[I0];
|
||||
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
|
||||
const auto X = wei_k_c0_y_x_c1_lengths[I3];
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
|
||||
|
||||
constexpr index_t InWeiVectorSize = 8;
|
||||
|
||||
if(C1 % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 64;
|
||||
|
||||
constexpr index_t E1 = C0 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t E1PerBlock = C0;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
|
||||
#elif 1
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t E1 = 2 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t E1PerBlock = 2;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
|
||||
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
|
||||
#endif
|
||||
|
||||
if(KPerThread % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
const auto in_n_c0_hi_wi_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
|
||||
const auto wei_k_c0_y_x_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
constexpr auto conv_driver =
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
activ_type>{};
|
||||
|
||||
std::cerr << "conv_bias_activ_input_"
|
||||
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
|
||||
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0
|
||||
<< "h" << Ho << "w" << Wo << "k" << K1 << std::endl;
|
||||
|
||||
for(int i = 0; i < 5; i++)
|
||||
{
|
||||
|
||||
const auto ave_time =
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
}
|
||||
@@ -1,241 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 1;
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 2;
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x,
|
||||
in_desc_n_c_hi_wi,
|
||||
out_desc_n_k_ho_wo,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GN0>{},
|
||||
Number<GK1>{});
|
||||
|
||||
const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto in_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto out_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
||||
|
||||
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_contraction_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
|
||||
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
|
||||
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
CThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_grid_step_hacks),
|
||||
decltype(in_grid_step_hacks),
|
||||
decltype(out_grid_step_hacks),
|
||||
decltype(wei_grid_move_slice_window_step_hacks),
|
||||
decltype(in_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_grid_desc_gk0_gm0_gm1_gk1,
|
||||
in_grid_desc_gk0_gn0_gn1_gk1,
|
||||
out_grid_desc_gm0_gm1_gn0_gn1,
|
||||
wei_grid_step_hacks,
|
||||
in_grid_step_hacks,
|
||||
out_grid_step_hacks,
|
||||
wei_grid_move_slice_window_step_hacks,
|
||||
in_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -1,212 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
ck::ActivTypeEnum activ_type,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename MaxLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
|
||||
const InLengths& in_n_c0_hi_wi_c1_lengths,
|
||||
const WeiLengths& wei_k_c0_y_x_c1_lengths,
|
||||
const MaxLengths& max_n_k0_hx_wx_k1_lengths,
|
||||
const OutLengths& out_n_k0_ho_wo_k1_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
|
||||
const Tensor<TInWei>& wei_k_c0_y_x_c1,
|
||||
const Tensor<TOut>& bias_k0_k1,
|
||||
Tensor<TOut>& out_n_k0_ho_wo_k1,
|
||||
Tensor<TOut>& max_n_k0_hx_wx_k1,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
|
||||
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
|
||||
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
|
||||
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
|
||||
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
|
||||
|
||||
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
|
||||
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
|
||||
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
|
||||
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_lengths[I0];
|
||||
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
|
||||
const auto X = wei_k_c0_y_x_c1_lengths[I3];
|
||||
|
||||
const auto Hx = max_n_k0_hx_wx_k1_lengths[I2];
|
||||
const auto Wx = max_n_k0_hx_wx_k1_lengths[I3];
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
DeviceMem max_n_k0_hx_wx_k1_device_buf(sizeof(TOut) *
|
||||
max_n_k0_hx_wx_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
|
||||
max_n_k0_hx_wx_k1_device_buf.ToDevice(max_n_k0_hx_wx_k1.mData.data());
|
||||
|
||||
constexpr index_t InWeiVectorSize = 8;
|
||||
|
||||
if(C1 % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 64;
|
||||
|
||||
constexpr index_t E1 = C0 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t E1PerBlock = C0;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
|
||||
#elif 1
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t E1 = 2 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t E1PerBlock = 2;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
|
||||
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
|
||||
#endif
|
||||
|
||||
if(KPerThread % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
const auto in_n_c0_hi_wi_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
|
||||
const auto wei_k_c0_y_x_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
|
||||
const auto max_n_k0_hx_wx_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
constexpr auto conv_driver =
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
activ_type>{};
|
||||
|
||||
std::cerr << "conv_bias_activ_maxpool_input_"
|
||||
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
|
||||
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0
|
||||
<< "h" << Ho << "w" << Wo << "k" << K1 << "_maxpoolout_n" << N << "k" << K0 << "h"
|
||||
<< Ho / 2 << "w" << Wo / 2 << "k" << K1 << std::endl;
|
||||
|
||||
for(int i = 0; i < 5; i++)
|
||||
{
|
||||
|
||||
const auto ave_time =
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
max_n_k0_hx_wx_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(max_n_k0_hx_wx_k1_device_buf.GetDeviceBuffer()),
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
max_n_k0_hx_wx_k1_device_buf.FromDevice(max_n_k0_hx_wx_k1.mData.data());
|
||||
}
|
||||
@@ -1,463 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_kn_mn(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_kn_nm(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -1,463 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_nk_mn(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_nk_nm(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -1,463 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
@@ -1,291 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -1,564 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
#if 1
|
||||
// non-padded GEMM
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
#else
|
||||
// padded GEMM
|
||||
const auto a_k0_m_k1_grid_desc_tmp =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_right_pad_transform(M, MRightPad),
|
||||
make_pass_through_transform(K1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc_tmp,
|
||||
make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
#endif
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
@@ -1,347 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -1,286 +0,0 @@
|
||||
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
ck::index_t GM1PerBlockGM11,
|
||||
ck::index_t GN1PerBlockGN11,
|
||||
ck::index_t GK0PerBlock,
|
||||
ck::index_t BM1PerThreadBM11,
|
||||
ck::index_t BN1PerThreadBN11,
|
||||
ck::index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs,
|
||||
typename BM10BN10ThreadClusterBN10Xs,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float
|
||||
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction =
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
||||
{
|
||||
throw std::runtime_error("wrong! "
|
||||
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||
"GM0_GM1_GN0_GN1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
|
||||
|
||||
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
|
||||
|
||||
// c_grid_block_cluster_blockid_to_gm10_gn10
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
|
||||
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
|
||||
|
||||
{
|
||||
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -1,429 +0,0 @@
|
||||
#ifndef DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
#define DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t E1_,
|
||||
ck::index_t E2_,
|
||||
ck::index_t K2_,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t E1PerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E2,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_E2,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_E2,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_K,
|
||||
ck::ActivTypeEnum activ_type>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Add,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ck::TensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatC* __restrict__ p_bias_grid,
|
||||
FloatC* __restrict__ p_d_grid,
|
||||
const int nrepeat) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
|
||||
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
|
||||
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
|
||||
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto Hox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I2);
|
||||
const auto Wox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto OutRightPadHx = OutRightPadH * 2;
|
||||
const auto OutRightPadWx = OutRightPadW * 2;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
const auto E = C0 * Y * X;
|
||||
|
||||
constexpr auto E1 = Number<E1_>{};
|
||||
constexpr auto E2 = Number<E2_>{};
|
||||
constexpr auto K2 = Number<K2_>{};
|
||||
|
||||
const auto E0 = E / E1;
|
||||
|
||||
// weight tensor
|
||||
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_pass_through_transform(C0 * Y * X),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
|
||||
|
||||
const auto a_e0_e1_k_e2_grid_desc =
|
||||
transform_tensor_descriptor(a_e_k_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
|
||||
in_n_c0_hip_wip_e2_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c0_y_ho_x_wo_e2_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(
|
||||
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_e_n_ho_wo_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
// output tensor
|
||||
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, OutRightPadH),
|
||||
make_pad_transform(Wo, I0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// add tensor
|
||||
const auto d_k_n_hopx2_wopx2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Hox2, I0, OutRightPadHx),
|
||||
make_pad_transform(Wox2, I0, OutRightPadWx)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E1 % E1PerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
|
||||
constexpr auto a_e0_e1_k_e2_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
|
||||
make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
|
||||
);
|
||||
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
|
||||
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
// clang-format on
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_e0_e1_k_e2_grid_desc),
|
||||
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
decltype(d_k_n_hopx2_wopx2_grid_desc),
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
Sequence<2, 3, 0, 1, 4>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
|
||||
9,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
|
||||
1,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
decltype(a_e0_e1_k_e2_global_step_hacks),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks),
|
||||
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto a_e0_e1_k0_k1_e2_grid_desc =
|
||||
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
|
||||
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
|
||||
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
|
||||
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
|
||||
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
|
||||
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
|
||||
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(
|
||||
d_k_n_hopx2_wopx2_grid_desc);
|
||||
|
||||
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
|
||||
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
|
||||
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
|
||||
|
||||
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_e0_block_loop = E0 > 1;
|
||||
|
||||
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
|
||||
|
||||
const auto cblockid_to_k_n_h_w_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
|
||||
decltype(cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_resize_add<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_resize_add<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -1,386 +0,0 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t E1_,
|
||||
ck::index_t E2_,
|
||||
ck::index_t K2_,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t E1PerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E2,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_E2,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_E2,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_K,
|
||||
ck::ActivTypeEnum activ_type>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatC* __restrict__ p_bias_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const int nrepeat) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
|
||||
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
|
||||
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
|
||||
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
|
||||
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
|
||||
#else
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
#endif
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
const auto E = C0 * Y * X;
|
||||
|
||||
constexpr auto E1 = Number<E1_>{};
|
||||
constexpr auto E2 = Number<E2_>{};
|
||||
constexpr auto K2 = Number<K2_>{};
|
||||
|
||||
const auto E0 = E / E1;
|
||||
|
||||
// weight tensor
|
||||
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_pass_through_transform(C0 * Y * X),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
|
||||
|
||||
const auto a_e0_e1_k_e2_grid_desc =
|
||||
transform_tensor_descriptor(a_e_k_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
|
||||
in_n_c0_hip_wip_e2_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c0_y_ho_x_wo_e2_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(
|
||||
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_e_n_ho_wo_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
// output tensor
|
||||
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, OutRightPadH),
|
||||
make_pad_transform(Wo, I0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E1 % E1PerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
|
||||
constexpr auto a_e0_e1_k_e2_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
|
||||
make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
|
||||
);
|
||||
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
|
||||
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
// clang-format on
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_e0_e1_k_e2_grid_desc),
|
||||
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
Sequence<2, 3, 0, 1, 4>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
|
||||
9,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, H2, W0, W1, W2
|
||||
1,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
decltype(a_e0_e1_k_e2_global_step_hacks),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto a_e0_e1_k0_k1_e2_grid_desc =
|
||||
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
|
||||
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
|
||||
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
|
||||
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
|
||||
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
|
||||
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
|
||||
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_e0_block_loop = E0 > 1;
|
||||
|
||||
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
|
||||
|
||||
const auto cblockid_to_k_n_h_w_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
|
||||
decltype(cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -1,440 +0,0 @@
|
||||
#ifndef DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
#define DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t E1_,
|
||||
ck::index_t E2_,
|
||||
ck::index_t K2_,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t E1PerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E2,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_E2,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_E2,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_K,
|
||||
ck::ActivTypeEnum activ_type>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... MaxPool,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ck::TensorDescriptor<MaxPool...>& max_n_k0_hx_wx_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatC* __restrict__ p_bias_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatC* __restrict__ p_d_grid,
|
||||
const int nrepeat) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
|
||||
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
|
||||
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
|
||||
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto Hx = max_n_k0_hx_wx_k1_global_desc.GetLength(I2);
|
||||
const auto Wx = max_n_k0_hx_wx_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
|
||||
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto OutRightPadHx = Number<OutRightPadH / 2>{};
|
||||
const auto OutRightPadWx = Number<OutRightPadW / 2>{};
|
||||
#else
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto OutRightPadHx = OutRightPadH / 2;
|
||||
const auto OutRightPadWx = OutRightPadW / 2;
|
||||
#endif
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
const auto E = C0 * Y * X;
|
||||
|
||||
constexpr auto E1 = Number<E1_>{};
|
||||
constexpr auto E2 = Number<E2_>{};
|
||||
constexpr auto K2 = Number<K2_>{};
|
||||
|
||||
const auto E0 = E / E1;
|
||||
|
||||
// weight tensor
|
||||
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_pass_through_transform(C0 * Y * X),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
|
||||
|
||||
const auto a_e0_e1_k_e2_grid_desc =
|
||||
transform_tensor_descriptor(a_e_k_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
|
||||
in_n_c0_hip_wip_e2_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c0_y_ho_x_wo_e2_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(
|
||||
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_e_n_ho_wo_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
// output tensor
|
||||
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, OutRightPadH),
|
||||
make_pad_transform(Wo, I0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// max tensor
|
||||
const auto d_k_n_hx_wx_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Hx, I0, OutRightPadHx),
|
||||
make_pad_transform(Wx, I0, OutRightPadWx)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E1 % E1PerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
|
||||
constexpr auto a_e0_e1_k_e2_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
|
||||
make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
|
||||
);
|
||||
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
// clang-format on
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
decltype(a_e0_e1_k_e2_grid_desc),
|
||||
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
decltype(d_k_n_hx_wx_grid_desc),
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
Sequence<2, 3, 0, 1, 4>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
|
||||
9,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused
|
||||
// with MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
|
||||
1,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
decltype(a_e0_e1_k_e2_global_step_hacks),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks),
|
||||
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto a_e0_e1_k0_k1_e2_grid_desc =
|
||||
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
|
||||
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
|
||||
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
|
||||
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
|
||||
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
|
||||
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
|
||||
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
|
||||
|
||||
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
|
||||
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
|
||||
|
||||
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_e0_block_loop = E0 > 1;
|
||||
|
||||
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
|
||||
|
||||
const auto cblockid_to_k_n_h_w_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
|
||||
decltype(cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_maxpool<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_maxpool<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
cblockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -1,278 +0,0 @@
|
||||
#ifndef DRIVER_GEMM_DLOPS_V1R2
|
||||
#define DRIVER_GEMM_DLOPS_V1R2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t M1N1ThreadClusterM10,
|
||||
ck::index_t M1N1ThreadClusterN10,
|
||||
ck::index_t M1N1ThreadClusterM11,
|
||||
ck::index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// cblockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -1,275 +0,0 @@
|
||||
#ifndef DRIVER_GEMM_DLOPS_V1R3
|
||||
#define DRIVER_GEMM_DLOPS_V1R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
||||
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// cblockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -1,220 +0,0 @@
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP
|
||||
#define DRIVER_GEMM_XDLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t K1,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K& b_grid_desc_k0_n_k1,
|
||||
const CMNGridDesc& c_grid_desc_m_n,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K,
|
||||
CMNGridDesc,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
{
|
||||
std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", "
|
||||
<< a_grid_desc_k0_m_k1.GetLength(I1) << ", " << a_grid_desc_k0_m_k1.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_grid_desc_k0_n_k1{" << b_grid_desc_k0_n_k1.GetLength(I0) << ", "
|
||||
<< b_grid_desc_k0_n_k1.GetLength(I1) << ", " << b_grid_desc_k0_n_k1.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_grid_desc_m_n{ " << c_grid_desc_m_n.GetLength(I0) << ", "
|
||||
<< c_grid_desc_m_n.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, M01, N01))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
|
||||
const auto block_2_ctile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01);
|
||||
|
||||
using Block2CTileMap = decltype(block_2_ctile_map);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto element_op_ = ElementwiseOperation{};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
element_op_,
|
||||
element_op_,
|
||||
element_op_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
element_op_,
|
||||
element_op_,
|
||||
element_op_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -1,213 +0,0 @@
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R4
|
||||
#define DRIVER_GEMM_XDLOPS_V2R4
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t K1,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
ABK0MK1GridDesc,
|
||||
BBK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
{
|
||||
std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_b_k0_m_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_b_k0_m_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_b_k0_n_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_b_k0_n_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
|
||||
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
const auto c_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch);
|
||||
{
|
||||
std::cout << "gridSize : " << grid_size << std::endl;
|
||||
}
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef REFERENCE_BATCHED_GEMM_HPP
|
||||
#define REFERENCE_BATCHED_GEMM_HPP
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -132,4 +132,3 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,33 +1,10 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#ifndef REFERENCE_CONV_BWD_DATA_HPP
|
||||
#define REFERENCE_CONV_BWD_DATA_HPP
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -351,4 +352,3 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -4,9 +4,8 @@
|
||||
#include <type_traits>
|
||||
#include <sstream>
|
||||
|
||||
#include "stream_config.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP
|
||||
#define REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -187,4 +187,3 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP
|
||||
#define REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -195,4 +195,3 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP
|
||||
#define REFERENCE_GEMM_BIAS_BIAS_2D_HPP
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -131,4 +131,3 @@ struct ReferenceGemmBias2D : public device::BaseOperator
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_HPP
|
||||
#define REFERENCE_GEMM_BIAS_ACTIVATION_HPP
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -134,4 +135,3 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
|
||||
#define REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -142,4 +143,3 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANTCE_HPP
|
||||
#define DEVICE_REDUCE_INSTANTCE_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_blockwise_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_f16_f32_f16.hpp"
|
||||
#include "device_reduce_instance_blockwise_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_blockwise_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
|
||||
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
|
||||
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
|
||||
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
|
||||
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
|
||||
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
|
||||
#include "device_reduce_instance_threadwise_f32_f64_f32.hpp"
|
||||
#include "device_reduce_instance_threadwise_f64_f64_f64.hpp"
|
||||
#include "device_reduce_instance_threadwise_i8_i8_i8.hpp"
|
||||
#include "device_reduce_instance_threadwise_i8_i32_i8.hpp"
|
||||
#include "device_reduce_instance_threadwise_b16_f32_b16.hpp"
|
||||
|
||||
#endif
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp"
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_HPP
|
||||
#pragma once
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -175,7 +174,4 @@ void add_device_reduce_instance_blockwise(
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -53,7 +53,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -40,7 +40,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -28,7 +28,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -51,7 +52,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -27,7 +28,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -51,7 +52,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -23,7 +24,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_blockwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -39,7 +40,4 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_IMPL_COMMON_HPP
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -35,7 +34,4 @@ struct ReductionConfiguration_2
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_HPP
|
||||
#pragma once
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -193,7 +193,4 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -24,7 +24,4 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -24,7 +24,4 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -23,7 +24,4 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -23,7 +24,4 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(float, double, float, 5, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -23,7 +24,4 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_HPP
|
||||
#pragma once
|
||||
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "device_reduce_instance_impl_common.hpp"
|
||||
#include "device_reduce_threadwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -152,7 +151,4 @@ void add_device_reduce_instance_threadwise(
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -53,7 +53,4 @@ ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -40,7 +40,4 @@ ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP
|
||||
#pragma once
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -28,7 +28,4 @@ ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -51,7 +52,4 @@ ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -27,7 +28,4 @@ ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -51,7 +52,4 @@ ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -23,7 +24,4 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP
|
||||
#define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP
|
||||
#pragma once
|
||||
|
||||
#include "device_reduce_instance_threadwise.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -39,7 +40,4 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
|
||||
} // namespace device_reduce_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <half.hpp>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <iterator>
|
||||
@@ -11,7 +10,7 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
@@ -107,8 +106,7 @@ check_err(const std::vector<T>& out,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same<T, half_t>::value || std::is_same<T, half_float::half>::value,
|
||||
bool>::type
|
||||
typename std::enable_if<std::is_same<T, half_t>::value, bool>::type
|
||||
check_err(const std::vector<T>& out,
|
||||
const std::vector<T>& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
|
||||
@@ -9,17 +9,17 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_conv_fwd.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "fill.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "op_instance_engine.hpp"
|
||||
#include "reference_conv_fwd.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/op_instance_engine.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
@@ -9,9 +9,12 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "ck/utility/functional2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
## host_tensor
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include/ck
|
||||
${PROJECT_SOURCE_DIR}/include/ck/utility
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor
|
||||
)
|
||||
|
||||
set(HOST_TENSOR_SOURCE
|
||||
device.cpp
|
||||
device_memory.cpp
|
||||
host_tensor.cpp
|
||||
)
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
#include "device.hpp"
|
||||
|
||||
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
{
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
|
||||
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
|
||||
|
||||
std::size_t DeviceMem::GetBufferSize() { return mMemSize; }
|
||||
|
||||
void DeviceMem::ToDevice(const void* p)
|
||||
{
|
||||
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p)
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
|
||||
|
||||
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
|
||||
|
||||
struct KernelTimerImpl
|
||||
{
|
||||
KernelTimerImpl()
|
||||
{
|
||||
hip_check_error(hipEventCreate(&mStart));
|
||||
hip_check_error(hipEventCreate(&mEnd));
|
||||
}
|
||||
|
||||
~KernelTimerImpl()
|
||||
{
|
||||
hip_check_error(hipEventDestroy(mStart));
|
||||
hip_check_error(hipEventDestroy(mEnd));
|
||||
}
|
||||
|
||||
void Start()
|
||||
{
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(mStart, nullptr));
|
||||
}
|
||||
|
||||
void End()
|
||||
{
|
||||
hip_check_error(hipEventRecord(mEnd, nullptr));
|
||||
hip_check_error(hipEventSynchronize(mEnd));
|
||||
}
|
||||
|
||||
float GetElapsedTime() const
|
||||
{
|
||||
float time;
|
||||
hip_check_error(hipEventElapsedTime(&time, mStart, mEnd));
|
||||
return time;
|
||||
}
|
||||
|
||||
hipEvent_t mStart, mEnd;
|
||||
};
|
||||
|
||||
KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {}
|
||||
|
||||
KernelTimer::~KernelTimer() {}
|
||||
|
||||
void KernelTimer::Start() { impl->Start(); }
|
||||
|
||||
void KernelTimer::End() { impl->End(); }
|
||||
|
||||
float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); }
|
||||
25
library/src/host_tensor/device_memory.cpp
Normal file
25
library/src/host_tensor/device_memory.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include "ck/device_utility/hip_check_error.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
|
||||
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
{
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
|
||||
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
|
||||
|
||||
std::size_t DeviceMem::GetBufferSize() { return mMemSize; }
|
||||
|
||||
void DeviceMem::ToDevice(const void* p)
|
||||
{
|
||||
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p)
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
|
||||
|
||||
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cassert>
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
void HostTensorDescriptor::CalculateStrides()
|
||||
{
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/host/device/include
|
||||
${PROJECT_SOURCE_DIR}/host/solver/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
|
||||
set(CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_fwd_driver_offline_nchwc.cpp)
|
||||
set(CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_add_fwd_driver_offline_nchwc.cpp)
|
||||
set(CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_maxpool_fwd_driver_offline_nchwc.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
|
||||
set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp)
|
||||
set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_fwd_driver_offline_nchwc ${CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE})
|
||||
add_executable(conv_add_fwd_driver_offline_nchwc ${CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE})
|
||||
add_executable(conv_maxpool_fwd_driver_offline_nchwc ${CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE})
|
||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_fwd_driver_offline_nchwc PRIVATE host_tensor)
|
||||
target_link_libraries(conv_add_fwd_driver_offline_nchwc PRIVATE host_tensor)
|
||||
target_link_libraries(conv_maxpool_fwd_driver_offline_nchwc PRIVATE host_tensor)
|
||||
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_driver_offline PRIVATE host_tensor)
|
||||
@@ -1,416 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 0
|
||||
#define USE_CONV_FWD_V5R1_NCHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V5R1NCHWC // 0
|
||||
};
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution_add_nchwc(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
const Tensor<TOut>& add,
|
||||
const Tensor<TOut>& bias,
|
||||
Tensor<TOut>& add_host,
|
||||
Tensor<TOut>& out_host,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ck::ActivTypeEnum activ_type)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) {
|
||||
double v = 0;
|
||||
auto k = k0 * out_host.mDesc.GetLengths()[4] + k1;
|
||||
|
||||
for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
|
||||
for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1)
|
||||
{
|
||||
v += static_cast<const double>(in(n, c0, hi, wi, c1)) *
|
||||
static_cast<const double>(wei(k, c0, y, x, c1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v += bias(k0, k1);
|
||||
v = activ(v, activ_type);
|
||||
|
||||
const int hox2 = ho * 2;
|
||||
const int wox2 = wo * 2;
|
||||
|
||||
out_host(n, k0, ho, wo, k1) = v;
|
||||
|
||||
add_host(n, k0, hox2, wox2, k1) = v + add(n, k0, hox2, wox2, k1);
|
||||
add_host(n, k0, hox2, wox2 + 1, k1) = v + add(n, k0, hox2, wox2 + 1, k1);
|
||||
add_host(n, k0, hox2 + 1, wox2, k1) = v + add(n, k0, hox2 + 1, wox2, k1);
|
||||
add_host(n, k0, hox2 + 1, wox2 + 1, k1) = v + add(n, k0, hox2 + 1, wox2 + 1, k1);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out_host.mDesc.GetLengths()[0],
|
||||
out_host.mDesc.GetLengths()[1],
|
||||
out_host.mDesc.GetLengths()[2],
|
||||
out_host.mDesc.GetLengths()[3],
|
||||
out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 23)
|
||||
{
|
||||
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu;
|
||||
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
|
||||
const bool do_verification = std::stoi(argv[2]);
|
||||
const int init_method = std::stoi(argv[3]);
|
||||
const bool do_log = std::stoi(argv[4]);
|
||||
const int nrepeat = std::stoi(argv[5]);
|
||||
|
||||
const index_t N = std::stoi(argv[6]);
|
||||
const index_t K0 = std::stoi(argv[7]);
|
||||
const index_t K1 = std::stoi(argv[8]);
|
||||
const index_t C0 = std::stoi(argv[9]);
|
||||
const index_t C1 = std::stoi(argv[10]);
|
||||
const index_t Y = std::stoi(argv[11]);
|
||||
const index_t X = std::stoi(argv[12]);
|
||||
const index_t Hi = std::stoi(argv[13]);
|
||||
const index_t Wi = std::stoi(argv[14]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[15]);
|
||||
const index_t conv_stride_w = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[17]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[19]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[21]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[22]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
|
||||
const auto Hox2 = Ho * 2;
|
||||
const auto Wox2 = Wo * 2;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 6)
|
||||
{
|
||||
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
|
||||
|
||||
const bool do_verification = std::stoi(argv[2]);
|
||||
const int init_method = std::stoi(argv[3]);
|
||||
const bool do_log = std::stoi(argv[4]);
|
||||
const int nrepeat = std::stoi(argv[5]);
|
||||
|
||||
constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu;
|
||||
|
||||
#if 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<1080>{};
|
||||
constexpr auto Wi = Number<1920>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
constexpr auto K0 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<540>{};
|
||||
constexpr auto Wi = Number<960>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<270>{};
|
||||
constexpr auto Wi = Number<480>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 1
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto Hi = Number<135>{};
|
||||
constexpr auto Wi = Number<240>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 1
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<32>{};
|
||||
constexpr auto Wi = Number<32>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
constexpr auto K0 = Number<8>{};
|
||||
#endif
|
||||
|
||||
constexpr auto conv_stride_h = I1;
|
||||
constexpr auto conv_stride_w = I1;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
|
||||
constexpr auto Hox2 = Number<Ho * 2>{};
|
||||
constexpr auto Wox2 = Number<Wo * 2>{};
|
||||
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5),
|
||||
add_lengths_host(5), bias_lengths_host(2);
|
||||
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C0);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[4] = static_cast<std::size_t>(C1);
|
||||
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K0 * K1);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C0);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[4] = static_cast<std::size_t>(C1);
|
||||
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K0);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[4] = static_cast<std::size_t>(K1);
|
||||
|
||||
add_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
add_lengths_host[1] = static_cast<std::size_t>(K0);
|
||||
add_lengths_host[2] = static_cast<std::size_t>(Hox2);
|
||||
add_lengths_host[3] = static_cast<std::size_t>(Wox2);
|
||||
add_lengths_host[4] = static_cast<std::size_t>(K1);
|
||||
|
||||
bias_lengths_host[0] = static_cast<std::size_t>(K0);
|
||||
bias_lengths_host[1] = static_cast<std::size_t>(K1);
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<in_data_t> add(add_lengths_host);
|
||||
Tensor<in_data_t> add_device(add_lengths_host);
|
||||
Tensor<in_data_t> add_host(add_lengths_host);
|
||||
Tensor<out_data_t> bias(bias_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(add.mDesc, std::cout << "add: ");
|
||||
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
add.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
|
||||
auto f_make_for_device_nchwc = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1);
|
||||
const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1);
|
||||
const auto add_lengths_dev = make_tuple(N, K0, Hox2, Wox2, K1);
|
||||
const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
add_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHWC
|
||||
if(algo == ConvForwardAlgo::V5R1NCHWC)
|
||||
{
|
||||
const auto tmp = f_make_for_device_nchwc();
|
||||
|
||||
device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t,
|
||||
activ_type>(
|
||||
tmp[I0], // in_lengths_dev
|
||||
tmp[I1], // wei_lengths_dev
|
||||
tmp[I2], // add_lengths_dev
|
||||
tmp[I3], // out_lengths_dev
|
||||
tmp[I4], // conv_strides_dev
|
||||
tmp[I5], // conv_dilations_dev
|
||||
tmp[I6], // in_left_pads_dev
|
||||
tmp[I7], // in_right_pads_dev
|
||||
in,
|
||||
wei,
|
||||
bias,
|
||||
add,
|
||||
add_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_add_nchwc(in,
|
||||
wei,
|
||||
add,
|
||||
bias,
|
||||
add_host,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
activ_type);
|
||||
|
||||
ck::utils::check_err(add_device.mData, add_host.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "add_host: ", add_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "add_device: ", add_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,488 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp"
|
||||
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 0
|
||||
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
enum ConvBackwardDataAlgo
|
||||
{
|
||||
V4R1XDLNHWC, // 0
|
||||
V4R1R2XDLNHWC, // 1
|
||||
};
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_convolution_backward_data(Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
const Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& /* in_right_pads */,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I2];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I3];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I2];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I3];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, k, ho, wo) * wei(k, c, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, c, hi, wi) = v;
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I1];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I2];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I1];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I2];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, ho, wo, k) * wei(k, y, x, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, hi, wi, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto C = Number<192>{};
|
||||
constexpr auto Hi = Number<71>{};
|
||||
constexpr auto Wi = Number<71>{};
|
||||
constexpr auto K = Number<256>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
|
||||
constexpr auto conv_stride_h = I2;
|
||||
constexpr auto conv_stride_w = I2;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in_host(in_lengths_host);
|
||||
Tensor<in_data_t> in_device(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
#if USE_CONV_BWD_V4R1_XDL_NHWC
|
||||
if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_BWD_V4R1R2_XDL_NHWC
|
||||
if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 &&
|
||||
in_right_pad_w == 0)
|
||||
{
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if 1
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_convolution_backward_data(in_host,
|
||||
wei,
|
||||
out,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
ck::utils::check_err(in_device.mData, in_host.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "out : ", out.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in_host : ", in_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in_device: ", in_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,549 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW, // 0
|
||||
V4R4R2NHWC, // 1
|
||||
V6R1NCHW, // 2
|
||||
V4R4R2XDLNCHW, // 3
|
||||
V4R4R4XDLNHWC // 4
|
||||
};
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_convolution_forward(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
if constexpr(is_same<TIn, bhalf_t>::value)
|
||||
{
|
||||
v += ck::type_convert<float>(in(n, c, hi, wi)) *
|
||||
ck::type_convert<float>(wei(k, c, y, x));
|
||||
}
|
||||
else
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<TOut, bhalf_t>::value)
|
||||
{
|
||||
out(n, k, ho, wo) = ck::type_convert<bhalf_t>(static_cast<float>(v));
|
||||
}
|
||||
else
|
||||
{
|
||||
out(n, k, ho, wo) = v;
|
||||
}
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
if constexpr(is_same<TIn, bhalf_t>::value)
|
||||
{
|
||||
v += ck::type_convert<float>(in(n, hi, wi, c)) *
|
||||
ck::type_convert<float>(wei(k, y, x, c));
|
||||
}
|
||||
else
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr(is_same<TOut, bhalf_t>::value)
|
||||
{
|
||||
out(n, ho, wo, k) = ck::type_convert<bhalf_t>(static_cast<float>(v));
|
||||
}
|
||||
else
|
||||
{
|
||||
out(n, ho, wo, k) = v;
|
||||
}
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto C = Number<192>{};
|
||||
constexpr auto Hi = Number<71>{};
|
||||
constexpr auto Wi = Number<71>{};
|
||||
constexpr auto K = Number<256>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
|
||||
constexpr auto conv_stride_h = I1;
|
||||
constexpr auto conv_stride_w = I1;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 0
|
||||
using in_data_t = bhalf_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = bhalf_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R2_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4R2NHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V6R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V6R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R2_XDL_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4R2XDLNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R4_XDL_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4R4XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_convolution_forward(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
ck::utils::check_err(out_device.mData, out_host.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,393 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 0
|
||||
#define USE_CONV_FWD_V5R1_NCHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V5R1NCHWC // 0
|
||||
};
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution_nchwc(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
const Tensor<TOut>& bias,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ck::ActivTypeEnum activ_type)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) {
|
||||
double v = 0;
|
||||
const int k = k0 * out.mDesc.GetLengths()[4] + k1;
|
||||
|
||||
for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1)
|
||||
{
|
||||
v += static_cast<const double>(in(n, c0, hi, wi, c1)) *
|
||||
static_cast<const double>(wei(k, c0, y, x, c1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
v += bias(k0, k1);
|
||||
out(n, k0, ho, wo, k1) = activ(v, activ_type);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3],
|
||||
out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 23)
|
||||
{
|
||||
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu;
|
||||
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
|
||||
const bool do_verification = std::stoi(argv[2]);
|
||||
const int init_method = std::stoi(argv[3]);
|
||||
const bool do_log = std::stoi(argv[4]);
|
||||
const int nrepeat = std::stoi(argv[5]);
|
||||
|
||||
const index_t N = std::stoi(argv[6]);
|
||||
const index_t K0 = std::stoi(argv[7]);
|
||||
const index_t K1 = std::stoi(argv[8]);
|
||||
const index_t C0 = std::stoi(argv[9]);
|
||||
const index_t C1 = std::stoi(argv[10]);
|
||||
const index_t Y = std::stoi(argv[11]);
|
||||
const index_t X = std::stoi(argv[12]);
|
||||
const index_t Hi = std::stoi(argv[13]);
|
||||
const index_t Wi = std::stoi(argv[14]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[15]);
|
||||
const index_t conv_stride_w = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[17]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[19]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[21]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[22]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 6)
|
||||
{
|
||||
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
|
||||
|
||||
const bool do_verification = std::stoi(argv[2]);
|
||||
const int init_method = std::stoi(argv[3]);
|
||||
const bool do_log = std::stoi(argv[4]);
|
||||
const int nrepeat = std::stoi(argv[5]);
|
||||
|
||||
// constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::Sigmoid;
|
||||
constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu;
|
||||
|
||||
#if 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<1080>{};
|
||||
constexpr auto Wi = Number<1920>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<1>{};
|
||||
constexpr auto K1 = Number<4>{};
|
||||
#elif 1
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<1080>{};
|
||||
constexpr auto Wi = Number<1920>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<1080>{};
|
||||
constexpr auto Wi = Number<1920>{};
|
||||
constexpr auto Y = Number<1>{};
|
||||
constexpr auto X = Number<1>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<540>{};
|
||||
constexpr auto Wi = Number<960>{};
|
||||
constexpr auto Y = Number<1>{};
|
||||
constexpr auto X = Number<1>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto Hi = Number<270>{};
|
||||
constexpr auto Wi = Number<480>{};
|
||||
constexpr auto Y = Number<1>{};
|
||||
constexpr auto X = Number<1>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#endif
|
||||
|
||||
constexpr auto conv_stride_h = I1;
|
||||
constexpr auto conv_stride_w = I1;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
|
||||
#if 1
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
#else
|
||||
constexpr auto in_left_pad_h = I0;
|
||||
constexpr auto in_left_pad_w = I0;
|
||||
constexpr auto in_right_pad_h = I0;
|
||||
constexpr auto in_right_pad_w = I0;
|
||||
#endif
|
||||
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5),
|
||||
bias_lengths_host(2);
|
||||
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C0);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[4] = static_cast<std::size_t>(C1);
|
||||
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K0 * K1);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C0);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[4] = static_cast<std::size_t>(C1);
|
||||
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K0);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[4] = static_cast<std::size_t>(K1);
|
||||
|
||||
bias_lengths_host[0] = static_cast<std::size_t>(K0);
|
||||
bias_lengths_host[1] = static_cast<std::size_t>(K1);
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> bias(bias_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(bias.mDesc, std::cout << "bias: ");
|
||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
||||
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
bias.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchwc = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1);
|
||||
const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1);
|
||||
const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHWC
|
||||
if(algo == ConvForwardAlgo::V5R1NCHWC)
|
||||
{
|
||||
const auto tmp = f_make_for_device_nchwc();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t,
|
||||
activ_type>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
bias,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_nchwc(in,
|
||||
wei,
|
||||
bias,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
activ_type);
|
||||
|
||||
ck::utils::check_err(out_device.mData, out_host.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "bias: ", bias.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,415 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 0
|
||||
#define USE_CONV_FWD_V5R1_NCHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V5R1NCHWC // 0
|
||||
};
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution_maxpool_nchwc(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
const Tensor<TOut>& bias,
|
||||
Tensor<TOut>& out_host,
|
||||
Tensor<TOut>& max_host,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ck::ActivTypeEnum activ_type)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) {
|
||||
double v = 0;
|
||||
auto k = k0 * out_host.mDesc.GetLengths()[4] + k1;
|
||||
|
||||
for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1)
|
||||
{
|
||||
v += static_cast<const double>(in(n, c0, hi, wi, c1)) *
|
||||
static_cast<const double>(wei(k, c0, y, x, c1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
v += bias(k0, k1);
|
||||
v = activ(v, activ_type);
|
||||
|
||||
out_host(n, k0, ho, wo, k1) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out_host.mDesc.GetLengths()[0],
|
||||
out_host.mDesc.GetLengths()[1],
|
||||
out_host.mDesc.GetLengths()[2],
|
||||
out_host.mDesc.GetLengths()[3],
|
||||
out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
|
||||
|
||||
auto maxpool_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) {
|
||||
auto hx = ho * 2;
|
||||
auto wx = wo * 2;
|
||||
|
||||
auto v0 = out_host(n, k0, hx, wx, k1);
|
||||
auto v1 = out_host(n, k0, hx, wx + 1, k1);
|
||||
auto v2 = out_host(n, k0, hx + 1, wx, k1);
|
||||
auto v3 = out_host(n, k0, hx + 1, wx + 1, k1);
|
||||
|
||||
max_host(n, k0, ho, wo, k1) = std::max({v0, v1, v2, v3});
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(maxpool_nchw,
|
||||
max_host.mDesc.GetLengths()[0],
|
||||
max_host.mDesc.GetLengths()[1],
|
||||
max_host.mDesc.GetLengths()[2],
|
||||
max_host.mDesc.GetLengths()[3],
|
||||
max_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 23)
|
||||
{
|
||||
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu;
|
||||
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
|
||||
const bool do_verification = std::stoi(argv[2]);
|
||||
const int init_method = std::stoi(argv[3]);
|
||||
const bool do_log = std::stoi(argv[4]);
|
||||
const int nrepeat = std::stoi(argv[5]);
|
||||
|
||||
const index_t N = std::stoi(argv[6]);
|
||||
const index_t K0 = std::stoi(argv[7]);
|
||||
const index_t K1 = std::stoi(argv[8]);
|
||||
const index_t C0 = std::stoi(argv[9]);
|
||||
const index_t C1 = std::stoi(argv[10]);
|
||||
const index_t Y = std::stoi(argv[11]);
|
||||
const index_t X = std::stoi(argv[12]);
|
||||
const index_t Hi = std::stoi(argv[13]);
|
||||
const index_t Wi = std::stoi(argv[14]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[15]);
|
||||
const index_t conv_stride_w = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[17]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[19]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[21]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[22]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
|
||||
const index_t Ho_2 = Ho / 2;
|
||||
const index_t Wo_2 = Wo / 2;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 6)
|
||||
{
|
||||
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
|
||||
|
||||
const bool do_verification = std::stoi(argv[2]);
|
||||
const int init_method = std::stoi(argv[3]);
|
||||
const bool do_log = std::stoi(argv[4]);
|
||||
const int nrepeat = std::stoi(argv[5]);
|
||||
|
||||
constexpr ck::ActivTypeEnum activ_type = ActivTypeEnum::LeakyRelu;
|
||||
|
||||
#if 1
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<1080>{};
|
||||
constexpr auto Wi = Number<1920>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<1080>{};
|
||||
constexpr auto Wi = Number<1920>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<3>{};
|
||||
constexpr auto C1 = Number<4>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<1>{};
|
||||
constexpr auto Hi = Number<540>{};
|
||||
constexpr auto Wi = Number<960>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#elif 0
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto Hi = Number<270>{};
|
||||
constexpr auto Wi = Number<480>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
constexpr auto C0 = Number<2>{};
|
||||
constexpr auto C1 = Number<8>{};
|
||||
constexpr auto K0 = Number<2>{};
|
||||
constexpr auto K1 = Number<8>{};
|
||||
#endif
|
||||
|
||||
constexpr auto conv_stride_h = I1;
|
||||
constexpr auto conv_stride_w = I1;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
|
||||
constexpr auto Ho_2 = Number<Ho / 2>{};
|
||||
constexpr auto Wo_2 = Number<Wo / 2>{};
|
||||
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5),
|
||||
max_lengths_host(5), bias_lengths_host(2);
|
||||
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C0);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[4] = static_cast<std::size_t>(C1);
|
||||
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K0 * K1);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C0);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[4] = static_cast<std::size_t>(C1);
|
||||
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K0);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[4] = static_cast<std::size_t>(K1);
|
||||
|
||||
max_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
max_lengths_host[1] = static_cast<std::size_t>(K0);
|
||||
max_lengths_host[2] = static_cast<std::size_t>(Ho_2);
|
||||
max_lengths_host[3] = static_cast<std::size_t>(Wo_2);
|
||||
max_lengths_host[4] = static_cast<std::size_t>(K1);
|
||||
|
||||
bias_lengths_host[0] = static_cast<std::size_t>(K0);
|
||||
bias_lengths_host[1] = static_cast<std::size_t>(K1);
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> bias(bias_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<in_data_t> max_device(max_lengths_host);
|
||||
Tensor<in_data_t> max_host(max_lengths_host);
|
||||
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
|
||||
auto f_make_for_device_nchwc = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1);
|
||||
const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1);
|
||||
const auto max_lengths_dev = make_tuple(N, K0, Ho_2, Wo_2, K1);
|
||||
const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
max_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHWC
|
||||
if(algo == ConvForwardAlgo::V5R1NCHWC)
|
||||
{
|
||||
const auto tmp = f_make_for_device_nchwc();
|
||||
|
||||
device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t,
|
||||
activ_type>(tmp[I0], // in_lengths_dev
|
||||
tmp[I1], // wei_lengths_dev
|
||||
tmp[I2], // max_lengths_dev
|
||||
tmp[I3], // out_lengths_dev
|
||||
tmp[I4], // conv_strides_dev
|
||||
tmp[I5], // conv_dilations_dev
|
||||
tmp[I6], // in_left_pads_dev
|
||||
tmp[I7], // in_right_pads_dev
|
||||
in,
|
||||
wei,
|
||||
bias,
|
||||
out_device,
|
||||
max_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_maxpool_nchwc(in,
|
||||
wei,
|
||||
bias,
|
||||
out_host,
|
||||
max_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
activ_type);
|
||||
|
||||
ck::utils::check_err(out_device.mData, out_host.mData);
|
||||
ck::utils::check_err(max_device.mData, max_host.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
// LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
// LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
// LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") <<
|
||||
// std::endl;
|
||||
LogRangeAsType<float>(std::cout << "max_host: ", max_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "max_device: ", max_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,532 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
|
||||
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
|
||||
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0
|
||||
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
|
||||
|
||||
enum ConvBackwardWeightAlgo
|
||||
{
|
||||
V4R4R2XDLNCHW, // 0
|
||||
V4R4R4XDLNHWC, // 1
|
||||
V4R4R2XDLATOMICNCHW, // 2
|
||||
V4R4R4XDLATOMICNHWC, // 3
|
||||
V4R4R5XDLATOMICNHWC, // 4
|
||||
};
|
||||
|
||||
template <typename TOut,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_convolution_backward_weight(const Tensor<TOut>& out,
|
||||
const Tensor<TIn>& in,
|
||||
Tensor<TWei>& wei,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(out(n, k, ho, wo));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, c, y, x) = v;
|
||||
};
|
||||
|
||||
auto f_kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(out(n, ho, wo, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, y, x, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kcyx,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kyxc,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 23)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
printf("additional: desired_grid_size\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t desired_grid_size = std::stoi(argv[22]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto C = Number<128>{};
|
||||
constexpr auto Hi = Number<14>{};
|
||||
constexpr auto Wi = Number<14>{};
|
||||
constexpr auto K = Number<256>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
|
||||
constexpr auto conv_stride_h = I1;
|
||||
constexpr auto conv_stride_w = I1;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using wei_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using out_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using wei_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using out_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using wei_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<wei_data_t> wei_device(wei_lengths_host);
|
||||
Tensor<wei_data_t> wei_host(wei_lengths_host);
|
||||
Tensor<out_data_t> out(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.1, 0.1}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{-0.1, 0.1}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
|
||||
|
||||
auto gen_out = [](auto... is) {
|
||||
return GeneratorTensor_2<out_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
out.GenerateTensorValue(gen_out, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
// set zero to wei_device
|
||||
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
|
||||
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R4_XDL_NHWC
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw<
|
||||
in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
desired_grid_size,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
desired_grid_size,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
desired_grid_size,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_convolution_backward_weight(out,
|
||||
in,
|
||||
wei_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
ck::utils::check_err(wei_device.mData, wei_host.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "out: ", out.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,456 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdlops_mk_kn_mn.hpp"
|
||||
#include "device_gemm_xdlops_mk_nk_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_kn_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_nk_mn.hpp"
|
||||
#include "device_gemm_xdlops_mk_kn_nm.hpp"
|
||||
#include "device_gemm_xdlops_mk_nk_nm.hpp"
|
||||
#include "device_gemm_xdlops_km_kn_nm.hpp"
|
||||
#include "device_gemm_xdlops_km_nk_nm.hpp"
|
||||
|
||||
#define USE_GEMM_XDL_MK_KN_MN 1
|
||||
#define USE_GEMM_XDL_MK_NK_MN 1
|
||||
#define USE_GEMM_XDL_KM_KN_MN 1
|
||||
#define USE_GEMM_XDL_KM_NK_MN 1
|
||||
#define USE_GEMM_XDL_MK_KN_NM 0
|
||||
#define USE_GEMM_XDL_MK_NK_NM 0
|
||||
#define USE_GEMM_XDL_KM_KN_NM 0
|
||||
#define USE_GEMM_XDL_KM_NK_NM 0
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
MK_KN_NM, // 4
|
||||
MK_NK_NM, // 5
|
||||
KM_KN_NM, // 6
|
||||
KM_NK_NM // 7
|
||||
};
|
||||
|
||||
enum struct GemmAlgo
|
||||
{
|
||||
Xdl_MK_KN_MN, // 0
|
||||
Xdl_MK_NK_MN, // 1
|
||||
Xdl_KM_KN_MN, // 2
|
||||
Xdl_KM_NK_MN, // 3
|
||||
Xdl_MK_KN_NM, // 4
|
||||
Xdl_MK_NK_NM, // 5
|
||||
Xdl_KM_KN_NM, // 6
|
||||
Xdl_KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
template <typename AType, typename BType, typename CType>
|
||||
void host_gemm(const Tensor<AType>& a,
|
||||
const Tensor<BType>& b,
|
||||
Tensor<CType>& c,
|
||||
const GemmMatrixLayout layout)
|
||||
{
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
auto f_mk_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
auto f_km_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
auto f_km_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
auto f_mk_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
auto f_mk_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
auto f_km_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
auto f_km_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
if(argc != 12)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: M, N, K\n");
|
||||
printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
const auto algo = static_cast<GemmAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t M = std::stoi(argv[7]);
|
||||
const index_t N = std::stoi(argv[8]);
|
||||
const index_t K = std::stoi(argv[9]);
|
||||
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]);
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]);
|
||||
|
||||
#if 0
|
||||
using ab_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = float;
|
||||
#elif 1
|
||||
using ab_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = half_t;
|
||||
#elif 1
|
||||
using ab_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using c_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2);
|
||||
std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2);
|
||||
|
||||
// A
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN ||
|
||||
layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
a_strides_host[0] = static_cast<std::size_t>(K);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
a_strides_host[0] = static_cast<std::size_t>(M);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
// B
|
||||
if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN ||
|
||||
layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
b_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
b_strides_host[0] = static_cast<std::size_t>(K);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
b_strides_host[0] = static_cast<std::size_t>(N);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
// C
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN ||
|
||||
layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
c_strides_host[0] = static_cast<std::size_t>(M);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
Tensor<ab_data_t> a(a_lengths_host, a_strides_host);
|
||||
Tensor<ab_data_t> b(b_lengths_host, b_strides_host);
|
||||
Tensor<c_data_t> c_host(c_lengths_host, c_strides_host);
|
||||
Tensor<c_data_t> c_device(c_lengths_host, c_strides_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: ");
|
||||
ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: ");
|
||||
ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: ");
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{0.0, 1.0}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
#if USE_GEMM_XDL_MK_KN_MN
|
||||
if(algo == GemmAlgo::Xdl_MK_KN_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_NK_MN
|
||||
if(algo == GemmAlgo::Xdl_MK_NK_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_KN_MN
|
||||
if(algo == GemmAlgo::Xdl_KM_KN_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_NK_MN
|
||||
if(algo == GemmAlgo::Xdl_KM_NK_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_KN_NM
|
||||
if(algo == GemmAlgo::Xdl_MK_KN_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_NK_NM
|
||||
if(algo == GemmAlgo::Xdl_MK_NK_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_KN_NM
|
||||
if(algo == GemmAlgo::Xdl_KM_KN_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_NK_NM
|
||||
if(algo == GemmAlgo::Xdl_KM_NK_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_gemm(a, b, c_host, layout);
|
||||
|
||||
ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,3 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include/ck
|
||||
${PROJECT_SOURCE_DIR}/include/ck/utility
|
||||
${PROJECT_SOURCE_DIR}/include/ck/host_utility
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor
|
||||
${PROJECT_SOURCE_DIR}/include/ck/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/host
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce
|
||||
${PROJECT_SOURCE_DIR}/external/include/half
|
||||
)
|
||||
|
||||
function(add_instance_library INSTANCE_NAME)
|
||||
message("adding instance ${INSTANCE_NAME}")
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
@@ -37,7 +17,6 @@ add_subdirectory(conv2d_fwd)
|
||||
add_subdirectory(conv3d_fwd)
|
||||
add_subdirectory(conv2d_fwd_bias_relu)
|
||||
add_subdirectory(conv2d_fwd_bias_relu_add)
|
||||
add_subdirectory(conv2d_fwd_bias_relu_atomic_add)
|
||||
add_subdirectory(conv2d_bwd_data)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
@@ -53,7 +32,6 @@ add_library(device_operations STATIC
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
|
||||
@@ -65,7 +43,6 @@ add_library(device_operations STATIC
|
||||
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
|
||||
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
|
||||
device_conv2d.cpp
|
||||
)
|
||||
add_library(composablekernels::device_operations ALIAS device_operations)
|
||||
|
||||
@@ -73,8 +50,8 @@ add_library(composablekernels::device_operations ALIAS device_operations)
|
||||
set(DEV_OPS_INC_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/ck/
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/
|
||||
${PROJECT_SOURCE_DIR}/external/include/
|
||||
)
|
||||
|
||||
target_compile_features(device_operations PUBLIC)
|
||||
set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(device_operations PUBLIC
|
||||
@@ -93,7 +70,6 @@ target_include_directories(device_operations PUBLIC
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/half>
|
||||
)
|
||||
|
||||
#once new arches are enabled make this an option on the main cmake file
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_batched_gemm_xdl.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_batched_gemm_xdl.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user