mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Split-K autodeduction (#3351)
* First version of split-K autodeduction.
* Fix circular dependency and kernel construction.
* Fix tolerance calculation for bwd weight example.
* Simplify kernel construction.
* Fix kernel launching bug for split-K autodeduce.
* Add split-K autodeduction support for the two stage example.
* Fix a corner case.
* Fix clang-format.
* Fix clang-format for inc files.
* Add missing header.
* Prevent too large split-K values.
* Fix formatting.
* Add unit tests for IsSupportedArgument in grouped bwd conv.
* clang-format.
* Fix merge conflicts.
* Address feedback from code review.
* clang-format
* Fix new tests after merge.
---------
Co-authored-by: Ville Pietilä <>
[ROCm/composable_kernel commit: fc22320d78]
This commit is contained in:
@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -105,9 +105,9 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -130,7 +130,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.wei_ptr,
|
||||
@@ -140,10 +140,14 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
}
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
|
||||
@@ -17,8 +17,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using WorkspaceDataType = float;
|
||||
|
||||
@@ -118,9 +118,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
|
||||
ck_tile::GroupedConvBwdWeightHostArgs(args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
const auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
@@ -184,7 +184,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
if(kargs.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
@@ -192,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
const auto ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
@@ -206,6 +206,10 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
const auto split_k = kargs.k_batch;
|
||||
|
||||
return InvokerResult{ave_time, split_k};
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
|
||||
@@ -132,3 +132,9 @@ auto create_args(int argc, char* argv[])
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
struct InvokerResult
|
||||
{
|
||||
float ave_time;
|
||||
ck_tile::index_t split_k;
|
||||
};
|
||||
|
||||
@@ -14,22 +14,22 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
InvokerResult invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
auto res = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
return ave_time;
|
||||
return res;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
@@ -132,16 +132,17 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
auto res = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
const float ave_time = res.ave_time;
|
||||
|
||||
weight_dev_buf.FromDevice(weight.data());
|
||||
|
||||
@@ -172,9 +173,11 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
|
||||
|
||||
const ck_tile::index_t split_k = res.split_k;
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
GemmK, split_k, max_accumulated_value);
|
||||
pass = ck_tile::check_err(weight,
|
||||
weight_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
|
||||
@@ -70,6 +70,24 @@ inline bool is_load_tr_supported()
|
||||
// Check if load transpose is supported.
|
||||
return get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
inline size_t get_num_cus()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
return static_cast<size_t>(props.multiProcessorCount);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
|
||||
#endif
|
||||
@@ -62,8 +64,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -104,11 +104,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,8 +150,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -189,11 +190,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,8 +243,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -281,11 +283,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,7 +403,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
using GroupedConvBwdWeightKernelArgsSpecialized =
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = true;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -476,7 +480,24 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
|
||||
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
|
||||
}
|
||||
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
|
||||
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
|
||||
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
|
||||
TilePartitioner_,
|
||||
GemmPipeline_,
|
||||
EpiloguePipeline_>;
|
||||
|
||||
// Negative k_batch value: split-K autodeduction.
|
||||
if(kernel_args.k_batch < 0)
|
||||
{
|
||||
const auto optimal_split_k =
|
||||
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
|
||||
kernel_args);
|
||||
kernel_args.k_batch = optimal_split_k;
|
||||
}
|
||||
|
||||
return kernel_args;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -514,15 +535,54 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
|
||||
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
|
||||
{
|
||||
// The epilogue performs atomic add related to split-K using the ODataType.
|
||||
// If the type is less accurate than float, large split-K values may lead to
|
||||
// accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
|
||||
if(kargs.k_batch > 128)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"For epilogue output data type that is not float/double, we must have "
|
||||
"k_batch <= 128.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
CK_TILE_ERROR("Conditions not met for K_batch > 1!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
|
||||
CK_TILE_HOST index_t get_max_occupancy_for_kernel()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr int min_blocks_per_cu = 1;
|
||||
|
||||
const auto kernel_ptr = kentry<min_blocks_per_cu, KernelImpl, KernelArgs>;
|
||||
|
||||
int max_occupancy = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy, kernel_ptr, BlockSize, dynamic_smem_size));
|
||||
|
||||
return static_cast<index_t>(max_occupancy);
|
||||
}
|
||||
|
||||
CK_TILE_HOST index_t get_best_occupancy_k_batch_value(index_t max_occupancy, index_t grid_size)
|
||||
{
|
||||
static const index_t num_cus = get_num_cus();
|
||||
const index_t max_capacity = max_occupancy * num_cus;
|
||||
|
||||
index_t k_batch = 1;
|
||||
const auto optimal_split = static_cast<index_t>(std::floor((1.0 * max_capacity) / grid_size));
|
||||
if(optimal_split > 1)
|
||||
{
|
||||
k_batch = optimal_split;
|
||||
}
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
|
||||
<< max_occupancy << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
|
||||
}
|
||||
return k_batch;
|
||||
}
|
||||
|
||||
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
CK_TILE_HOST ActiveWorkgroupsPerCU()
|
||||
{
|
||||
max_occupancy_ = get_max_occupancy_for_kernel<BlockSize, KernelArgs, KernelImpl>();
|
||||
}
|
||||
index_t max_occupancy_{1};
|
||||
};
|
||||
|
||||
template <index_t BlockSize, typename KernelImpl, typename TilePartitioner, typename KernelArgs>
|
||||
CK_TILE_HOST index_t calculate_optimal_k_batch(const KernelArgs& kargs)
|
||||
{
|
||||
static ActiveWorkgroupsPerCU<BlockSize, KernelArgs, KernelImpl> active_workgroups_per_cu;
|
||||
|
||||
const auto grid_size = TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN) * kargs.GemmBatch;
|
||||
auto optimal_k_batch =
|
||||
get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size);
|
||||
|
||||
const auto max_allowed_k_batch = kargs.GemmK;
|
||||
optimal_k_batch = std::min(optimal_k_batch, max_allowed_k_batch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << optimal_k_batch << std::endl;
|
||||
}
|
||||
|
||||
return optimal_k_batch;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -38,3 +38,4 @@ add_subdirectory(atomic_add_op)
|
||||
add_subdirectory(fmha)
|
||||
add_subdirectory(gemm_tile_engine)
|
||||
add_subdirectory(pooling)
|
||||
add_subdirectory(grouped_conv)
|
||||
|
||||
7
test/ck_tile/grouped_conv/CMakeLists.txt
Normal file
7
test/ck_tile/grouped_conv/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_conv_bwd_weight test_ck_tile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
@@ -0,0 +1,249 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
struct TestConvConfig
|
||||
{
|
||||
static constexpr index_t VectorSizeA = 4;
|
||||
static constexpr index_t VectorSizeB = 8;
|
||||
static constexpr index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr index_t M_Tile = 128;
|
||||
static constexpr index_t N_Tile = 128;
|
||||
static constexpr index_t K_Tile = 32;
|
||||
|
||||
static constexpr index_t M_Warp = 2;
|
||||
static constexpr index_t N_Warp = 2;
|
||||
static constexpr index_t K_Warp = 1;
|
||||
|
||||
static constexpr index_t M_Warp_Tile = 16;
|
||||
static constexpr index_t N_Warp_Tile = 16;
|
||||
static constexpr index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr GemmPipeline Pipeline = GemmPipeline::COMPUTE_V3;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr index_t NumGroupsToMerge = 1;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
// Helper to build full kernel type
|
||||
template <typename PrecType,
|
||||
typename ConvConfig,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
memory_operation_enum MemOp = memory_operation_enum::set,
|
||||
index_t NDimSpatial = 2>
|
||||
struct BuildKernel
|
||||
{
|
||||
using GemmShape = TileGemmShape<
|
||||
sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
|
||||
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
using ConvTraits = GroupedConvTraits<NDimSpatial,
|
||||
ConvolutionSpecialization::Default,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
tuple<>,
|
||||
OutLayout,
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = GemmSpatiallyLocalTilePartitioner<GemmShape, 8, 4>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
TileGemmUniversalTraits<ConvTraits::FixedGemmParams::kPadM,
|
||||
ConvTraits::FixedGemmParams::kPadN,
|
||||
ConvTraits::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename ConvTraits::AsLayoutBwdWeight,
|
||||
typename ConvTraits::BsLayoutBwdWeight,
|
||||
typename ConvTraits::CLayoutBwdWeight,
|
||||
ConvTraits::FixedGemmParams::TransposeC,
|
||||
ConvTraits::FixedGemmParams::UseStructuredSparsity,
|
||||
ConvTraits::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
GemmPipelineProblem<PrecType, // OutDataType (A in bwd weight)
|
||||
PrecType, // InDataType (B in bwd weight)
|
||||
float, // AccDataType
|
||||
GemmShape,
|
||||
typename ConvTraits::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
element_wise::PassThrough,
|
||||
element_wise::PassThrough,
|
||||
PrecType, // WeiDataType (C in bwd weight)
|
||||
ConvTraits::FixedGemmParams::FixedVectorSize,
|
||||
ConvTraits::VectorSizeA,
|
||||
ConvTraits::VectorSizeB>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
UniversalGemmPipelineProblem<PrecType,
|
||||
PrecType,
|
||||
float,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
ConvConfig::Scheduler,
|
||||
element_wise::PassThrough,
|
||||
element_wise::PassThrough,
|
||||
PrecType,
|
||||
ConvTraits::FixedGemmParams::FixedVectorSize,
|
||||
ConvTraits::VectorSizeA,
|
||||
ConvTraits::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using EpilogueProblem = CShuffleEpilogueProblem<PrecType,
|
||||
PrecType,
|
||||
tuple<>,
|
||||
float,
|
||||
PrecType,
|
||||
typename ConvTraits::ImplicitGemmDsLayout,
|
||||
typename ConvTraits::FixedGemmParams::ELayout,
|
||||
element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvTraits::FixedGemmParams::TransposeC,
|
||||
MemOp,
|
||||
ConvConfig::NumWaveGroups,
|
||||
ConvTraits::FixedGemmParams::FixedVectorSize,
|
||||
ConvTraits::VectorSizeC>;
|
||||
|
||||
using Epilogue = CShuffleEpilogue<EpilogueProblem>;
|
||||
|
||||
using type =
|
||||
GroupedConvolutionBackwardWeightKernel<ConvTraits, TilePartitioner, GemmPipeline, Epilogue>;
|
||||
};
|
||||
|
||||
// Helper to create 2D host args
|
||||
static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t left_pad_y,
|
||||
index_t left_pad_x,
|
||||
index_t right_pad_y,
|
||||
index_t right_pad_x,
|
||||
index_t k_batch = 1)
|
||||
{
|
||||
auto conv_param = conv::ConvParam{2,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Y, X},
|
||||
{Hi, Wi},
|
||||
{stride_y, stride_x},
|
||||
{dilation_y, dilation_x},
|
||||
{left_pad_y, left_pad_x},
|
||||
{right_pad_y, right_pad_x}};
|
||||
|
||||
return GroupedConvBwdWeightHostArgs{conv_param, nullptr, nullptr, {}, nullptr, k_batch};
|
||||
}
|
||||
|
||||
static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
|
||||
{
|
||||
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
|
||||
}
|
||||
|
||||
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, ValidKBatch)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
auto host_args_kbatch_1 = create_2d_host_args(1);
|
||||
auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_1));
|
||||
|
||||
auto host_args_kbatch_4 = create_2d_host_args(4);
|
||||
auto kargs_4 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_4);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_4));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
auto host_args_kbatch_0 = create_2d_host_args(0);
|
||||
auto kargs = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_0);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK,
|
||||
memory_operation_enum::atomic_add>::type;
|
||||
|
||||
// k_batch = 1 should fail with atomic_add
|
||||
auto host_args_kbatch_1 = create_2d_host_args(1);
|
||||
auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1));
|
||||
|
||||
// k_batch = 2 should pass
|
||||
auto host_args_kbatch_2 = create_2d_host_args(2);
|
||||
auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
// k_batch = 128 should pass
|
||||
auto host_args_kbatch_128 = create_2d_host_args(128);
|
||||
auto kargs_128 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
|
||||
|
||||
// k_batch = 129 should fail for half_t output
|
||||
auto host_args_kbatch_129 = create_2d_host_args(129);
|
||||
auto kargs_129 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));
|
||||
}
|
||||
Reference in New Issue
Block a user