mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK] Allow tensors larger than 2GB in grouped conv bwd weight (#3169)
* Take split_k into account when checking 2GB tensor limit. * Revert "Take split_k into account when checking 2GB tensor limit." This reverts commitadf35c91be. * Optimize grouped conv bwd wei split_k off calc (cherry picked from commit6f61dd56c5) * Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp (cherry picked from commitb33877c10f) * Fix tensor descriptors and stride calculations * Don't miss half of the elements * Fix buffer size calculations * Disable hack if stride not divisible by k_batch * Clean up comments * Disallow hack in non-contiguous edge cases * Index -> Dim * Fix broken test * Refactor applicability checks into separate function * fix missed variable name * Fix variable name in info print * update V3 2GB check * No more regression, use templates instead * Code deduplication * Regression fix for cshuffle * arch-guarded atomic_add implementations for gfx11 * Similar for half(4|8)_t as well * Only use both offset hacks at the same time * Revert "arch-guarded atomic_add implementations for gfx11" This reverts commit3883fe6935. This reverts commit5311ec608d. * Reapply "arch-guarded atomic_add implementations for gfx11" This reverts commit1972adeddc. * Only remove float4 atomic_add * Refactor to single flag * Consolidate template parameters * Consolidate flag in transformers --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -24,6 +24,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -60,13 +61,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
@@ -77,23 +84,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
@@ -118,14 +131,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
@@ -139,24 +158,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
DispatchSplitKHack_2Lds<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = split_k_offset_hack;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
@@ -693,7 +718,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
// Create initial descriptors with hack=false to check compactness
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
@@ -709,11 +735,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
ce_grid_desc_m_n_ = descs[I2];
|
||||
k_batch_,
|
||||
false, // hack=false for initial check
|
||||
true); // use_full_batch_kindex
|
||||
|
||||
ce_elementwise_grid_desc_m_n_ =
|
||||
conv_to_gemm_transformer_v1
|
||||
@@ -733,6 +757,67 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_right_pads,
|
||||
k_batch_)[I2];
|
||||
|
||||
split_k_offset_hack_ =
|
||||
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
|
||||
descs_initial[I0],
|
||||
descs_initial[I1],
|
||||
k_batch_,
|
||||
Conv_N_,
|
||||
output_spatial_lengths_,
|
||||
KPerBlock);
|
||||
|
||||
// Create final descriptors with correct hack flag
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
split_k_offset_hack_, // Use determined hack flag
|
||||
true); // use_full_batch_kindex
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
ce_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Step 5: Calculate stride using CalculateOffset on FINAL descriptors
|
||||
if(split_k_offset_hack_)
|
||||
{
|
||||
const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_;
|
||||
const auto idx_start = make_multi_index(0, 0, 0);
|
||||
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.CalculateOffset(idx_next) -
|
||||
a_grid_desc_k0_m_k1_.CalculateOffset(idx_start);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
if(split_k_offset_hack_)
|
||||
{
|
||||
const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_;
|
||||
const auto idx_start = make_multi_index(0, 0, 0);
|
||||
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.CalculateOffset(idx_next) -
|
||||
b_grid_desc_k0_n_k1_.CalculateOffset(idx_start);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
@@ -869,6 +954,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -971,7 +1059,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -987,7 +1078,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1920,14 +2014,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -33,6 +34,74 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dispatch helper function for split-K hack - handles 2-way dispatch based on runtime flag
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid,
|
||||
const FloatB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
void* p_shared,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
if(split_k_offset_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
@@ -62,7 +131,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -79,17 +152,33 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
DispatchBatchedGemmSplitKHack<GridwiseGemm,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
AGridDesc_B_K0_M_K1,
|
||||
BGridDesc_B_K0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
Block2CTileMap,
|
||||
HasMainKBlockLoop>(
|
||||
p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_hack,
|
||||
k_batch);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -104,6 +193,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = batch_count;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
ignore = k_batch;
|
||||
|
||||
compute_ptr_offset_of_batch.GetAPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetBPtrOffset(0);
|
||||
@@ -459,7 +552,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
false>, // Both true/false give the same occupancy.
|
||||
false>, // HasMainKBlockLoop - both true/false give the same occupancy
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
return std::max(1, max_occupancy);
|
||||
@@ -576,6 +669,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes are divisible by k_batch
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
false); // split_k_offset_b_hack (temporary)
|
||||
|
||||
split_k_offset_hack_ =
|
||||
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
|
||||
descs_initial[I0],
|
||||
descs_initial[I1],
|
||||
k_batch_,
|
||||
Conv_N_,
|
||||
output_spatial_lengths_,
|
||||
K0PerBlock * K1);
|
||||
|
||||
// Now create descriptors with the correct hack flag
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -592,12 +716,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
k_batch_,
|
||||
split_k_offset_hack_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Calculate stride using CalculateOffset method for accurate stride
|
||||
// This works correctly for any descriptor transform pipeline
|
||||
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
@@ -732,6 +867,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -878,7 +1016,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_,
|
||||
arg.k_batch_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
@@ -58,13 +59,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
@@ -74,20 +81,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -96,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = num_k_per_block;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
|
||||
#endif // end of if (defined(__gfx9__)
|
||||
}
|
||||
|
||||
@@ -119,14 +134,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
@@ -140,21 +161,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
DispatchSplitKHack_2Lds<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -163,6 +187,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = num_k_per_block;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
#endif // end of if (defined(__gfx9__)
|
||||
}
|
||||
|
||||
@@ -490,8 +517,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
: p_a_grid_{p_out_grid},
|
||||
p_b_grid_{p_in_grid},
|
||||
p_c_grid_{p_wei_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
@@ -560,6 +587,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes match product of dimensions
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
false, // split_k_offset_b_hack (temporary)
|
||||
true); // use_full_batch_kindex=true for V1-compatible descriptors
|
||||
|
||||
split_k_offset_hack_ =
|
||||
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
|
||||
descs_initial[I0],
|
||||
descs_initial[I1],
|
||||
k_batch_,
|
||||
Conv_N_,
|
||||
output_spatial_lengths_,
|
||||
K0PerBlock);
|
||||
|
||||
// Now create descriptors with the correct hack flag
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -576,11 +635,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
k_batch_,
|
||||
split_k_offset_hack_,
|
||||
true); // use_full_batch_kindex=true for V1-compatible descriptors
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Calculate stride using CalculateOffset method for accurate stride
|
||||
// This works correctly for any descriptor transform pipeline
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
@@ -591,8 +662,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -604,8 +675,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
@@ -631,6 +702,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -640,17 +714,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
void ShowInfo(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
@@ -659,10 +731,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
template <typename GridwiseGemm>
|
||||
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
@@ -680,7 +752,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto num_k_per_block =
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.k_batch_ > 1)
|
||||
@@ -716,11 +788,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -732,11 +807,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -749,7 +827,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
@@ -781,7 +859,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
@@ -1090,7 +1168,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
@@ -1159,7 +1237,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
else
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
@@ -1232,7 +1310,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
@@ -1289,10 +1367,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
#endif
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
@@ -1423,9 +1501,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
|
||||
const bool a_small_enough = arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() /
|
||||
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
|
||||
sizeof(ADataType) <=
|
||||
TwoGB;
|
||||
const bool b_small_enough = arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() /
|
||||
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
|
||||
sizeof(BDataType) <=
|
||||
TwoGB;
|
||||
const bool c_small_enough =
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB;
|
||||
if(!(a_small_enough && b_small_enough && c_small_enough))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Check if a tensor descriptor has compact layout
|
||||
// Compact means: GetElementSpaceSize() == product of all dimension lengths
|
||||
// Non-compact descriptors have complex transform pipelines that may not support split-k hack
|
||||
template <typename Descriptor>
|
||||
bool IsDescriptorCompact(const Descriptor& desc)
|
||||
{
|
||||
// Calculate product of all dimensions
|
||||
long_index_t dims_product = 1;
|
||||
constexpr index_t num_dims = Descriptor::GetNumOfDimension();
|
||||
|
||||
// Use template recursion to multiply all dimension lengths
|
||||
static_for<0, num_dims, 1>{}(
|
||||
[&](auto i) { dims_product *= static_cast<long_index_t>(desc.GetLength(i)); });
|
||||
|
||||
return desc.GetElementSpaceSize() == dims_product;
|
||||
}
|
||||
|
||||
// Determine split-k hack eligibility for descriptor pair
|
||||
// This checks all the conditions required for safely using the split-k offset hack
|
||||
template <index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
struct SplitKHackEligibility
|
||||
{
|
||||
template <typename ADescriptor, typename BDescriptor>
|
||||
static bool
|
||||
Check(const ADescriptor& a_desc,
|
||||
const BDescriptor& b_desc,
|
||||
index_t k_batch,
|
||||
index_t Conv_N,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage
|
||||
{
|
||||
// Only enable hack if k_batch > 1
|
||||
if(k_batch <= 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Calculate output spatial product
|
||||
const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(),
|
||||
output_spatial_lengths.end(),
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// Check various divisibility and layout requirements
|
||||
const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0;
|
||||
|
||||
const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0;
|
||||
|
||||
const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0;
|
||||
|
||||
const bool is_correct_layout =
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0;
|
||||
|
||||
const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0;
|
||||
|
||||
// Check descriptor compactness
|
||||
const bool is_a_compact = IsDescriptorCompact(a_desc);
|
||||
const bool is_b_compact = IsDescriptorCompact(b_desc);
|
||||
|
||||
// Require BOTH A and B to be eligible for the hack to avoid KBatch dimension mismatch
|
||||
// The gridwise kernel's CheckValidity requires A.KBatch == B.KBatch, so we must
|
||||
// apply the hack uniformly to both tensors to maintain kernel applicability
|
||||
const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch &&
|
||||
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
|
||||
is_b_stride_divisible && is_a_compact && is_b_compact;
|
||||
|
||||
return eligible;
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to dispatch split-K hack for standard kernel (single LDS)
|
||||
// Reduces code duplication in device layer implementations
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
if(split_k_offset_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to dispatch split-K hack for 2lds kernel
|
||||
// Reduces code duplication in device layer implementations
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
if(split_k_offset_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user