[CK] Allow tensors larger than 2GB in grouped conv bwd weight (#3169)

* Take split_k into account when checking 2GB tensor limit.

* Revert "Take split_k into account when checking 2GB tensor limit."

This reverts commit adf35c91be.

* Optimize grouped conv bwd wei split_k off calc

(cherry picked from commit 6f61dd56c5)

* Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp

(cherry picked from commit b33877c10f)

* Fix tensor descriptors and stride calculations

* Don't miss half of the elements

* Fix buffer size calculations

* Disable hack if stride not divisible by k_batch

* Clean up comments

* Disallow hack in non-contiguous edge cases

* Index -> Dim

* Fix broken test

* Refactor applicability checks into separate function

* fix missed variable name

* Fix variable name in info print

* update V3 2GB check

* No more regression, use templates instead

* Code deduplication

* Regression fix for cshuffle

* arch-guarded atomic_add implementations for gfx11

* Similar for half(4|8)_t as well

* Only use both offset hacks at the same time

* Revert "arch-guarded atomic_add implementations for gfx11"

This reverts commit 3883fe6935.
This reverts commit 5311ec608d.

* Reapply "arch-guarded atomic_add implementations for gfx11"

This reverts commit 1972adeddc.

* Only remove float4 atomic_add

* Refactor to single flag

* Consolidate template parameters

* Consolidate flag in transformers

---------

Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
Johannes Graner
2026-01-08 08:02:02 +01:00
committed by GitHub
parent bc497beffb
commit ee2c35b92d
9 changed files with 1286 additions and 202 deletions

View File

@@ -24,6 +24,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -60,13 +61,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
[[maybe_unused]] const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
@@ -77,23 +84,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
#endif // end of if (defined(__gfx9__))
}
@@ -118,14 +131,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
[[maybe_unused]] const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
@@ -139,24 +158,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack_2Lds<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
ignore = split_k_offset_hack;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
#endif // end of if (defined(__gfx9__))
}
@@ -693,7 +718,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
k_batch_ = split_k;
}
const auto descs =
// Create initial descriptors with hack=false to check compactness
const auto descs_initial =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
@@ -709,11 +735,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
k_batch_,
false, // hack=false for initial check
true); // use_full_batch_kindex
ce_elementwise_grid_desc_m_n_ =
conv_to_gemm_transformer_v1
@@ -733,6 +757,67 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads,
k_batch_)[I2];
split_k_offset_hack_ =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
KPerBlock);
// Create final descriptors with correct hack flag
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
split_k_offset_hack_, // Use determined hack flag
true); // use_full_batch_kindex
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
// Step 5: Calculate stride using CalculateOffset on FINAL descriptors
if(split_k_offset_hack_)
{
const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_;
const auto idx_start = make_multi_index(0, 0, 0);
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
split_k_stride_a_ = a_grid_desc_k0_m_k1_.CalculateOffset(idx_next) -
a_grid_desc_k0_m_k1_.CalculateOffset(idx_start);
}
else
{
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
}
if(split_k_offset_hack_)
{
const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_;
const auto idx_start = make_multi_index(0, 0, 0);
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
split_k_stride_b_ = b_grid_desc_k0_n_k1_.CalculateOffset(idx_next) -
b_grid_desc_k0_n_k1_.CalculateOffset(idx_start);
}
else
{
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
}
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -869,6 +954,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -971,7 +1059,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
else
{
@@ -987,7 +1078,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
};
@@ -1920,14 +2014,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}

View File

@@ -21,6 +21,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -33,6 +34,74 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Dispatch helper function for split-K hack - handles 2-way dispatch based on runtime flag
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid,
const FloatB* p_b_grid,
FloatC* p_c_grid,
void* p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack,
index_t k_batch)
{
if(split_k_offset_hack)
{
GridwiseGemm::template Run<HasMainKBlockLoop, true>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, false>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
}
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
@@ -62,7 +131,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack,
index_t k_batch)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -79,17 +152,33 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
DispatchBatchedGemmSplitKHack<GridwiseGemm,
FloatA,
FloatB,
FloatC,
AGridDesc_B_K0_M_K1,
BGridDesc_B_K0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
Block2CTileMap,
HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
split_k_offset_hack,
k_batch);
}
#else
ignore = p_a_grid;
@@ -104,6 +193,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = batch_count;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
ignore = k_batch;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
@@ -459,7 +552,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<>,
false>, // Both true/false give the same occupancy.
false>, // HasMainKBlockLoop - both true/false give the same occupancy
BlockSize,
dynamic_smem_size));
return std::max(1, max_occupancy);
@@ -576,6 +669,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
k_batch_ = split_k;
}
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes are divisible by k_batch
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
false); // split_k_offset_b_hack (temporary)
split_k_offset_hack_ =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
K0PerBlock * K1);
// Now create descriptors with the correct hack flag
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -592,12 +716,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
k_batch_,
split_k_offset_hack_);
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride using CalculateOffset method for accurate stride
// This works correctly for any descriptor transform pipeline
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
block_2_ctile_map_ =
GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
@@ -732,6 +867,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -878,7 +1016,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
arg.b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
arg.compute_ptr_offset_of_batch_,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_,
arg.k_batch_);
};
if(has_main_k0_block_loop)

