mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Integrate universal gemm with conv bwd data and add SplitK (#1315)
* Integrate universal gemm with conv bwd data
* Fix multi d kernel
* Add splitK support
* instances refactor
* instances refactor
* refactor
* fixeS
* fixes
* 16x16 instnaces
* Fixes
* Fix
* Fix
* Fix
* Fix
* Fix
* Fixes
* fix
* fix
[ROCm/composable_kernel commit: 4094ad158a]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -59,7 +59,8 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const ck::index_t split_k = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -227,7 +227,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op)
|
||||
const CDEElementwiseOp& cde_element_op,
|
||||
const ck::index_t split_k = 1)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b)},
|
||||
p_ds_grid_{},
|
||||
@@ -240,7 +241,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -445,6 +447,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -534,6 +538,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.k_batch_ != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check device
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
@@ -691,7 +700,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op)
|
||||
const CDEElementwiseOp& cde_element_op,
|
||||
const ck::index_t split_k = 1)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -711,7 +721,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
cde_element_op,
|
||||
split_k};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -737,7 +748,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op) override
|
||||
const CDEElementwiseOp& cde_element_op,
|
||||
const ck::index_t split_k = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
@@ -757,7 +769,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
cde_element_op,
|
||||
split_k);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -187,7 +187,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
|
||||
ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
|
||||
YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
|
||||
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)}
|
||||
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)},
|
||||
batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -203,7 +204,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const ConvSpatialDimsType& conv_filter_dilations,
|
||||
const ConvSpatialDimsType& input_left_pads,
|
||||
const ConvSpatialDimsType& input_right_pads,
|
||||
const ConvSpatialDimsType& tildes)
|
||||
const ConvSpatialDimsType& tildes,
|
||||
const index_t batch_k = 1)
|
||||
: Hi_{c_g_n_c_wis_lengths[HIdx]},
|
||||
Wi_{c_g_n_c_wis_lengths[WIdx]},
|
||||
Ho_{a_g_n_k_wos_lengths[HIdx]},
|
||||
@@ -231,7 +233,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
|
||||
InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
|
||||
IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
|
||||
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
|
||||
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]},
|
||||
batch_k_{batch_k}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
@@ -616,20 +619,22 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t AK0 = math::integer_divide_ceil(K_, AK1);
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock;
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_unmerge_transform(make_tuple(AK0, AK1))),
|
||||
make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0, GemmMPerBlock, AK1),
|
||||
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
@@ -719,11 +724,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
make_pass_through_transform(
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -816,11 +825,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
make_pass_through_transform(
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -850,21 +863,23 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t BK0 = math::integer_divide_ceil(K_, BK1);
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(BK0, GemmNPerBlock, BK1),
|
||||
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
@@ -925,11 +940,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
|
||||
BK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(
|
||||
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -1006,11 +1025,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
|
||||
BK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(
|
||||
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -1355,6 +1378,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
IndexType ZTilde_, YTilde_, XTilde_;
|
||||
IndexType DTilde_, HTilde_, WTilde_;
|
||||
IndexType ZDot_, YDot_, XDot_;
|
||||
index_t batch_k_;
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
|
||||
Reference in New Issue
Block a user