Automatic deduction of split-K value for grouped convolution (#2491)

* Split-K autodeduction for DeviceGroupedConvBwdWeight_Xdl_CShuffle and DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.

* Split-K autodeduction for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.

* Use simple best occupancy model to calculate the split-K.

* Handle split-K autodeduction in explicit gemm conv.

* Add unit tests for split-K autodeduction.

* Remove oversubscription.

* Small fixes.

* Added split-K autodeduction for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle.

* Run clang formatting.

* Fix error handling in the conv profiler.

* Add missing documentation for the autodeducted split-K values.

* Add split-K autodeduction to DeviceGroupedConvBwdWeight_Explicit_Xdl solver.

* Fix clang formatting and split-K profiler documentation.

* Rename max_occupancy value variable.

* Calculate grid size for split-K autodeduction directly from input array shapes and template params.

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-07-31 13:08:45 +03:00
committed by GitHub
parent 7b074249f4
commit e962a41638
14 changed files with 544 additions and 72 deletions

View File

@@ -337,6 +337,60 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
}
};
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
Argument,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
// Invoker
struct Invoker : public BaseInvoker
{
@@ -1044,6 +1098,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
return str.str();
}
static ck::index_t GetMaxOccupancy()
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
return active_workgroups_per_cu.max_occupancy_;
}
};
} // namespace device

View File

@@ -13,6 +13,8 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
namespace ck {
namespace tensor_operation {
@@ -142,6 +144,20 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
}
else
{
split_k_ = split_k;
}
if constexpr(IsTwoStageNeeded)
{
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
@@ -176,7 +192,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
out_element_op,
in_element_op,
wei_element_op,
split_k};
split_k_};
}
else
{
@@ -199,7 +215,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
out_element_op,
in_element_op,
wei_element_op,
split_k};
split_k_};
}
}
@@ -236,6 +252,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
bool is_filter_data_packed;
CElementwiseGridDesc elementwise_desc_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
ck::index_t split_k_;
};
// Invoker

View File

@@ -19,6 +19,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -542,7 +544,36 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType,
BDataType,
AccDataType,
OutElementwiseOperation,
InElementwiseOperation,
element_wise::PassThrough,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
true>,
BlockSize,
dynamic_smem_size));
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(
const InDataType* p_in_grid,
@@ -591,9 +622,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -610,6 +642,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
std::tie(gemmM, gemmN, std::ignore) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -712,7 +760,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -22,6 +22,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -504,7 +506,55 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -547,9 +597,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -576,6 +627,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size = calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) *
Conv_G_ / NumGroupsToMerge;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
<< std::endl;
}
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -751,7 +831,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -19,6 +19,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -419,7 +421,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
int max_occupancy = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType,
BDataType,
CDataType,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<>,
false>, // Both true/false give the same occupancy.
BlockSize,
dynamic_smem_size));
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -463,9 +494,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -491,6 +523,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
std::tie(gemmM, gemmN, std::ignore) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -656,7 +705,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -20,6 +20,8 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
@@ -381,7 +383,53 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct Argument : public BaseArgument
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
@@ -424,9 +472,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
k_batch_{split_k}
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
@@ -443,6 +492,35 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
std::tie(gemmM, gemmN, gemmK) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto grid_size =
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
<< std::endl;
}
}
else
{
k_batch_ = split_k;
}
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -513,7 +591,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
struct ArgumentSplitK
{
index_t k_batch_{1};
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <numeric>
#include <hip/hip_runtime.h>
#include "ck/utility/env.hpp"
#include "ck/utility/number.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct DeviceProperties
{
DeviceProperties()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu_ = dev_prop.multiProcessorCount;
};
int num_cu_;
};
inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
{
static DeviceProperties device_properties;
const int max_capacity = max_occupancy * device_properties.num_cu_;
ck::index_t k_batch = 1;
const auto optimal_split =
static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
if(optimal_split > 1)
{
k_batch = optimal_split;
}
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
<< max_occupancy << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
}
return k_batch;
}
template <ck::index_t NDimSpatial>
inline auto
get_bwd_weight_gemm_sizes(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// The input array has elements in the order: G, N, K, Do, Ho, Wo
// GemmK = N * Do * Ho * Wo for the BWD weight pass.
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
index_t{1},
std::multiplies<>{});
const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
// The GEMM M dimension is the number of output channels.
const auto gemmM = e_g_k_c_xs_lengths[I1];
// The output array has elements in the order: G, K, C, X, Y, Z
// GemmN = C * X * Y * Z for the BWD weight pass.
const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
end(e_g_k_c_xs_lengths),
index_t{1},
std::multiplies<>{});
const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ;
return std::make_tuple(gemmM, gemmN, gemmK);
}
template <ck::index_t MPerBlock, ck::index_t NPerBlock>
inline ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
{
const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock);
const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock);
return M0 * N0;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck