Refactor applicability checks into separate function

This commit is contained in:
Graner, Johannes
2025-12-01 12:26:32 +00:00
parent 9e3e1b6935
commit 32b3a538d9
4 changed files with 119 additions and 123 deletions

View File

@@ -24,6 +24,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -716,7 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
k_batch_ = split_k;
}
// Step 1: Create initial descriptors with hack=false to check compactness
// Create initial descriptors with hack=false to check compactness
const auto descs_initial =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -738,17 +739,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
false, // hack=false for initial check
true); // use_full_batch_kindex
// Step 2: Check if descriptors are compact (element_space == product of dimensions)
const auto a_dims_product = static_cast<long_index_t>(descs_initial[I0].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I2));
const auto b_dims_product = static_cast<long_index_t>(descs_initial[I1].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I2));
const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product);
const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product);
ce_elementwise_grid_desc_m_n_ =
conv_to_gemm_transformer_v1
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -767,35 +757,16 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads,
k_batch_)[I2];
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
std::tie(split_k_offset_a_hack_, split_k_offset_b_hack_) =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
KPerBlock);
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
const bool can_divide_n_spatial_by_k_batch =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible =
descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0;
const bool is_b_stride_divisible =
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
// Step 3: Determine if hack can be enabled (only for compact layouts)
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_a_compact;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible && is_b_compact;
// Step 4: Create final descriptors with correct hack flags
// Create final descriptors with correct hack flags
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(

View File

@@ -21,6 +21,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -612,48 +613,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
false, // split_k_offset_a_hack (temporary)
false); // split_k_offset_b_hack (temporary)
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
const bool can_divide_n_spatial_by_k_batch =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible =
descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0;
const bool is_b_stride_divisible =
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
// Check if descriptor has compact layout (product of dimensions equals element space)
// Non-compact layouts have complex transform pipelines that don't support the hack
const auto a_dims_product = static_cast<long_index_t>(descs_initial[I0].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I2)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I3));
const auto b_dims_product = static_cast<long_index_t>(descs_initial[I1].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I2)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I3));
const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product);
const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product);
// Determine if we can safely use the split-k offset hack
// Only enable for compact layouts where element_space_size == product of dimensions
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_a_compact;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible && is_b_compact;
std::tie(split_k_offset_a_hack_, split_k_offset_b_hack_) =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
K0PerBlock * K1);
// Now create descriptors with the correct hack flags
const auto descs =

View File

@@ -22,6 +22,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_descriptor_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
@@ -607,47 +608,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
false, // split_k_offset_b_hack (temporary)
true); // use_full_batch_kindex=true for V1-compatible descriptors
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
const bool can_divide_n_spatial_by_k_batch =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible =
descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0;
const bool is_b_stride_divisible =
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
// Check if descriptor has compact layout (product of dimensions equals element space)
// Non-compact layouts have complex transform pipelines that don't support the hack
// Note: CShuffleV3 descriptors are 3D [K0, M, K1], not 4D like CShuffle
const auto a_dims_product = static_cast<long_index_t>(descs_initial[I0].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I2));
const auto b_dims_product = static_cast<long_index_t>(descs_initial[I1].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I2));
const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product);
const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product);
// Determine if we can safely use the split-k offset hack
// Only enable for compact layouts where element_space_size == product of dimensions
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_a_compact;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible && is_b_compact;
std::tie(split_k_offset_a_hack_, split_k_offset_b_hack_) =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
K0PerBlock);
// Now create descriptors with the correct hack flags
const auto descs =

View File

@@ -0,0 +1,90 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <numeric>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Check if a tensor descriptor has compact layout
// Compact means: GetElementSpaceSize() == product of all dimension lengths
// Non-compact descriptors have complex transform pipelines that may not support split-k hack
template <typename Descriptor>
bool IsDescriptorCompact(const Descriptor& desc)
{
// Calculate product of all dimensions
long_index_t dims_product = 1;
constexpr index_t num_dims = Descriptor::GetNumOfDimension();
// Use template recursion to multiply all dimension lengths
static_for<0, num_dims, 1>{}(
[&](auto i) { dims_product *= static_cast<long_index_t>(desc.GetLength(i)); });
return desc.GetElementSpaceSize() == dims_product;
}
// Determine split-k hack eligibility for descriptor pair
// This checks all the conditions required for safely using the split-k offset hack
template <index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
struct SplitKHackEligibility
{
template <typename ADescriptor, typename BDescriptor>
static auto
Check(const ADescriptor& a_desc,
const BDescriptor& b_desc,
index_t k_batch,
index_t Conv_N,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage
{
// Only enable hack if k_batch > 1
if(k_batch <= 1)
{
return std::make_pair(false, false);
}
// Calculate output spatial product
const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(),
output_spatial_lengths.end(),
index_t{1},
std::multiplies<index_t>());
// Check various divisibility and layout requirements
const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0;
const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0;
const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0;
const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0;
// Check descriptor compactness
const bool is_a_compact = IsDescriptorCompact(a_desc);
const bool is_b_compact = IsDescriptorCompact(b_desc);
// Determine hack flags based on all conditions
const bool split_k_offset_a_hack = can_divide_n_spatial_by_k_batch && is_k_not_paded &&
is_correct_layout && is_a_stride_divisible &&
is_a_compact;
const bool split_k_offset_b_hack = can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible &&
is_b_compact;
return std::make_pair(split_k_offset_a_hack, split_k_offset_b_hack);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck