diff --git a/example/20_convnd_bwd_weight_xdl/CMakeLists.txt b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt index 1a644d9479..66fdef625a 100644 --- a/example/20_convnd_bwd_weight_xdl/CMakeLists.txt +++ b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt @@ -1,2 +1,4 @@ add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp) -target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) \ No newline at end of file +add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp) +target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) +target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util) \ No newline at end of file diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp index 65725d3ae8..f917c2c3ac 100644 --- a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp @@ -297,52 +297,15 @@ int main(int argc, char* argv[]) split_k); // alloc work space - size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); - float ave_time = 0.f; - if(std::is_same::value && split_k > 1) + float ave_time = 0.f; + if(!conv->IsSupportedArgument(argument.get())) { - DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); - wei_work_space_device_buf.SetZero(); - argument = conv->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_work_space_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - params.N_, - params.K_, - params.C_, - params.input_spatial_lengths_, - params.filter_spatial_lengths_, - output_spatial_lengths, - params.conv_filter_strides_, - params.conv_filter_dilations_, - params.input_left_pads_, - params.input_right_pads_, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}, - split_k); - - if(!conv->IsSupportedArgument(argument.get())) - { - std::cout << "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem" - << std::endl; - return 1; - } - - ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - } - else - { - if(!conv->IsSupportedArgument(argument.get())) - { - std::cout << "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem" - << std::endl; - return 1; - } - ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; } + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = ck::utils::conv::get_flops( params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp new file mode 100644 index 0000000000..43f0cdb7ec --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp @@ -0,0 +1,427 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "conv_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_unary_elementwise.hpp" +#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_backward_weight.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert; + +using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device:: + DeviceUnaryElementwise; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +using DeviceConvBwdWeightBasePtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + +// clang-format off +template +using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device:: + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + AccDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +template +void host_elementwise(HostTensorB& B, + const HostTensorA& A, + const std::vector& shape, + Functor functor) +{ + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl; + for(std::size_t n = 0; n < tensor_size; ++n) + { + B.mData[n] = functor(A.mData[n]); + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: is show log (0=no, 1=yes)\n" + << "arg5: split-k : in this example split-k must be larger than 1\n" + << "arg6: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 7; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + int do_log = 0; + int split_k = 2; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc > 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + num_dim_spatial = std::stoi(argv[6]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 7; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + if(split_k <= 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_host_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_device_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + // reset input to zero + wei_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + // alloc work space + size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); + if(bwd_weight_workspace_size <= 0) + { + print_use_msg(); + exit(1); + } + + float conv_ave_time = 0.f; + + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer()); + + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / conv_ave_time; + + float gb_per_sec = num_btype / 1.E6 / conv_ave_time; + + std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, + wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdWeightInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvBwdWeightInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvBwdWeightInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index dde9e0f873..5920232038 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -11,6 +11,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_bwd_weight.hpp" +#include "gridwise_unary_elementwise_1d.hpp" namespace ck { namespace tensor_operation { @@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ 1); } + // type convert descs + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * 4; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + template + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using TypeConvertFunctor = + ck::tensor_operation::element_wise::UnaryTypeConvert; + using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); + using GridwiseUEltwise = + GridwiseUnaryElementwise_1D; + using ABCGridDescs = decltype(GetABCGridDesc()); using AGridDesc_K0_M_K1 = remove_cvref_t; @@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ true, true>; + using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + AccDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); @@ -802,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + // init work space + p_c_workspace_grid_ = nullptr; + block_2_ctile_map_ = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); @@ -838,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ std::vector input_left_pads_; std::vector input_right_pads_; index_t k_batch_; + + // external work space + void* p_c_workspace_grid_; }; // Invoker @@ -910,41 +1014,159 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ arg.block_2_ctile_map_); }; + // run kernel for bf16 with splitk + const auto run_bf16_splitk = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_workspace_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(AccDataType))); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + static_cast(arg.p_c_workspace_grid_), + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + // kernel for type conversion + std::vector filter_dims{static_cast(arg.Conv_K_), + static_cast(arg.Conv_C_)}; + + filter_dims.insert(std::end(filter_dims), + std::begin(arg.filter_spatial_lengths_), + std::end(arg.filter_spatial_lengths_)); + + int tensor_size = + std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies{}); + + const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size); + GridDesc_M0 a_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + GridDesc_M0 b_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + + if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_)) + { + throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting"); + } + + // run kernel for type conversion + void* p_c_grid_tmp_ = static_cast(arg.p_c_grid_); + InDataType* p_c_grid_tmp_bf16_ = static_cast(p_c_grid_tmp_); + const auto Run_type_convert = [&](const auto& kernel) { + float elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(type_convert_grid_size), + dim3(256), + 0, + static_cast(arg.p_c_workspace_grid_), + p_c_grid_tmp_bf16_, + a_grid_desc_m0_, + b_grid_desc_m0_, + TypeConvertFunctor{}); + return elapsed_time; + }; + if constexpr(std::is_same::value) { if(has_main_k0_block_loop) { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - true>; + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; - Run(kernel); + Run(kernel); + } + else + { + const auto kernel_type_convert = + kernel_unary_elementwise_1d; + + const auto kernel_conv = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAddFloatBf16Splitk, + ADataType, // TODO: distiguish A/B datatype + AccDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + run_bf16_splitk(kernel_conv); + ave_time += Run_type_convert(kernel_type_convert); + } } else { - const auto kernel = kernel_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - OutElementwiseOperation, - InElementwiseOperation, - WeiElementwiseOperation, - remove_reference_t, - false>; + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; - Run(kernel); + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAddFloatBf16Splitk, + ADataType, // TODO: distiguish A/B datatype + AccDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + run_bf16_splitk(kernel); + } } } else @@ -1226,6 +1448,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ { return GetWorkSpaceSize(*dynamic_cast(p_arg)); } + + void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override + { + dynamic_cast(p_arg)->p_c_workspace_grid_ = workspace_ptr; + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp new file mode 100644 index 0000000000..4fcad7004f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp @@ -0,0 +1,178 @@ +#pragma once +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "gridwise_unary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceUnaryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + BDataType* p_b, + const std::vector& shape, + const std::vector& stride_a, + const std::vector& stride_b, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + shape_(shape), + functor_(functor), + blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future + { + index_t tensor_size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size); + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_); + } + + const ADataType* p_a_; + BDataType* p_b_; + std::vector shape_; + GridDesc_M0 a_grid_desc_m0_; + GridDesc_M0 b_grid_desc_m0_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_unary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->shape_.back() % ScalarPerVector != 0) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + void* p_b, + std::vector shape, + std::vector stride_a, + std::vector stride_b, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + shape, + stride_a, + stride_b, + functor); + } + + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "ScalarPerVector = " << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index d899bdc967..70d773fb13 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -346,6 +346,27 @@ struct UnarySqrt }; }; +template +struct UnaryTypeConvert; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const + { + y = ck::type_convert(x); + }; +}; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const + { + y = ck::type_convert(x); + }; +}; + } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 0d3f8ddefb..b1f3779802 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + void* p_shared = static_cast(p_shared_block); + auto c_block_buf = make_dynamic_buffer( - static_cast(p_shared_block), + static_cast(p_shared), c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); static_assert(M1 == MWave, ""); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp new file mode 100644 index 0000000000..5773068756 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) +{ + GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor); +} + +template +struct GridwiseUnaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m0 = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * ScalarPerVector); + } + + __host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0) + { + return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size) + { + const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector); + + return grid_size; + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m0, thread_store_global_offset}; + + auto b_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + ScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + b_grid_desc_m0, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto m0 = b_grid_desc_m0.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = m0 / (loop_step); + do + { + // read and process ScalarPerVector elements + a_global_load.Run( + a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + + static_for<0, ScalarPerVector, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + functor(b_thread_buf(Number{}), a_thread_buf(Number{})); + }); + + b_global_write.Run(thread_desc_m0, + make_tuple(I0), // SrcSliceOriginIdx + b_thread_buf, + b_grid_desc_m0, + b_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); + b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck