Integrate universal gemm with conv bwd data and add SplitK (#1315)

* Integrate universal gemm with conv bwd data

* Fix multi d kernel

* Add splitK support

* instances refactor

* instances refactor

* refactor

* fixeS

* fixes

* 16x16 instnaces

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Fixes

* fix

* fix

[ROCm/composable_kernel commit: 4094ad158a]
This commit is contained in:
Bartłomiej Kocot
2025-04-28 23:54:49 +02:00
committed by GitHub
parent 02ef8bcfb1
commit 05f9b2dde3
69 changed files with 2262 additions and 349 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -59,7 +59,8 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
const CDEElementwiseOperation& cde_element_op,
const ck::index_t split_k = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -227,7 +227,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op)
const CDEElementwiseOp& cde_element_op,
const ck::index_t split_k = 1)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{},
@@ -240,7 +241,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {
@@ -445,6 +447,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
const index_t k_batch_;
};
// Invoker
@@ -534,6 +538,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.k_batch_ != 1)
{
return false;
}
// check device
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
@@ -691,7 +700,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op)
const CDEElementwiseOp& cde_element_op,
const ck::index_t split_k = 1)
{
return Argument{p_a,
p_b,
@@ -711,7 +721,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
cde_element_op,
split_k};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -737,7 +748,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op) override
const CDEElementwiseOp& cde_element_op,
const ck::index_t split_k = 1) override
{
return std::make_unique<Argument>(p_a,
p_b,
@@ -757,7 +769,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
input_right_pads,
a_element_op,
b_element_op,
cde_element_op);
cde_element_op,
split_k);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override

View File

@@ -19,7 +19,7 @@
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

View File

@@ -17,7 +17,7 @@
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -187,7 +187,8 @@ struct TransformConvBwdDataToGemm_v1
WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)}
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)},
batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_}
{
}
@@ -203,7 +204,8 @@ struct TransformConvBwdDataToGemm_v1
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads,
const ConvSpatialDimsType& tildes)
const ConvSpatialDimsType& tildes,
const index_t batch_k = 1)
: Hi_{c_g_n_c_wis_lengths[HIdx]},
Wi_{c_g_n_c_wis_lengths[WIdx]},
Ho_{a_g_n_k_wos_lengths[HIdx]},
@@ -231,7 +233,8 @@ struct TransformConvBwdDataToGemm_v1
InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]},
batch_k_{batch_k}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
@@ -616,20 +619,22 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t AK0 = math::integer_divide_ceil(K_, AK1);
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 =
math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock;
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
return out_gemmak0_gemmm_gemmak1_grid_desc;
@@ -719,11 +724,15 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 =
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
@@ -816,11 +825,15 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 =
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
@@ -850,21 +863,23 @@ struct TransformConvBwdDataToGemm_v1
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t BK0 = math::integer_divide_ceil(K_, BK1);
const index_t K0PerBlock = GemmKPerBlock / BK1;
const index_t BK0 =
math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
make_tuple(BK0, GemmNPerBlock, BK1),
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
Sequence<false, DoPadGemmN, false>{});
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
@@ -925,11 +940,15 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const index_t K0PerBlock = GemmKPerBlock / BK1;
const index_t BK0 =
math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
BK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
@@ -1006,11 +1025,15 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const index_t K0PerBlock = GemmKPerBlock / BK1;
const index_t BK0 =
math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
BK1 * K0PerBlock * batch_k_) *
K0PerBlock;
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
@@ -1355,6 +1378,7 @@ struct TransformConvBwdDataToGemm_v1
IndexType ZTilde_, YTilde_, XTilde_;
IndexType DTilde_, HTilde_, WTilde_;
IndexType ZDot_, YDot_, XDot_;
index_t batch_k_;
};
} // namespace tensor_operation