mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Split-K autodeduction (#3351)
* First version of split-K autodeduction. * Fix circular dependency and kernel construction. * Fix tolerance calculation for bwd weight example. * Simplify kernel construction. * Fix kernel launching bug for split-K autodeduce. * Add split-K autodeduction support for the two stage example. * Fix a corner case. * Fix clang-format. * Fix clang-format for inc files. * Add missing header. * Prevent too large split-K values. * Fix formatting. * Add unit tests for IsSupportedArgument in grouped bwd conv. * clang-format. * Fix merge conflicts. * Address feedback from code review. * clang-format * Fix new tests after merge. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
|
||||
#endif
|
||||
@@ -62,8 +64,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -104,11 +104,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,8 +150,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -189,11 +190,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,8 +243,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
@@ -281,11 +283,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,7 +403,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
using GroupedConvBwdWeightKernelArgsSpecialized =
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = true;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -476,7 +480,24 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
|
||||
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
|
||||
}
|
||||
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
|
||||
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
|
||||
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
|
||||
TilePartitioner_,
|
||||
GemmPipeline_,
|
||||
EpiloguePipeline_>;
|
||||
|
||||
// Negative k_batch value: split-K autodeduction.
|
||||
if(kernel_args.k_batch < 0)
|
||||
{
|
||||
const auto optimal_split_k =
|
||||
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
|
||||
kernel_args);
|
||||
kernel_args.k_batch = optimal_split_k;
|
||||
}
|
||||
|
||||
return kernel_args;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -514,15 +535,54 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
|
||||
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
|
||||
{
|
||||
// The epilogue performs atomic add related to split-K using the ODataType.
|
||||
// If the type is less accurate than float, large split-K values may lead to
|
||||
// accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
|
||||
if(kargs.k_batch > 128)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"For epilogue output data type that is not float/double, we must have "
|
||||
"k_batch <= 128.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
CK_TILE_ERROR("Conditions not met for K_batch > 1!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
|
||||
CK_TILE_HOST index_t get_max_occupancy_for_kernel()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr int min_blocks_per_cu = 1;
|
||||
|
||||
const auto kernel_ptr = kentry<min_blocks_per_cu, KernelImpl, KernelArgs>;
|
||||
|
||||
int max_occupancy = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy, kernel_ptr, BlockSize, dynamic_smem_size));
|
||||
|
||||
return static_cast<index_t>(max_occupancy);
|
||||
}
|
||||
|
||||
CK_TILE_HOST index_t get_best_occupancy_k_batch_value(index_t max_occupancy, index_t grid_size)
|
||||
{
|
||||
static const index_t num_cus = get_num_cus();
|
||||
const index_t max_capacity = max_occupancy * num_cus;
|
||||
|
||||
index_t k_batch = 1;
|
||||
const auto optimal_split = static_cast<index_t>(std::floor((1.0 * max_capacity) / grid_size));
|
||||
if(optimal_split > 1)
|
||||
{
|
||||
k_batch = optimal_split;
|
||||
}
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
|
||||
<< max_occupancy << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
|
||||
}
|
||||
return k_batch;
|
||||
}
|
||||
|
||||
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
CK_TILE_HOST ActiveWorkgroupsPerCU()
|
||||
{
|
||||
max_occupancy_ = get_max_occupancy_for_kernel<BlockSize, KernelArgs, KernelImpl>();
|
||||
}
|
||||
index_t max_occupancy_{1};
|
||||
};
|
||||
|
||||
template <index_t BlockSize, typename KernelImpl, typename TilePartitioner, typename KernelArgs>
|
||||
CK_TILE_HOST index_t calculate_optimal_k_batch(const KernelArgs& kargs)
|
||||
{
|
||||
static ActiveWorkgroupsPerCU<BlockSize, KernelArgs, KernelImpl> active_workgroups_per_cu;
|
||||
|
||||
const auto grid_size = TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN) * kargs.GemmBatch;
|
||||
auto optimal_k_batch =
|
||||
get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size);
|
||||
|
||||
const auto max_allowed_k_batch = kargs.GemmK;
|
||||
optimal_k_batch = std::min(optimal_k_batch, max_allowed_k_batch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << optimal_k_batch << std::endl;
|
||||
}
|
||||
|
||||
return optimal_k_batch;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user