mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#6132 (commit e97065d)
[CK] Fix divide-by-zero crash for grouped conv kernels (#6132) ## Motivation During run pytorch unit tests for conv3d: `test_dtypes_nn_functional_conv3d_cuda`, `test_fake_crossref_backward_amp_nn_functional_conv3d_cuda_float32` found divide-by-zero crash during CK kernel selection. Refs ROCM-20764 ## Technical Details Add assert for K0PerBlock equal 0, also covered other potential places related with k_batch calculation. ## Test Plan Run miopen command extracted from mentioned test: `MIOpenDriver convfp16 --spatial_dim 3 -I NCDHW -O NCDHW -f NCDHW -n 1 -c 1 -k 1 -g 1 --in_d 4 -H 4 -W 4 --fil_d 4 -y 4 -x 4 --pad_d 0 -p 0 -q 0 --conv_stride_d 2 -u 2 -v 2 --dilation_d 1 -l 1 -j 1 -m conv -F 4 -t 1` ## Test Result Passed ## Submission Checklist - [X] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Signed-off-by: Artem Kuzmitckii <artem.kuzmitckii@amd.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
793a59736a
commit
281d1bf50b
@@ -19,6 +19,7 @@
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
@@ -853,6 +854,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
|
||||
@@ -179,6 +179,7 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
|
||||
@@ -670,6 +670,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
|
||||
@@ -695,6 +695,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
|
||||
@@ -611,6 +611,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
|
||||
@@ -717,6 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
// Create initial descriptors with hack=false to check compactness
|
||||
const auto descs_initial =
|
||||
|
||||
@@ -555,6 +555,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
|
||||
@@ -669,6 +669,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes are divisible by k_batch
|
||||
|
||||
@@ -638,6 +638,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
k_batch_ = clamp_gemm_k_batch(k_batch_);
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes match product of dimensions
|
||||
|
||||
@@ -13,6 +13,13 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/// Ensures GemmKBatch in conv to GEMM transforms is never 0 (would zero the divisor in
|
||||
/// integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch)).
|
||||
inline constexpr index_t clamp_gemm_k_batch(index_t k_batch) noexcept
|
||||
{
|
||||
return k_batch < 1 ? index_t{1} : k_batch;
|
||||
}
|
||||
|
||||
struct DeviceProperties
|
||||
{
|
||||
DeviceProperties()
|
||||
@@ -33,6 +40,10 @@ inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index
|
||||
const int max_capacity = max_occupancy * device_properties.num_cu_;
|
||||
|
||||
ck::index_t k_batch = 1;
|
||||
if(grid_size <= 0)
|
||||
{
|
||||
return k_batch;
|
||||
}
|
||||
const auto optimal_split =
|
||||
static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
|
||||
if(optimal_split > 1)
|
||||
|
||||
@@ -21,6 +21,10 @@ template <index_t NDimSpatial,
|
||||
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
|
||||
struct TransformConvBwdWeightToGemm
|
||||
{
|
||||
// Same contract as TransformConvBwdWeightToGemmV2 (non-zero K tile factors).
|
||||
static_assert(GemmK1Number > 0, "GemmK1Number must be positive");
|
||||
static_assert(K0PerBlock > 0, "K0PerBlock must be positive");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
|
||||
@@ -31,6 +31,11 @@ template <index_t NDimSpatial,
|
||||
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
|
||||
struct TransformConvBwdWeightToGemmV2
|
||||
{
|
||||
// Compile-time contract: divisor GemmK1Number * K0PerBlock * GemmKBatch in
|
||||
// integer_divide_ceil(GemmKTotal, ...) must stay non-zero (GemmKBatch clamped at runtime).
|
||||
static_assert(GemmK1Number > 0, "GemmK1Number must be positive");
|
||||
static_assert(K0PerBlock > 0, "K0PerBlock must be positive");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user