mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Wmma support for grouped convolution bwd weight (#2947)
* Convolution bwd weight device implementation
* Merge branch 'grouped_conv_bwd_weight_device_impl_wmma' into 'feature/conv_bwd_weight_wmma'
Convolution bwd weight device implementation
See merge request amd/ai/composable_kernel!38
* Fix bug and disable splitK=-1 tests for wmma
* Add generic instances for bf16 f32 bf16
* check gridwise level validity in device impl for 1 stage D0
* Fix bugs in device implementation:
- rdna3 compilation error
- gridwise layouts (need to be correct to ensure that CheckValidaity()
works correctly)
* Add padding in conv to gemm transformers for 1x1Stride1Pad0 specialization
* Remove workaround for 1x1Stride1Pad0 conv specialization
* Add instances for xdl parity (for pipeline v1)
* Add two stage instances (xdl parity)
* Add multiple Ds instances
* Add examples
* Uncomment scale instances
* Fix copyright
* Fix examples compilation
* Add atomic add float4
* Fix compilation error
* Fix instances
* Compute tolerances in examples instead of using default ones
* Compute tolerances instead of using default ones in bilinear and scale tests
* Merge branch 'grouped_conv_bwd_weight_instances_examples' into 'feature/conv_bwd_weight_wmma'
Grouped conv: Instances and example bwd weight
See merge request amd/ai/composable_kernel!47
* Device implementation of explicit gemm for grouped conv bwd weight
Based on batched gemm multiple D
* Add instances for pipeline v1 and v3
* Add support for occupancy-based splitk
* Fix ckProfiler dependencies
* Review fixes
* Merge branch 'explicit_bwd_weight' into 'feature/conv_bwd_weight_wmma'
Device implementation of explicit gemm for grouped conv bwd weight
See merge request amd/ai/composable_kernel!52
* Fix cmake file for tests
* fix clang format
* fix instance factory error
* Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test.
* Revert "Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test."
This reverts commit da8e4cfb7917d45d46339ec74eb72e2f585f14cf.
* Disable splitk for 2stage xdl on rdna (bug to be fixed)
* Fix add_test_executable
* Always ForceThreadTileTransfer for now, WaveTileTransfer does not work for convolution yet.
* Grab device and gridwise files from bkp branch, this should enable splitK support for convolution and also we no longer ForceThreadTileTransfer for explicit gemm. Also grab some updates from 7e7243783008b11e904f127ecf1df55ef95e9af2 to fix building on clang20.
* Fix bug in various bwd wei device implementations / profiler where the occupancy based split_k value could not be found because the Argument did not derive from ArgumentSplitK, leading to incorrect error tolerances.
* Actually print the reason when a device implementation is not supported.
* Print number of valid instances in profiler and tests.
* Fix clang format for Two Stage implementation
* Fix copyright
* Address review comments
* Fix explicit conv bwd weight struct
* Fix gridwise common
* Fix gridwise ab scale
* Remove autodeduce 1 stage
* Restore example tolerance calculation
* Fix compilation error
* Fix gridwise common
* Fix gridwise gemm
* Fix typo
* Fix splitk
* Fix splitk ab scale
* Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test.
* Reduce instances to only the tuned wmma V3 ones for implicit v1 intra and explicit v1 intra pad/nopad.
* Add explicit oddMN support with custom tuned instances
* Add two stage instances based on the parameters from the tuned cshuffle V3 instances. CShuffleBlockTranserScalarPerVector adapted to 4, and mergegroups fixed to 1 for now. No more special instance lists.
* Replace cshuffle non-v3 lists with v3 lists, making sure to not have duplications. Also removing stride1pad0 support for NHWGC since we can use explicit for those cases.
* Remove some instances that give incorrect results (f16 NHWGC)
* Add bf16 f32 bf16 instances based on tuned b16 NHWGC GKYXC instances.
* Add back some generic instances to make sure we have the same shape / layout / datatype support as before the instance selection process.
* Add instances for scale and bilinear based on the bf16 NHWGC GKYXC tuning. Keep generic instances for support.
* Disable two stage f16 instances which produce incorrect results.
* Remove more instances which fail verification, for bf16_f32_bf16 and for f16 scale / bilinear.
* Disable all non-generic two-stage instances in the instance lists for NHWGC. They are never faster and support is already carried by CShuffleV3 and Explicit.
* Remove unused instance lists and related add_x_instance() functions, fwd declarations, cmakelists entries. Also merge the "wmma" and "wmma v3" instance list files, which are both v3.
* Re-enable all xdl instances (un-16x16-adapted) and dl instances. Remove custom ckProfiler target.
* Remove straggler comments
* Remove [[maybe_unused]]
* Fix clang format
* Remove unwanted instances. This includes all instances which are not NHWGCxGKYXC and F16 or BF16 (no mixed in-out types).
* Add comment
---------
Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
Co-authored-by: Kiefer van Teutem <50830967+krithalith@users.noreply.github.com>
[ROCm/composable_kernel commit: 87dd073887]
This commit is contained in:
@@ -11,8 +11,11 @@ add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bw
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_fp16 grouped_conv_bwd_weight_v3_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_bf16 grouped_conv_bwd_weight_v3_wmma_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16)
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = BF16;
|
||||
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
|
||||
using WeiDataType = F32;
|
||||
using OutDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::NHWGC,
|
||||
ck::tensor_layout::convolution::NDHWGC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::NDHWGK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // K1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
@@ -16,11 +16,20 @@ using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
NDimSpatial,
|
||||
ck::tensor_layout::convolution::GNDHWC,
|
||||
ck::tensor_layout::convolution::GKZYXC,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::NHWGC,
|
||||
ck::tensor_layout::convolution::NDHWGC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::NDHWGK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
@@ -32,30 +41,30 @@ using DeviceConvBwdWeightInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
32, // KPerBlock
|
||||
8, // K1
|
||||
16, // MPerWMMA
|
||||
16, // NPerWMMA
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // ABlockTransferSrcAccessOrder
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
false, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
4,
|
||||
2,
|
||||
S<1, 32, 1, 8>,
|
||||
1>;
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
@@ -80,6 +89,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
|
||||
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
// Dl and WMMA ops don't support split_k > 1
|
||||
// Dl ops don't support split_k > 1
|
||||
constexpr ck::index_t split_k = 1;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
@@ -131,7 +131,21 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
|
||||
wei_device_buf.FromDevice(wei_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData);
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol = ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
double atol = ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k);
|
||||
|
||||
return ck::utils::check_err(wei_device_result.mData,
|
||||
wei_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
else if(config.do_verification == 2)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user