Integrate universal gemm with conv bwd data and add SplitK (#1315)

* Integrate universal gemm with conv bwd data

* Fix multi d kernel

* Add splitK support

* instances refactor

* instances refactor

* refactor

* fixeS

* fixes

* 16x16 instnaces

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Fixes

* fix

* fix
This commit is contained in:
Bartłomiej Kocot
2025-04-28 23:54:49 +02:00
committed by GitHub
parent d9786f3363
commit 4094ad158a
69 changed files with 2262 additions and 349 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -227,7 +227,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op)
const CDEElementwiseOp& cde_element_op,
const ck::index_t split_k = 1)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{},
@@ -240,7 +241,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {
@@ -445,6 +447,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
const index_t k_batch_;
};
// Invoker
@@ -534,6 +538,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.k_batch_ != 1)
{
return false;
}
// check device
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
@@ -691,7 +700,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op)
const CDEElementwiseOp& cde_element_op,
const ck::index_t split_k = 1)
{
return Argument{p_a,
p_b,
@@ -711,7 +721,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
cde_element_op,
split_k};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -737,7 +748,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOp& a_element_op,
const BElementwiseOp& b_element_op,
const CDEElementwiseOp& cde_element_op) override
const CDEElementwiseOp& cde_element_op,
const ck::index_t split_k = 1) override
{
return std::make_unique<Argument>(p_a,
p_b,
@@ -757,7 +769,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
input_right_pads,
a_element_op,
b_element_op,
cde_element_op);
cde_element_op,
split_k);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override

View File

@@ -19,7 +19,7 @@
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

View File

@@ -17,7 +17,7 @@
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"