mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Add grouped conv bwd weight wmma (#985)
* Add grouped conv bwd weight wmma
* Update README, changelog, profiler
* Minor fixes
* Fix grouped conv bwd wei dl kernel
* Minor fixes
* Minor stylistic fixes
[ROCm/composable_kernel commit: 16d7c4d2f7]
This commit is contained in:
@@ -1,11 +1,20 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
|
||||
set(target 1)
|
||||
endif()
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
|
||||
|
||||
@@ -33,8 +34,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
|
||||
bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k)
|
||||
{
|
||||
// Odd K or C values are supported only by DL kernel (only applies to fp16)
|
||||
// DL kernel currently supports only `split_k=1`
|
||||
// Odd K or C values are supported only by DL and WMMA
|
||||
// kernels (only applies to fp16)
|
||||
// DL and WMMA kernels currently support only `split_k=1`
|
||||
if constexpr(std::is_same_v<InDataType, ck::half_t>)
|
||||
{
|
||||
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
|
||||
@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
const bool is_navi3x = ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102";
|
||||
if(is_navi3x)
|
||||
{
|
||||
// on navi3x only support for 3d is implemented
|
||||
if constexpr(NDimSpatial{} != 3)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// on navi3x only support for i8 and fp16 is implemented
|
||||
if constexpr(!((std::is_same_v<InDataType, int8_t> &&
|
||||
std::is_same_v<WeiDataType, int8_t> &&
|
||||
std::is_same_v<OutDataType, int8_t>) ||
|
||||
(std::is_same_v<InDataType, ck::half_t> &&
|
||||
std::is_same_v<WeiDataType, ck::half_t> &&
|
||||
std::is_same_v<OutDataType, ck::half_t>)))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// WMMA kernel is only supported for split_k=1
|
||||
if(split_k != 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// support for i8 is only implemented on navi3x
|
||||
if constexpr(std::is_same_v<InDataType, int8_t> &&
|
||||
std::is_same_v<WeiDataType, int8_t> && std::is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<int8_t, int8_t, int8_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>;
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using ConvolutionBackwardWeightSpecialization =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault = ConvolutionBackwardWeightSpecialization::Default;
|
||||
static constexpr auto Filter1x1Stride1Pad0 =
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <typename Tuple, ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using OutLayout = std::tuple_element_t<0, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<1, Tuple>;
|
||||
using InLayout = std::tuple_element_t<2, Tuple>;
|
||||
static constexpr ck::index_t NDimSpatial = std::tuple_element_t<3, Tuple>{};
|
||||
|
||||
// clang-format off
|
||||
using GroupedConvBwdWeightDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
//| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
<NDimSpatial, InLayout, WeiLayout, OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
|
||||
template <ck::index_t SplitK>
|
||||
bool Run()
|
||||
{
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
|
||||
|
||||
range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
|
||||
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
|
||||
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
|
||||
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
|
||||
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
|
||||
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
|
||||
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
|
||||
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
auto conv = GroupedConvBwdWeightDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
SplitK);
|
||||
return conv.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<GNDHWK, GKZYXC, GNDHWC, ck::Number<3>>,
|
||||
std::tuple<NDHWGK, GKZYXC, NDHWGC, ck::Number<3>>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeightFilter1x13d
|
||||
: public TestGroupedConvndBwdWeight<Tuple, Filter1x1Stride1Pad0>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeightDefault3d
|
||||
: public TestGroupedConvndBwdWeight<Tuple, ConvBwdWeightDefault>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightFilter1x13d, KernelTypes3d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightFilter1x13d, SpecializationCheck)
|
||||
{
|
||||
// Check filter 3x3x3 instead of 1x1x1
|
||||
this->conv_param = {
|
||||
3, 2, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
bool is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check strides 2x2x2 instead of 1x1x1
|
||||
this->conv_param = {
|
||||
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check with pad
|
||||
this->conv_param = {
|
||||
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Supported version
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, VectorLoadCheck)
|
||||
{
|
||||
// vector load for A
|
||||
this->conv_param = {
|
||||
3, 2, 128, 129, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
bool is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
// vector load for B, E, Ds
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 257, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, SplitKCheck)
|
||||
{
|
||||
// SplitK=1
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
bool is_supported = this->template Run<1>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
// SplitK=2
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
Reference in New Issue
Block a user