View File

@@ -22,6 +22,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
@@ -58,13 +59,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
@@ -74,20 +81,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
@@ -96,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = num_k_per_block;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
#endif // end of if (defined(__gfx9__)
}
@@ -119,14 +134,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
@@ -140,21 +161,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack_2Lds<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
@@ -163,6 +187,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = num_k_per_block;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
#endif // end of if (defined(__gfx9__)
}
@@ -490,8 +517,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
: p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{},
@@ -560,6 +587,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
k_batch_ = split_k;
}
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes match product of dimensions
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
false, // split_k_offset_b_hack (temporary)
true); // use_full_batch_kindex=true for V1-compatible descriptors
split_k_offset_hack_ =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
K0PerBlock);
// Now create descriptors with the correct hack flag
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -576,11 +635,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
k_batch_,
split_k_offset_hack_,
true); // use_full_batch_kindex=true for V1-compatible descriptors
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride using CalculateOffset method for accurate stride
// This works correctly for any descriptor transform pipeline
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
@@ -591,8 +662,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -604,8 +675,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
@@ -631,6 +702,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -640,17 +714,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
@@ -659,10 +731,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
template <typename GridwiseGemm>
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
@@ -680,7 +752,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto num_k_per_block =
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
if(arg.k_batch_ > 1)
@@ -716,11 +788,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
dim3(BlockSize),
0,
gemm_arg_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
else
{
@@ -732,11 +807,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
};
@@ -749,7 +827,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
@@ -781,7 +859,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
@@ -1090,7 +1168,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
@@ -1159,7 +1237,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
else
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
@@ -1232,7 +1310,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
@@ -1289,10 +1367,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
@@ -1423,9 +1501,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
const bool a_small_enough = arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() /
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
sizeof(ADataType) <=
TwoGB;
const bool b_small_enough = arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() /
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
sizeof(BDataType) <=
TwoGB;
const bool c_small_enough =
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB;
if(!(a_small_enough && b_small_enough && c_small_enough))
{
return false;
}

View File

@@ -0,0 +1,222 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <numeric>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Check if a tensor descriptor has compact layout
// Compact means: GetElementSpaceSize() == product of all dimension lengths
// Non-compact descriptors have complex transform pipelines that may not support split-k hack
template <typename Descriptor>
bool IsDescriptorCompact(const Descriptor& desc)
{
// Calculate product of all dimensions
long_index_t dims_product = 1;
constexpr index_t num_dims = Descriptor::GetNumOfDimension();
// Use template recursion to multiply all dimension lengths
static_for<0, num_dims, 1>{}(
[&](auto i) { dims_product *= static_cast<long_index_t>(desc.GetLength(i)); });
return desc.GetElementSpaceSize() == dims_product;
}
// Determine split-k hack eligibility for descriptor pair
// This checks all the conditions required for safely using the split-k offset hack
template <index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
struct SplitKHackEligibility
{
template <typename ADescriptor, typename BDescriptor>
static bool
Check(const ADescriptor& a_desc,
const BDescriptor& b_desc,
index_t k_batch,
index_t Conv_N,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage
{
// Only enable hack if k_batch > 1
if(k_batch <= 1)
{
return false;
}
// Calculate output spatial product
const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(),
output_spatial_lengths.end(),
index_t{1},
std::multiplies<index_t>());
// Check various divisibility and layout requirements
const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0;
const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0;
const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0;
const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0;
// Check descriptor compactness
const bool is_a_compact = IsDescriptorCompact(a_desc);
const bool is_b_compact = IsDescriptorCompact(b_desc);
// Require BOTH A and B to be eligible for the hack to avoid KBatch dimension mismatch
// The gridwise kernel's CheckValidity requires A.KBatch == B.KBatch, so we must
// apply the hack uniformly to both tensors to maintain kernel applicability
const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_b_stride_divisible && is_a_compact && is_b_compact;
return eligible;
}
};
// Helper function to dispatch split-K hack for standard kernel (single LDS)
// Reduces code duplication in device layer implementations
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_hack)
{
if(split_k_offset_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
// Helper function to dispatch split-K hack for 2lds kernel
// Reduces code duplication in device layer implementations
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_hack)
{
if(split_k_offset_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
} // namespace device
} // namespace tensor_operation
} // namespace ck