mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
example for convnd bwd weight bf16 splitk (#265)
* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight * add bwd weight for bf16: init * remove redundant compute * use datatype and split k to check whether a workspace is used * remove unused computation for work space size * add some code for bfp16 * add device/grid unary op * add unary type convert to bwd-weight example * support bf16 splitk kernel for convnd bwd weight * 1. remove comments. 2. add checkvalidity. 3. add gridsize computation * add workspace size check * fix format * change function name
This commit is contained in:
@@ -1,2 +1,4 @@
|
|||||||
add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp)
|
add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp)
|
||||||
target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util)
|
add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp)
|
||||||
|
target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util)
|
||||||
|
target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util)
|
||||||
@@ -297,52 +297,15 @@ int main(int argc, char* argv[])
|
|||||||
split_k);
|
split_k);
|
||||||
|
|
||||||
// alloc work space
|
// alloc work space
|
||||||
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
|
float ave_time = 0.f;
|
||||||
float ave_time = 0.f;
|
if(!conv->IsSupportedArgument(argument.get()))
|
||||||
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
|
|
||||||
{
|
{
|
||||||
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
|
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||||
wei_work_space_device_buf.SetZero();
|
"not support this Conv problem"
|
||||||
argument = conv->MakeArgumentPointer(
|
<< std::endl;
|
||||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
return 1;
|
||||||
static_cast<AccDataType*>(wei_work_space_device_buf.GetDeviceBuffer()),
|
|
||||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
|
||||||
params.N_,
|
|
||||||
params.K_,
|
|
||||||
params.C_,
|
|
||||||
params.input_spatial_lengths_,
|
|
||||||
params.filter_spatial_lengths_,
|
|
||||||
output_spatial_lengths,
|
|
||||||
params.conv_filter_strides_,
|
|
||||||
params.conv_filter_dilations_,
|
|
||||||
params.input_left_pads_,
|
|
||||||
params.input_right_pads_,
|
|
||||||
InElementOp{},
|
|
||||||
WeiElementOp{},
|
|
||||||
OutElementOp{},
|
|
||||||
split_k);
|
|
||||||
|
|
||||||
if(!conv->IsSupportedArgument(argument.get()))
|
|
||||||
{
|
|
||||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
|
||||||
"not support this Conv problem"
|
|
||||||
<< std::endl;
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if(!conv->IsSupportedArgument(argument.get()))
|
|
||||||
{
|
|
||||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
|
||||||
"not support this Conv problem"
|
|
||||||
<< std::endl;
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
|
||||||
}
|
}
|
||||||
|
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||||
|
|
||||||
std::size_t flop = ck::utils::conv::get_flops(
|
std::size_t flop = ck::utils::conv::get_flops(
|
||||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||||
|
|||||||
@@ -0,0 +1,427 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <numeric>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <half.hpp>
|
||||||
|
|
||||||
|
#include "check_err.hpp"
|
||||||
|
#include "conv_util.hpp"
|
||||||
|
#include "config.hpp"
|
||||||
|
#include "print.hpp"
|
||||||
|
#include "device.hpp"
|
||||||
|
#include "host_tensor.hpp"
|
||||||
|
#include "host_tensor_generator.hpp"
|
||||||
|
#include "device_tensor.hpp"
|
||||||
|
#include "tensor_layout.hpp"
|
||||||
|
#include "element_wise_operation.hpp"
|
||||||
|
#include "device_unary_elementwise.hpp"
|
||||||
|
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
|
||||||
|
#include "reference_conv_backward_weight.hpp"
|
||||||
|
|
||||||
|
using InDataType = ck::bhalf_t;
|
||||||
|
using WeiDataType = ck::bhalf_t;
|
||||||
|
using OutDataType = ck::bhalf_t;
|
||||||
|
using AccDataType = float;
|
||||||
|
|
||||||
|
template <ck::index_t... Is>
|
||||||
|
using S = ck::Sequence<Is...>;
|
||||||
|
|
||||||
|
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||||
|
|
||||||
|
using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
|
||||||
|
|
||||||
|
using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device::
|
||||||
|
DeviceUnaryElementwise<AccDataType, WeiDataType, UnaryTypeConvert, 1, 4>;
|
||||||
|
|
||||||
|
static constexpr auto ConvBwdWeightDefault =
|
||||||
|
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||||
|
|
||||||
|
using DeviceConvBwdWeightBasePtr =
|
||||||
|
ck::tensor_operation::device::DeviceConvBwdWeightPtr<InElementOp, WeiElementOp, OutElementOp>;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template <ck::index_t NumDimSpatial>
|
||||||
|
using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device::
|
||||||
|
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||||
|
InDataType, // InDataType
|
||||||
|
AccDataType, // WeiDataType
|
||||||
|
OutDataType, // OutDataType
|
||||||
|
AccDataType, // AccDataType
|
||||||
|
InElementOp, // InElementwiseOperation
|
||||||
|
WeiElementOp, // WeiElementwiseOperation
|
||||||
|
OutElementOp, // OutElementwiseOperation
|
||||||
|
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||||
|
NumDimSpatial, // NumDimSpatial
|
||||||
|
256, // BlockSize
|
||||||
|
128, // MPerBlock
|
||||||
|
128, // NPerBlock
|
||||||
|
4, // K0PerBlock
|
||||||
|
8, // K1
|
||||||
|
32, // MPerXdl
|
||||||
|
32, // NPerXdl
|
||||||
|
2, // MXdlPerWave
|
||||||
|
2, // NXdlPerWave
|
||||||
|
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||||
|
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||||
|
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||||
|
2, // ABlockTransferSrcVectorDim
|
||||||
|
8, // ABlockTransferSrcScalarPerVector
|
||||||
|
2, // ABlockTransferDstScalarPerVector_K1
|
||||||
|
true, // ABlockLdsAddExtraM
|
||||||
|
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||||
|
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||||
|
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
|
||||||
|
2, // BBlockTransferSrcVectorDim
|
||||||
|
8, // BBlockTransferSrcScalarPerVector
|
||||||
|
2, // BBlockTransferDstScalarPerVector_K1
|
||||||
|
true, // BBlockLdsAddExtraN
|
||||||
|
1, // CShuffleMXdlPerWavePerShuffle
|
||||||
|
1, // CShuffleNXdlPerWavePerShuffle
|
||||||
|
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||||
|
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
template <ck::index_t NumDimSpatial>
|
||||||
|
using ReferenceConvBwdWeightInstance =
|
||||||
|
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
|
||||||
|
WeiDataType,
|
||||||
|
OutDataType,
|
||||||
|
InElementOp,
|
||||||
|
WeiElementOp,
|
||||||
|
OutElementOp,
|
||||||
|
NumDimSpatial>;
|
||||||
|
|
||||||
|
template <typename HostTensorB, typename HostTensorA, typename Functor>
|
||||||
|
void host_elementwise(HostTensorB& B,
|
||||||
|
const HostTensorA& A,
|
||||||
|
const std::vector<std::size_t>& shape,
|
||||||
|
Functor functor)
|
||||||
|
{
|
||||||
|
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
|
||||||
|
std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl;
|
||||||
|
for(std::size_t n = 0; n < tensor_size; ++n)
|
||||||
|
{
|
||||||
|
B.mData[n] = functor(A.mData[n]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_use_msg()
|
||||||
|
{
|
||||||
|
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||||
|
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
|
||||||
|
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||||
|
<< "arg4: is show log (0=no, 1=yes)\n"
|
||||||
|
<< "arg5: split-k : in this example split-k must be larger than 1\n"
|
||||||
|
<< "arg6: N spatial dimensions (default 2)\n"
|
||||||
|
<< "Following arguments (depending on number of spatial dims):\n"
|
||||||
|
<< " N, K, C, \n"
|
||||||
|
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||||
|
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||||
|
<< " <strides>, (ie Sy, Sx for 2D)\n"
|
||||||
|
<< " <dilations>, (ie Dy, Dx for 2D)\n"
|
||||||
|
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||||
|
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
|
||||||
|
{
|
||||||
|
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||||
|
ck::utils::conv::ConvParams params;
|
||||||
|
int arg_idx = 7;
|
||||||
|
|
||||||
|
params.num_dim_spatial_ = num_dim_spatial;
|
||||||
|
params.N_ = std::stoi(argv[arg_idx++]);
|
||||||
|
params.K_ = std::stoi(argv[arg_idx++]);
|
||||||
|
params.C_ = std::stoi(argv[arg_idx++]);
|
||||||
|
|
||||||
|
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||||
|
for(int i = 0; i < num_dim_spatial; ++i)
|
||||||
|
{
|
||||||
|
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||||
|
}
|
||||||
|
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||||
|
for(int i = 0; i < num_dim_spatial; ++i)
|
||||||
|
{
|
||||||
|
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||||
|
}
|
||||||
|
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||||
|
for(int i = 0; i < num_dim_spatial; ++i)
|
||||||
|
{
|
||||||
|
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||||
|
}
|
||||||
|
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||||
|
for(int i = 0; i < num_dim_spatial; ++i)
|
||||||
|
{
|
||||||
|
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||||
|
}
|
||||||
|
params.input_left_pads_.resize(num_dim_spatial);
|
||||||
|
for(int i = 0; i < num_dim_spatial; ++i)
|
||||||
|
{
|
||||||
|
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||||
|
}
|
||||||
|
params.input_right_pads_.resize(num_dim_spatial);
|
||||||
|
for(int i = 0; i < num_dim_spatial; ++i)
|
||||||
|
{
|
||||||
|
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial)
|
||||||
|
{
|
||||||
|
switch(num_dim_spatial)
|
||||||
|
{
|
||||||
|
case 3: {
|
||||||
|
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<3>>();
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<2>>();
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<1>>();
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char* argv[])
|
||||||
|
{
|
||||||
|
bool do_verification = true;
|
||||||
|
int init_method = 1;
|
||||||
|
bool time_kernel = false;
|
||||||
|
int num_dim_spatial = 2;
|
||||||
|
int do_log = 0;
|
||||||
|
int split_k = 2;
|
||||||
|
|
||||||
|
ck::utils::conv::ConvParams params;
|
||||||
|
params.C_ = 128;
|
||||||
|
|
||||||
|
if(argc == 6)
|
||||||
|
{
|
||||||
|
do_verification = std::stoi(argv[1]);
|
||||||
|
init_method = std::stoi(argv[2]);
|
||||||
|
time_kernel = std::stoi(argv[3]);
|
||||||
|
do_log = std::stoi(argv[4]);
|
||||||
|
split_k = std::stoi(argv[5]);
|
||||||
|
}
|
||||||
|
else if(argc > 6)
|
||||||
|
{
|
||||||
|
do_verification = std::stoi(argv[1]);
|
||||||
|
init_method = std::stoi(argv[2]);
|
||||||
|
time_kernel = std::stoi(argv[3]);
|
||||||
|
do_log = std::stoi(argv[4]);
|
||||||
|
split_k = std::stoi(argv[5]);
|
||||||
|
num_dim_spatial = std::stoi(argv[6]);
|
||||||
|
// check args number
|
||||||
|
int conv_args = 3 + num_dim_spatial * 6;
|
||||||
|
int cmdline_nargs = conv_args + 7;
|
||||||
|
if(cmdline_nargs != argc)
|
||||||
|
{
|
||||||
|
print_use_msg();
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
params = parse_conv_params(num_dim_spatial, argv);
|
||||||
|
}
|
||||||
|
else if(argc != 1)
|
||||||
|
{
|
||||||
|
print_use_msg();
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(split_k <= 1)
|
||||||
|
{
|
||||||
|
print_use_msg();
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||||
|
static_cast<std::size_t>(params.C_)};
|
||||||
|
input_dims.insert(std::end(input_dims),
|
||||||
|
std::begin(params.input_spatial_lengths_),
|
||||||
|
std::end(params.input_spatial_lengths_));
|
||||||
|
|
||||||
|
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||||
|
static_cast<std::size_t>(params.C_)};
|
||||||
|
filter_dims.insert(std::end(filter_dims),
|
||||||
|
std::begin(params.filter_spatial_lengths_),
|
||||||
|
std::end(params.filter_spatial_lengths_));
|
||||||
|
|
||||||
|
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||||
|
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||||
|
static_cast<std::size_t>(params.K_)};
|
||||||
|
output_dims.insert(std::end(output_dims),
|
||||||
|
std::begin(output_spatial_lengths),
|
||||||
|
std::end(output_spatial_lengths));
|
||||||
|
|
||||||
|
Tensor<InDataType> in_n_c_hi_wi(
|
||||||
|
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||||
|
Tensor<WeiDataType> wei_k_c_y_x_host_result(
|
||||||
|
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||||
|
Tensor<WeiDataType> wei_k_c_y_x_device_result(
|
||||||
|
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||||
|
Tensor<OutDataType> out_n_k_ho_wo(
|
||||||
|
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||||
|
|
||||||
|
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||||
|
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl;
|
||||||
|
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||||
|
|
||||||
|
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||||
|
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
|
||||||
|
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||||
|
|
||||||
|
switch(init_method)
|
||||||
|
{
|
||||||
|
case 0: break;
|
||||||
|
case 1:
|
||||||
|
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||||
|
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||||
|
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||||
|
DeviceMem wei_device_buf(sizeof(WeiDataType) *
|
||||||
|
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
|
||||||
|
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||||
|
|
||||||
|
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||||
|
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||||
|
// reset input to zero
|
||||||
|
wei_device_buf.SetZero();
|
||||||
|
|
||||||
|
// do GEMM
|
||||||
|
auto conv = get_conv_instance(num_dim_spatial);
|
||||||
|
auto invoker = conv->MakeInvokerPointer();
|
||||||
|
auto argument =
|
||||||
|
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||||
|
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||||
|
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||||
|
params.N_,
|
||||||
|
params.K_,
|
||||||
|
params.C_,
|
||||||
|
params.input_spatial_lengths_,
|
||||||
|
params.filter_spatial_lengths_,
|
||||||
|
output_spatial_lengths,
|
||||||
|
params.conv_filter_strides_,
|
||||||
|
params.conv_filter_dilations_,
|
||||||
|
params.input_left_pads_,
|
||||||
|
params.input_right_pads_,
|
||||||
|
InElementOp{},
|
||||||
|
WeiElementOp{},
|
||||||
|
OutElementOp{},
|
||||||
|
split_k);
|
||||||
|
|
||||||
|
// alloc work space
|
||||||
|
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
|
||||||
|
if(bwd_weight_workspace_size <= 0)
|
||||||
|
{
|
||||||
|
print_use_msg();
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
float conv_ave_time = 0.f;
|
||||||
|
|
||||||
|
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
|
||||||
|
wei_work_space_device_buf.SetZero();
|
||||||
|
conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer());
|
||||||
|
|
||||||
|
if(!conv->IsSupportedArgument(argument.get()))
|
||||||
|
{
|
||||||
|
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||||
|
"not support this Conv problem"
|
||||||
|
<< std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||||
|
|
||||||
|
std::size_t flop = ck::utils::conv::get_flops(
|
||||||
|
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||||
|
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
|
||||||
|
params.N_,
|
||||||
|
params.C_,
|
||||||
|
params.K_,
|
||||||
|
params.input_spatial_lengths_,
|
||||||
|
params.filter_spatial_lengths_,
|
||||||
|
output_spatial_lengths);
|
||||||
|
|
||||||
|
float tflops = static_cast<float>(flop) / 1.E9 / conv_ave_time;
|
||||||
|
|
||||||
|
float gb_per_sec = num_btype / 1.E6 / conv_ave_time;
|
||||||
|
|
||||||
|
std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||||
|
<< " GB/s" << std::endl;
|
||||||
|
|
||||||
|
if(do_verification)
|
||||||
|
{
|
||||||
|
auto verify_f = [&](const auto& ref_conv) {
|
||||||
|
auto ref_invoker = ref_conv.MakeInvoker();
|
||||||
|
|
||||||
|
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
|
||||||
|
wei_k_c_y_x_host_result,
|
||||||
|
out_n_k_ho_wo,
|
||||||
|
params.conv_filter_strides_,
|
||||||
|
params.conv_filter_dilations_,
|
||||||
|
params.input_left_pads_,
|
||||||
|
params.input_right_pads_,
|
||||||
|
InElementOp{},
|
||||||
|
WeiElementOp{},
|
||||||
|
OutElementOp{});
|
||||||
|
|
||||||
|
ref_invoker.Run(ref_argument);
|
||||||
|
|
||||||
|
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
|
||||||
|
|
||||||
|
if(do_log)
|
||||||
|
{
|
||||||
|
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
|
||||||
|
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
|
||||||
|
LogRangeAsType<float>(
|
||||||
|
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
|
||||||
|
<< std::endl;
|
||||||
|
LogRangeAsType<float>(
|
||||||
|
std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ck::utils::check_err(wei_k_c_y_x_device_result.mData,
|
||||||
|
wei_k_c_y_x_host_result.mData)
|
||||||
|
? 0
|
||||||
|
: 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
switch(num_dim_spatial)
|
||||||
|
{
|
||||||
|
case 3: {
|
||||||
|
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
|
||||||
|
verify_f(ref_conv);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
|
||||||
|
verify_f(ref_conv);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
|
||||||
|
verify_f(ref_conv);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
|
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||||
|
#include "gridwise_unary_elementwise_1d.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
namespace tensor_operation {
|
namespace tensor_operation {
|
||||||
@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
|||||||
1);
|
1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// type convert descs
|
||||||
|
template <typename Desc_M0>
|
||||||
|
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
|
||||||
|
{
|
||||||
|
const auto m0 = desc_m0.GetLength(I0);
|
||||||
|
const index_t loop_step = gridSize * blockSize * 4;
|
||||||
|
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||||
|
const auto desc_m0_pad =
|
||||||
|
transform_tensor_descriptor(desc_m0,
|
||||||
|
make_tuple(make_right_pad_transform(m0, pad)),
|
||||||
|
make_tuple(Sequence<0>{}),
|
||||||
|
make_tuple(Sequence<0>{}));
|
||||||
|
return desc_m0_pad;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <index_t Dim>
|
||||||
|
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||||
|
const std::vector<index_t>& stride,
|
||||||
|
index_t gridSize,
|
||||||
|
index_t blockSize)
|
||||||
|
{
|
||||||
|
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||||
|
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||||
|
|
||||||
|
// nd desc - [s0, s1, s2, ...]
|
||||||
|
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||||
|
|
||||||
|
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||||
|
if constexpr(Dim > 1)
|
||||||
|
{
|
||||||
|
const auto desc_m0 = transform_tensor_descriptor(
|
||||||
|
desc,
|
||||||
|
make_tuple(make_merge_transform(tupleOfShape)),
|
||||||
|
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||||
|
make_tuple(Sequence<0>{}));
|
||||||
|
|
||||||
|
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
using TypeConvertFunctor =
|
||||||
|
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
|
||||||
|
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
|
||||||
|
using GridwiseUEltwise =
|
||||||
|
GridwiseUnaryElementwise_1D<AccDataType, InDataType, GridDesc_M0, TypeConvertFunctor, 4>;
|
||||||
|
|
||||||
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
|
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
|
||||||
|
|
||||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||||
@@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
|||||||
true,
|
true,
|
||||||
true>;
|
true>;
|
||||||
|
|
||||||
|
using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||||
|
BlockSize,
|
||||||
|
ADataType, // TODO: distinguish A/B datatype
|
||||||
|
AccDataType,
|
||||||
|
AccDataType,
|
||||||
|
InMemoryDataOperationEnum::AtomicAdd,
|
||||||
|
AGridDesc_K0_M_K1,
|
||||||
|
BGridDesc_K0_N_K1,
|
||||||
|
CGridDesc_M_N,
|
||||||
|
AElementwiseOperation,
|
||||||
|
BElementwiseOperation,
|
||||||
|
CElementwiseOperation,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
K0PerBlock,
|
||||||
|
MPerXdl,
|
||||||
|
NPerXdl,
|
||||||
|
K1,
|
||||||
|
MXdlPerWave,
|
||||||
|
NXdlPerWave,
|
||||||
|
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_K1,
|
||||||
|
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
ABlockLdsAddExtraM,
|
||||||
|
ABlockLdsM1PerBlock,
|
||||||
|
ABlockLdsM0PerBlock,
|
||||||
|
ABlockLdsM1Padding,
|
||||||
|
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_K1,
|
||||||
|
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockLdsAddExtraN,
|
||||||
|
BBlockLdsN1PerBlock,
|
||||||
|
BBlockLdsN0PerBlock,
|
||||||
|
BBlockLdsN1Padding,
|
||||||
|
CShuffleMXdlPerWavePerShuffle,
|
||||||
|
CShuffleNXdlPerWavePerShuffle,
|
||||||
|
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||||
|
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||||
|
true,
|
||||||
|
true>;
|
||||||
|
|
||||||
// Argument
|
// Argument
|
||||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||||
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
||||||
@@ -802,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
|||||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||||
c_grid_desc_m_n_ = descs[I2];
|
c_grid_desc_m_n_ = descs[I2];
|
||||||
|
|
||||||
|
// init work space
|
||||||
|
p_c_workspace_grid_ = nullptr;
|
||||||
|
|
||||||
block_2_ctile_map_ =
|
block_2_ctile_map_ =
|
||||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||||
|
|
||||||
@@ -838,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
|||||||
std::vector<index_t> input_left_pads_;
|
std::vector<index_t> input_left_pads_;
|
||||||
std::vector<index_t> input_right_pads_;
|
std::vector<index_t> input_right_pads_;
|
||||||
index_t k_batch_;
|
index_t k_batch_;
|
||||||
|
|
||||||
|
// external work space
|
||||||
|
void* p_c_workspace_grid_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Invoker
|
// Invoker
|
||||||
@@ -910,41 +1014,159 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
|||||||
arg.block_2_ctile_map_);
|
arg.block_2_ctile_map_);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// run kernel for bf16 with splitk
|
||||||
|
const auto run_bf16_splitk = [&](const auto& kernel) {
|
||||||
|
hipGetErrorString(hipMemset(
|
||||||
|
arg.p_c_workspace_grid_,
|
||||||
|
0,
|
||||||
|
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||||
|
sizeof(AccDataType)));
|
||||||
|
|
||||||
|
ave_time =
|
||||||
|
launch_and_time_kernel(stream_config,
|
||||||
|
kernel,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
arg.p_a_grid_,
|
||||||
|
arg.p_b_grid_,
|
||||||
|
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
|
||||||
|
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||||
|
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||||
|
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||||
|
arg.a_element_op_,
|
||||||
|
arg.b_element_op_,
|
||||||
|
arg.c_element_op_,
|
||||||
|
arg.block_2_ctile_map_);
|
||||||
|
};
|
||||||
|
|
||||||
|
// kernel for type conversion
|
||||||
|
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(arg.Conv_K_),
|
||||||
|
static_cast<std::size_t>(arg.Conv_C_)};
|
||||||
|
|
||||||
|
filter_dims.insert(std::end(filter_dims),
|
||||||
|
std::begin(arg.filter_spatial_lengths_),
|
||||||
|
std::end(arg.filter_spatial_lengths_));
|
||||||
|
|
||||||
|
int tensor_size =
|
||||||
|
std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies<int>{});
|
||||||
|
|
||||||
|
const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size);
|
||||||
|
GridDesc_M0 a_grid_desc_m0_ =
|
||||||
|
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
|
||||||
|
GridDesc_M0 b_grid_desc_m0_ =
|
||||||
|
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
|
||||||
|
|
||||||
|
if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_))
|
||||||
|
{
|
||||||
|
throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting");
|
||||||
|
}
|
||||||
|
|
||||||
|
// run kernel for type conversion
|
||||||
|
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
|
||||||
|
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
|
||||||
|
const auto Run_type_convert = [&](const auto& kernel) {
|
||||||
|
float elapsed_time =
|
||||||
|
launch_and_time_kernel(stream_config,
|
||||||
|
kernel,
|
||||||
|
dim3(type_convert_grid_size),
|
||||||
|
dim3(256),
|
||||||
|
0,
|
||||||
|
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
|
||||||
|
p_c_grid_tmp_bf16_,
|
||||||
|
a_grid_desc_m0_,
|
||||||
|
b_grid_desc_m0_,
|
||||||
|
TypeConvertFunctor{});
|
||||||
|
return elapsed_time;
|
||||||
|
};
|
||||||
|
|
||||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||||
{
|
{
|
||||||
if(has_main_k0_block_loop)
|
if(has_main_k0_block_loop)
|
||||||
{
|
{
|
||||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
if(kbatch == 1)
|
||||||
GridwiseGemm,
|
{
|
||||||
ADataType, // TODO: distiguish A/B datatype
|
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||||
CDataType,
|
GridwiseGemm,
|
||||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
ADataType, // TODO: distiguish A/B datatype
|
||||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
CDataType,
|
||||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||||
OutElementwiseOperation,
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||||
InElementwiseOperation,
|
remove_reference_t<
|
||||||
WeiElementwiseOperation,
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
OutElementwiseOperation,
|
||||||
true>;
|
InElementwiseOperation,
|
||||||
|
WeiElementwiseOperation,
|
||||||
|
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||||
|
true>;
|
||||||
|
|
||||||
Run(kernel);
|
Run(kernel);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto kernel_type_convert =
|
||||||
|
kernel_unary_elementwise_1d<GridwiseUEltwise,
|
||||||
|
AccDataType,
|
||||||
|
InDataType,
|
||||||
|
GridDesc_M0,
|
||||||
|
TypeConvertFunctor>;
|
||||||
|
|
||||||
|
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
|
||||||
|
GridwiseGemmAtomicAddFloatBf16Splitk,
|
||||||
|
ADataType, // TODO: distiguish A/B datatype
|
||||||
|
AccDataType,
|
||||||
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||||
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||||
|
remove_reference_t<
|
||||||
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||||
|
OutElementwiseOperation,
|
||||||
|
InElementwiseOperation,
|
||||||
|
WeiElementwiseOperation,
|
||||||
|
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
run_bf16_splitk(kernel_conv);
|
||||||
|
ave_time += Run_type_convert(kernel_type_convert);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
if(kbatch == 1)
|
||||||
GridwiseGemm,
|
{
|
||||||
ADataType, // TODO: distiguish A/B datatype
|
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||||
CDataType,
|
GridwiseGemm,
|
||||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
ADataType, // TODO: distiguish A/B datatype
|
||||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
CDataType,
|
||||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||||
OutElementwiseOperation,
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||||
InElementwiseOperation,
|
remove_reference_t<
|
||||||
WeiElementwiseOperation,
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
OutElementwiseOperation,
|
||||||
false>;
|
InElementwiseOperation,
|
||||||
|
WeiElementwiseOperation,
|
||||||
|
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||||
|
false>;
|
||||||
|
|
||||||
Run(kernel);
|
Run(kernel);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||||
|
GridwiseGemmAtomicAddFloatBf16Splitk,
|
||||||
|
ADataType, // TODO: distiguish A/B datatype
|
||||||
|
AccDataType,
|
||||||
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||||
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||||
|
remove_reference_t<
|
||||||
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||||
|
OutElementwiseOperation,
|
||||||
|
InElementwiseOperation,
|
||||||
|
WeiElementwiseOperation,
|
||||||
|
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
run_bf16_splitk(kernel);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -1226,6 +1448,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
|||||||
{
|
{
|
||||||
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
|
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
|
||||||
|
{
|
||||||
|
dynamic_cast<Argument*>(p_arg)->p_c_workspace_grid_ = workspace_ptr;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace device
|
} // namespace device
|
||||||
|
|||||||
@@ -0,0 +1,178 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "device.hpp"
|
||||||
|
#include "device_base.hpp"
|
||||||
|
#include "gridwise_unary_elementwise_1d.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
namespace tensor_operation {
|
||||||
|
namespace device {
|
||||||
|
|
||||||
|
template <typename ADataType,
|
||||||
|
typename BDataType,
|
||||||
|
typename ElementwiseFunctor,
|
||||||
|
index_t Dim,
|
||||||
|
index_t ScalarPerVector>
|
||||||
|
struct DeviceUnaryElementwise : public BaseOperator
|
||||||
|
{
|
||||||
|
static constexpr auto I0 = Number<0>{};
|
||||||
|
|
||||||
|
template <typename Desc_M0>
|
||||||
|
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
|
||||||
|
{
|
||||||
|
const auto m0 = desc_m0.GetLength(I0);
|
||||||
|
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
|
||||||
|
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||||
|
const auto desc_m0_pad =
|
||||||
|
transform_tensor_descriptor(desc_m0,
|
||||||
|
make_tuple(make_right_pad_transform(m0, pad)),
|
||||||
|
make_tuple(Sequence<0>{}),
|
||||||
|
make_tuple(Sequence<0>{}));
|
||||||
|
return desc_m0_pad;
|
||||||
|
}
|
||||||
|
|
||||||
|
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||||
|
const std::vector<index_t>& stride,
|
||||||
|
index_t gridSize,
|
||||||
|
index_t blockSize)
|
||||||
|
{
|
||||||
|
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||||
|
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||||
|
|
||||||
|
// nd desc - [s0, s1, s2, ...]
|
||||||
|
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||||
|
|
||||||
|
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||||
|
if constexpr(Dim > 1)
|
||||||
|
{
|
||||||
|
const auto desc_m0 = transform_tensor_descriptor(
|
||||||
|
desc,
|
||||||
|
make_tuple(make_merge_transform(tupleOfShape)),
|
||||||
|
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||||
|
make_tuple(Sequence<0>{}));
|
||||||
|
|
||||||
|
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||||
|
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<ADataType,
|
||||||
|
BDataType,
|
||||||
|
GridDesc_M0,
|
||||||
|
ElementwiseFunctor,
|
||||||
|
ScalarPerVector>;
|
||||||
|
|
||||||
|
struct Argument : public BaseArgument
|
||||||
|
{
|
||||||
|
Argument(const ADataType* p_a,
|
||||||
|
BDataType* p_b,
|
||||||
|
const std::vector<index_t>& shape,
|
||||||
|
const std::vector<index_t>& stride_a,
|
||||||
|
const std::vector<index_t>& stride_b,
|
||||||
|
ElementwiseFunctor functor)
|
||||||
|
: p_a_(p_a),
|
||||||
|
p_b_(p_b),
|
||||||
|
shape_(shape),
|
||||||
|
functor_(functor),
|
||||||
|
blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future
|
||||||
|
{
|
||||||
|
index_t tensor_size =
|
||||||
|
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
|
||||||
|
gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size);
|
||||||
|
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
|
||||||
|
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ADataType* p_a_;
|
||||||
|
BDataType* p_b_;
|
||||||
|
std::vector<int> shape_;
|
||||||
|
GridDesc_M0 a_grid_desc_m0_;
|
||||||
|
GridDesc_M0 b_grid_desc_m0_;
|
||||||
|
ElementwiseFunctor functor_;
|
||||||
|
index_t blockSize_;
|
||||||
|
index_t gridSize_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Invoker : public BaseInvoker
|
||||||
|
{
|
||||||
|
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||||
|
{
|
||||||
|
const auto kernel = kernel_unary_elementwise_1d<GridwiseUEltwise,
|
||||||
|
ADataType,
|
||||||
|
BDataType,
|
||||||
|
GridDesc_M0,
|
||||||
|
ElementwiseFunctor>;
|
||||||
|
|
||||||
|
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||||
|
kernel,
|
||||||
|
dim3(arg.gridSize_),
|
||||||
|
dim3(arg.blockSize_),
|
||||||
|
0,
|
||||||
|
arg.p_a_,
|
||||||
|
arg.p_b_,
|
||||||
|
arg.a_grid_desc_m0_,
|
||||||
|
arg.b_grid_desc_m0_,
|
||||||
|
arg.functor_);
|
||||||
|
return elapsed_time;
|
||||||
|
}
|
||||||
|
|
||||||
|
// polymorphic
|
||||||
|
float Run(const BaseArgument* p_arg,
|
||||||
|
const StreamConfig& stream_config = StreamConfig{}) override
|
||||||
|
{
|
||||||
|
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||||
|
{
|
||||||
|
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||||
|
|
||||||
|
if(pArg == nullptr)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if(pArg->shape_.back() % ScalarPerVector != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||||
|
void* p_b,
|
||||||
|
std::vector<index_t> shape,
|
||||||
|
std::vector<index_t> stride_a,
|
||||||
|
std::vector<index_t> stride_b,
|
||||||
|
ElementwiseFunctor functor)
|
||||||
|
{
|
||||||
|
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||||
|
static_cast<BDataType*>(p_b),
|
||||||
|
shape,
|
||||||
|
stride_a,
|
||||||
|
stride_b,
|
||||||
|
functor);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||||
|
|
||||||
|
std::string GetTypeString() const override
|
||||||
|
{
|
||||||
|
auto str = std::stringstream();
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
str << "DeviceBinaryElementwise"
|
||||||
|
<< "<"
|
||||||
|
<< "ScalarPerVector = " << ScalarPerVector
|
||||||
|
<< ">";
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
return str.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace device
|
||||||
|
} // namespace tensor_operation
|
||||||
|
} // namespace ck
|
||||||
@@ -346,6 +346,27 @@ struct UnarySqrt<double, double>
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Y, typename X>
|
||||||
|
struct UnaryTypeConvert;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct UnaryTypeConvert<float, ck::bhalf_t>
|
||||||
|
{
|
||||||
|
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
|
||||||
|
{
|
||||||
|
y = ck::type_convert<float, ck::bhalf_t>(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct UnaryTypeConvert<ck::bhalf_t, float>
|
||||||
|
{
|
||||||
|
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
|
||||||
|
{
|
||||||
|
y = ck::type_convert<ck::bhalf_t, float>(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace element_wise
|
} // namespace element_wise
|
||||||
} // namespace tensor_operation
|
} // namespace tensor_operation
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
|
|||||||
@@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
|||||||
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
|
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
|
||||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||||
|
|
||||||
|
void* p_shared = static_cast<void*>(p_shared_block);
|
||||||
|
|
||||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||||
static_cast<FloatC*>(p_shared_block),
|
static_cast<FloatC*>(p_shared),
|
||||||
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||||
|
|
||||||
static_assert(M1 == MWave, "");
|
static_assert(M1 == MWave, "");
|
||||||
|
|||||||
@@ -0,0 +1,129 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cluster_descriptor.hpp"
|
||||||
|
#include "data_type.hpp"
|
||||||
|
#include "element_wise_operation.hpp"
|
||||||
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
|
||||||
|
template <typename GridwiseUEltwise,
|
||||||
|
typename ADataType,
|
||||||
|
typename BDataType,
|
||||||
|
typename GridDesc_M0,
|
||||||
|
typename ElementwiseFunctor>
|
||||||
|
__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global,
|
||||||
|
BDataType* __restrict__ p_b_global,
|
||||||
|
const GridDesc_M0 a_grid_desc_m0,
|
||||||
|
const GridDesc_M0 b_grid_desc_m0,
|
||||||
|
const ElementwiseFunctor functor)
|
||||||
|
{
|
||||||
|
GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ADataType,
|
||||||
|
typename BDataType,
|
||||||
|
typename GridDesc_M0,
|
||||||
|
typename ElementwiseFunctor,
|
||||||
|
index_t ScalarPerVector>
|
||||||
|
struct GridwiseUnaryElementwise_1D
|
||||||
|
{
|
||||||
|
static constexpr auto I0 = Number<0>{};
|
||||||
|
static constexpr auto thread_desc_m0 =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
|
||||||
|
|
||||||
|
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||||
|
|
||||||
|
static __device__ auto CalculateElementwiseIndex()
|
||||||
|
{
|
||||||
|
const index_t global_thread_id = get_thread_global_1d_id();
|
||||||
|
return make_multi_index(global_thread_id * ScalarPerVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0,
|
||||||
|
const GridDesc_M0 b_grid_desc_m0)
|
||||||
|
{
|
||||||
|
return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0);
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size)
|
||||||
|
{
|
||||||
|
const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector);
|
||||||
|
|
||||||
|
return grid_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ static void Run(const ADataType* __restrict__ p_a_global,
|
||||||
|
BDataType* __restrict__ p_b_global,
|
||||||
|
const GridDesc_M0 a_grid_desc_m0,
|
||||||
|
const GridDesc_M0 b_grid_desc_m0,
|
||||||
|
const ElementwiseFunctor functor)
|
||||||
|
{
|
||||||
|
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||||
|
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
|
||||||
|
auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||||
|
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
|
||||||
|
|
||||||
|
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf;
|
||||||
|
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf;
|
||||||
|
|
||||||
|
const auto thread_store_global_offset = CalculateElementwiseIndex();
|
||||||
|
|
||||||
|
auto a_global_load =
|
||||||
|
ThreadwiseTensorSliceTransfer_v2<ADataType,
|
||||||
|
ADataType,
|
||||||
|
GridDesc_M0,
|
||||||
|
decltype(thread_desc_m0),
|
||||||
|
Sequence<ScalarPerVector>, // SliceLengths
|
||||||
|
Sequence<0>, // DimAccessOrder
|
||||||
|
0, // SrcVectorDim
|
||||||
|
ScalarPerVector,
|
||||||
|
1, // SrcScalarStrideInVector
|
||||||
|
false>{a_grid_desc_m0, thread_store_global_offset};
|
||||||
|
|
||||||
|
auto b_global_write =
|
||||||
|
ThreadwiseTensorSliceTransfer_v1r3<BDataType,
|
||||||
|
BDataType,
|
||||||
|
decltype(thread_desc_m0),
|
||||||
|
GridDesc_M0,
|
||||||
|
PassThrough,
|
||||||
|
Sequence<ScalarPerVector>, // SliceLengths
|
||||||
|
Sequence<0>, // DimAccessOrder
|
||||||
|
0, // DstVectorDim
|
||||||
|
ScalarPerVector,
|
||||||
|
InMemoryDataOperationEnum::Set,
|
||||||
|
1, // DstScalarStrideInVector
|
||||||
|
false>{
|
||||||
|
b_grid_desc_m0, thread_store_global_offset, PassThrough{}};
|
||||||
|
|
||||||
|
const index_t blockSize = get_block_size();
|
||||||
|
const index_t blockPerGrid = get_grid_size();
|
||||||
|
const auto m0 = b_grid_desc_m0.GetLength(I0);
|
||||||
|
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector;
|
||||||
|
const auto loop_step_index = make_multi_index(loop_step);
|
||||||
|
|
||||||
|
index_t num_iter = m0 / (loop_step);
|
||||||
|
do
|
||||||
|
{
|
||||||
|
// read and process ScalarPerVector elements
|
||||||
|
a_global_load.Run(
|
||||||
|
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
|
||||||
|
|
||||||
|
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
|
||||||
|
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
|
||||||
|
functor(b_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}));
|
||||||
|
});
|
||||||
|
|
||||||
|
b_global_write.Run(thread_desc_m0,
|
||||||
|
make_tuple(I0), // SrcSliceOriginIdx
|
||||||
|
b_thread_buf,
|
||||||
|
b_grid_desc_m0,
|
||||||
|
b_global_buf);
|
||||||
|
|
||||||
|
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
|
||||||
|
b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index);
|
||||||
|
} while(--num_iter);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ck
|
||||||
Reference in New Issue
Block a user