Automatic deduction of split-K value for grouped convolution (#2491)

* Split-K autodeduction for DeviceGroupedConvBwdWeight_Xdl_CShuffle and DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.

* Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.

* Use simple best occupancy model to calculate the split-K.

* Handle split-K autodeduction in explicit gemm conv.

* Add unit tests for split-K autodeduction.

* Remove oversubscription.

* Small fixes.

* Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle.

* Run clang formatting.

* Fix error handling in the conv profiler.

* Add missing documentation for the autodeducted split-K values.

* Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver.

* Fix clang formatting and split-K profiler documentation.

* Rename max_occupancy value variable.

* Calculate grid size for split-K autodeduction directly from input array shapes and template params.

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-07-31 13:08:45 +03:00
committed by GitHub
parent 7b074249f4
commit e962a41638
14 changed files with 544 additions and 72 deletions

View File

@@ -337,6 +337,60 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
}
};
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
// Invoker
struct Invoker : public BaseInvoker
{
@@ -1044,6 +1098,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
return str.str();
}
static ck::index_t GetMaxOccupancy()
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
return active_workgroups_per_cu.max_occupancy_;
}
};
} // namespace device

View File

@@ -13,6 +13,8 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
namespace ck {
namespace tensor_operation {
@@ -142,6 +144,20 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
}
else
{
split_k_ = split_k;
}
if constexpr(IsTwoStageNeeded)
{
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
@@ -176,7 +192,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
out_element_op,
in_element_op,
wei_element_op,
split_k};
split_k_};
}
else
{
@@ -199,7 +215,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
out_element_op,
in_element_op,
wei_element_op,
split_k};
split_k_};
}
}
@@ -236,6 +252,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
bool is_filter_data_packed;
CElementwiseGridDesc elementwise_desc_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
ck::index_t split_k_;
};
// Invoker

View File

@@ -19,6 +19,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -542,7 +544,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType,
BDataType,
AccDataType,
OutElementwiseOperation,
InElementwiseOperation,
element_wise::PassThrough,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
true>,
BlockSize,
dynamic_smem_size));
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(
const InDataType* p_in_grid,
@@ -591,9 +622,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -610,6 +642,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
std::tie(gemmM, gemmN, std::ignore) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -712,7 +760,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -22,6 +22,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -504,7 +506,55 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -547,9 +597,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -576,6 +627,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size = calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) *
Conv_G_ / NumGroupsToMerge;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
<< std::endl;
}
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -751,7 +831,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -19,6 +19,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -419,7 +421,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType,
BDataType,
CDataType,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<>,
false>, // Both true/false give the same occupancy.
BlockSize,
dynamic_smem_size));
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -463,9 +494,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -491,6 +523,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
std::tie(gemmM, gemmN, std::ignore) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -656,7 +705,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -20,6 +20,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
@@ -381,7 +383,53 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -424,9 +472,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -443,6 +492,35 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
<< std::endl;
}
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -513,7 +591,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
struct ArgumentSplitK
{
index_t k_batch_{1};
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <numeric>
#include <hip/hip_runtime.h>
#include "ck/utility/env.hpp"
#include "ck/utility/number.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct DeviceProperties
{
DeviceProperties()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu_ = dev_prop.multiProcessorCount;
};
int num_cu_;
};
inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
{
static DeviceProperties device_properties;
const int max_capacity = max_occupancy * device_properties.num_cu_;
ck::index_t k_batch = 1;
const auto optimal_split =
static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
if(optimal_split > 1)
{
k_batch = optimal_split;
}
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
<< max_occupancy << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
}
return k_batch;
}
template <ck::index_t NDimSpatial>
inline auto
get_bwd_weight_gemm_sizes(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// The input array has elements in the order: G, N, K, Do, Ho, Wo
// GemmK = N * Do * Ho * Wo for the BWD weight pass.
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
index_t{1},
std::multiplies<>{});
const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
// The GEMM M dimension is the number of output channels.
const auto gemmM = e_g_k_c_xs_lengths[I1];
// The output array has elements in the order: G, K, C, X, Y, Z
// GemmN = C * X * Y * Z for the BWD weight pass.
const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
end(e_g_k_c_xs_lengths),
index_t{1},
std::multiplies<>{});
const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ;
return std::make_tuple(gemmM, gemmN, gemmK);
}
template <ck::index_t MPerBlock, ck::index_t NPerBlock>
inline ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
{
const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock);
const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock);
return M0 * N0;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -148,7 +148,7 @@
# <dilations>, (ie Dy, Dx for 2D)
# <left padding>, (ie LeftPy, LeftPx for 2D)
# <right padding>, (ie RightPy, RightPx for 2D)
# SplitK
# SplitK (-1 for internally computed split-K value, positive value to set k batches explicitly, or 'all' to test all internal split-K values)
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
./bin/ckProfiler grouped_conv_bwd_weight 1 1 0 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1

View File

@@ -11,6 +11,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
@@ -40,7 +41,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
ck::index_t split_k)
const std::string& split_k)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -138,10 +139,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::index_t best_split_k = 1;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
std::string best_split_k("1");
// profile device Conv instances
bool all_pass = true;
@@ -170,11 +171,20 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
std::vector<ck::index_t> split_k_list = {/*auto deduce value*/ -1, 1, 2, 4, 8, 16, 32, 64, 128};
if(split_k > 0)
if(split_k != "all")
{
split_k_list = {split_k};
try
{
ck::index_t split_k_value = std::stoi(split_k);
split_k_list = {split_k_value};
}
catch(const std::exception& e)
{
std::cerr << e.what() << '\n';
exit(EXIT_FAILURE);
}
}
for(auto& op_ptr : op_ptrs)
@@ -200,6 +210,16 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
out_element_op,
split_k_list[split_k_id]);
auto split_k_value = split_k_list[split_k_id];
auto split_k_param_str = std::to_string(split_k_value);
auto* split_k_arg =
dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
if(split_k_arg && split_k_value < 0)
{
split_k_value = split_k_arg->k_batch_;
split_k_param_str = std::to_string(split_k_value) + " (best occupancy)";
}
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
@@ -222,7 +242,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK "
<< split_k_list[split_k_id] << std::endl;
<< split_k_param_str << std::endl;
if(tflops > best_tflops)
{
@@ -230,7 +250,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_list[split_k_id];
best_split_k = split_k_param_str;
}
if(do_verification)
@@ -244,7 +264,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
const index_t num_accums = output.GetElementSize() / conv_param.K_;
const index_t num_accums_split_k = split_k_list[split_k_id];
const index_t num_accums_split_k = split_k_value;
// Calculate thresholds
auto rtol =
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(

View File

@@ -56,7 +56,9 @@ static void print_helper_msg()
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << " SplitK\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg()
<< " SplitK (-1 for internally computed split-K value, positive value to set k "
"batches explicitly, or 'all' to test all internal split-K values)\n"
<< std::endl;
}
@@ -88,7 +90,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
const auto& split_k = std::string(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
using F32 = float;
using F16 = ck::half_t;

View File

@@ -30,7 +30,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
using NDimSpatial = std::tuple_element_t<6, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params;
std::vector<ck::index_t> split_ks{1, 2};
std::vector<ck::index_t> split_ks{-1, 1, 2};
bool skip_case(const ck::index_t split_k)
{
@@ -108,7 +108,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
false, // do_log
false, // time_kernel
param,
split_k);
std::to_string(split_k));
}
}
}

View File

@@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
// clang-format on
ck::utils::conv::ConvParam conv_param;
ck::index_t split_k{2};
std::vector<ck::index_t> split_ks{-1, 2};
template <ck::index_t NDimSpatial>
bool Run()
@@ -96,24 +96,30 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto conv = GroupedConvBwdWeightDeviceInstance{};
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k);
return conv.IsSupportedArgument(argument);
bool is_supported = true;
for(const auto split_k : split_ks)
{
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k);
is_supported &= conv.IsSupportedArgument(argument);
}
return is_supported;
}
};

View File

@@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
// clang-format on
ck::utils::conv::ConvParam conv_param;
ck::index_t split_k{2};
std::vector<ck::index_t> split_ks{-1, 2};
template <ck::index_t NDimSpatial>
bool Run()
@@ -96,24 +96,30 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto conv = GroupedConvBwdWeightDeviceInstance{};
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k);
return conv.IsSupportedArgument(argument);
bool is_supported = true;
for(const auto split_k : split_ks)
{
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
split_k);
is_supported &= conv.IsSupportedArgument(argument);
}
return is_supported;
}
};