[CK] Allow tensors larger than 2GB in grouped conv bwd weight (#3169)

* Take split_k into account when checking 2GB tensor limit.

* Revert "Take split_k into account when checking 2GB tensor limit."

This reverts commit adf35c91be.

* Optimize grouped conv bwd wei split_k off calc

(cherry picked from commit 2115642ee59050dabd81393c1b8f03b34adc05aa)

* Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp

(cherry picked from commit 900d4d4b466f5730ae1189370d3c96267c35ea69)

* Fix tensor descriptors and stride calculations

* Don't miss half of the elements

* Fix buffer size calculations

* Disable hack if stride not divisible by k_batch

* Clean up comments

* Disallow hack in non-contiguous edge cases

* Index -> Dim

* Fix broken test

* Refactor applicability checks into separate function

* fix missed variable name

* Fix variable name in info print

* update V3 2GB check

* No more regression, use templates instead

* Code deduplication

* Regression fix for cshuffle

* arch-guarded atomic_add implementations for gfx11

* Similar for half(4|8)_t as well

* Only use both offset hacks at the same time

* Revert "arch-guarded atomic_add implementations for gfx11"

This reverts commit 3883fe6935.
This reverts commit 5311ec608d.

* Reapply "arch-guarded atomic_add implementations for gfx11"

This reverts commit 1972adeddc.

* Only remove float4 atomic_add

* Refactor to single flag

* Consolidate template parameters

* Consolidate flag in transformers

---------

Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>

[ROCm/composable_kernel commit: ee2c35b92d]
This commit is contained in:
Johannes Graner
2026-01-08 08:02:02 +01:00
committed by GitHub
parent 5e0d3e77b9
commit 9d6add54e5
9 changed files with 1286 additions and 202 deletions

View File

@@ -24,6 +24,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -60,13 +61,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
[[maybe_unused]] const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
@@ -77,23 +84,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
#endif // end of if (defined(__gfx9__))
}
@@ -118,14 +131,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
[[maybe_unused]] const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
@@ -139,24 +158,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack_2Lds<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
ignore = split_k_offset_hack;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
#endif // end of if (defined(__gfx9__))
}
@@ -693,7 +718,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
k_batch_ = split_k;
}
const auto descs =
// Create initial descriptors with hack=false to check compactness
const auto descs_initial =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
@@ -709,11 +735,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
k_batch_,
false, // hack=false for initial check
true); // use_full_batch_kindex
ce_elementwise_grid_desc_m_n_ =
conv_to_gemm_transformer_v1
@@ -733,6 +757,67 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads,
k_batch_)[I2];
split_k_offset_hack_ =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
KPerBlock);
// Create final descriptors with correct hack flag
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
split_k_offset_hack_, // Use determined hack flag
true); // use_full_batch_kindex
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
// Step 5: Calculate stride using CalculateOffset on FINAL descriptors
if(split_k_offset_hack_)
{
const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_;
const auto idx_start = make_multi_index(0, 0, 0);
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
split_k_stride_a_ = a_grid_desc_k0_m_k1_.CalculateOffset(idx_next) -
a_grid_desc_k0_m_k1_.CalculateOffset(idx_start);
}
else
{
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
}
if(split_k_offset_hack_)
{
const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_;
const auto idx_start = make_multi_index(0, 0, 0);
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
split_k_stride_b_ = b_grid_desc_k0_n_k1_.CalculateOffset(idx_next) -
b_grid_desc_k0_n_k1_.CalculateOffset(idx_start);
}
else
{
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
}
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -869,6 +954,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -971,7 +1059,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
else
{
@@ -987,7 +1078,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
};
@@ -1920,14 +2014,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}

View File

@@ -21,6 +21,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -33,6 +34,74 @@ namespace ck {
namespace tensor_operation {
namespace device {
// Dispatch helper function for split-K hack - handles 2-way dispatch based on runtime flag
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid,
const FloatB* p_b_grid,
FloatC* p_c_grid,
void* p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack,
index_t k_batch)
{
if(split_k_offset_hack)
{
GridwiseGemm::template Run<HasMainKBlockLoop, true>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, false>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
}
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
@@ -62,7 +131,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack,
index_t k_batch)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -79,17 +152,33 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
DispatchBatchedGemmSplitKHack<GridwiseGemm,
FloatA,
FloatB,
FloatC,
AGridDesc_B_K0_M_K1,
BGridDesc_B_K0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
Block2CTileMap,
HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map,
split_k_stride_a,
split_k_stride_b,
split_k_offset_hack,
k_batch);
}
#else
ignore = p_a_grid;
@@ -104,6 +193,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = batch_count;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
ignore = k_batch;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
@@ -459,7 +552,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<>,
false>, // Both true/false give the same occupancy.
false>, // HasMainKBlockLoop - both true/false give the same occupancy
BlockSize,
dynamic_smem_size));
return std::max(1, max_occupancy);
@@ -576,6 +669,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
k_batch_ = split_k;
}
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes are divisible by k_batch
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
false); // split_k_offset_b_hack (temporary)
split_k_offset_hack_ =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
K0PerBlock * K1);
// Now create descriptors with the correct hack flag
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -592,12 +716,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
k_batch_,
split_k_offset_hack_);
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride using CalculateOffset method for accurate stride
// This works correctly for any descriptor transform pipeline
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
block_2_ctile_map_ =
GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
@@ -732,6 +867,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -878,7 +1016,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
arg.b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
arg.compute_ptr_offset_of_batch_,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_,
arg.k_batch_);
};
if(has_main_k0_block_loop)

View File

@@ -22,6 +22,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
@@ -58,13 +59,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
@@ -74,20 +81,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
@@ -96,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = num_k_per_block;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
#endif // end of if (defined(__gfx9__)
}
@@ -119,14 +134,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
const index_t num_k_per_block,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_hack)
{
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
@@ -140,21 +161,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
DispatchSplitKHack_2Lds<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
#else
ignore = karg;
@@ -163,6 +187,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = compute_ptr_offset_of_batch;
ignore = num_k_per_block;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_hack;
#endif // end of if (defined(__gfx9__)
}
@@ -490,8 +517,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
: p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{},
@@ -560,6 +587,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
k_batch_ = split_k;
}
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes match product of dimensions
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
false, // split_k_offset_b_hack (temporary)
true); // use_full_batch_kindex=true for V1-compatible descriptors
split_k_offset_hack_ =
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
descs_initial[I0],
descs_initial[I1],
k_batch_,
Conv_N_,
output_spatial_lengths_,
K0PerBlock);
// Now create descriptors with the correct hack flag
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -576,11 +635,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
k_batch_,
split_k_offset_hack_,
true); // use_full_batch_kindex=true for V1-compatible descriptors
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride using CalculateOffset method for accurate stride
// This works correctly for any descriptor transform pipeline
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
@@ -591,8 +662,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -604,8 +675,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
@@ -631,6 +702,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
long_index_t c_space_size_bytes;
bool split_k_offset_hack_;
long_index_t split_k_stride_a_, split_k_stride_b_;
};
// Invoker
@@ -640,17 +714,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
@@ -659,10 +731,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
template <typename GridwiseGemm>
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
@@ -680,7 +752,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto num_k_per_block =
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
if(arg.k_batch_ > 1)
@@ -716,11 +788,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
dim3(BlockSize),
0,
gemm_arg_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
else
{
@@ -732,11 +807,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
num_k_per_block,
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_hack_);
}
};
@@ -749,7 +827,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
@@ -781,7 +859,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
@@ -1090,7 +1168,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
@@ -1159,7 +1237,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
else
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
@@ -1232,7 +1310,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(gemm_arg.KBatch > 1)
if(arg.k_batch_ > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
@@ -1289,10 +1367,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
@@ -1423,9 +1501,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
const bool a_small_enough = arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() /
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
sizeof(ADataType) <=
TwoGB;
const bool b_small_enough = arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() /
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
sizeof(BDataType) <=
TwoGB;
const bool c_small_enough =
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB;
if(!(a_small_enough && b_small_enough && c_small_enough))
{
return false;
}

View File

@@ -0,0 +1,222 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <numeric>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Check if a tensor descriptor has compact layout
// Compact means: GetElementSpaceSize() == product of all dimension lengths
// Non-compact descriptors have complex transform pipelines that may not support split-k hack
template <typename Descriptor>
bool IsDescriptorCompact(const Descriptor& desc)
{
// Calculate product of all dimensions
long_index_t dims_product = 1;
constexpr index_t num_dims = Descriptor::GetNumOfDimension();
// Use template recursion to multiply all dimension lengths
static_for<0, num_dims, 1>{}(
[&](auto i) { dims_product *= static_cast<long_index_t>(desc.GetLength(i)); });
return desc.GetElementSpaceSize() == dims_product;
}
// Determine split-k hack eligibility for descriptor pair
// This checks all the conditions required for safely using the split-k offset hack
template <index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
struct SplitKHackEligibility
{
template <typename ADescriptor, typename BDescriptor>
static bool
Check(const ADescriptor& a_desc,
const BDescriptor& b_desc,
index_t k_batch,
index_t Conv_N,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage
{
// Only enable hack if k_batch > 1
if(k_batch <= 1)
{
return false;
}
// Calculate output spatial product
const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(),
output_spatial_lengths.end(),
index_t{1},
std::multiplies<index_t>());
// Check various divisibility and layout requirements
const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0;
const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0;
const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0;
const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0;
// Check descriptor compactness
const bool is_a_compact = IsDescriptorCompact(a_desc);
const bool is_b_compact = IsDescriptorCompact(b_desc);
// Require BOTH A and B to be eligible for the hack to avoid KBatch dimension mismatch
// The gridwise kernel's CheckValidity requires A.KBatch == B.KBatch, so we must
// apply the hack uniformly to both tensors to maintain kernel applicability
const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_b_stride_divisible && is_a_compact && is_b_compact;
return eligible;
}
};
// Helper function to dispatch split-K hack for standard kernel (single LDS)
// Reduces code duplication in device layer implementations
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_hack)
{
if(split_k_offset_hack)
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
// Helper function to dispatch split-K hack for 2lds kernel
// Reduces code duplication in device layer implementations
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum,
typename ADataType,
typename BDataType,
typename CDataType>
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const typename GridwiseGemm::Argument& karg,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
index_t k_id,
index_t k_batch,
bool split_k_offset_hack)
{
if(split_k_offset_hack)
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
true>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
else
{
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
false>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id,
k_batch);
}
}
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -663,7 +663,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
TailNumber TailNum = TailNumber::Odd,
bool SplitKOffsetHack = false>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
@@ -673,12 +674,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0)
const index_t k_id = 0,
const index_t k_batch = 1)
{
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -744,7 +749,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -775,7 +780,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
@@ -1024,7 +1029,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
TailNumber TailNum = TailNumber::Odd,
bool SplitKOffsetHack = false>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
@@ -1035,12 +1041,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0)
const index_t k_id = 0,
const index_t k_batch = 1)
{
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -1106,7 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1137,7 +1147,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),

View File

@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
@@ -149,7 +150,8 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool SplitKOffsetHack>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -164,7 +166,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
const CBlockClusterAdaptor c_block_cluster_adaptor,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
index_t k_batch)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -172,17 +177,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor);
GridwiseGemm::template Run<HasMainKBlockLoop, SplitKOffsetHack>(
p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor,
split_k_stride_a,
split_k_stride_b,
k_batch);
}
#else
ignore = p_a_grid;
@@ -195,6 +204,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = k_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -536,7 +548,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
MRepeat,
NRepeat,
FloatC,
CGlobalMemoryDataOperation>();
CGlobalMemoryDataOperation_>();
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
@@ -646,6 +658,416 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop, bool SplitKOffsetHack = false>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const CBlockClusterAdaptor& c_block_cluster_adaptor,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
index_t k_batch)
{
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t k_batch_id = block_work_idx[I0];
// Use compile-time branching based on template parameters
const long_index_t split_k_offset_a = SplitKOffsetHack ? k_batch_id * split_k_stride_a : 0;
const long_index_t split_k_offset_b = SplitKOffsetHack ? k_batch_id * split_k_stride_b : 0;
// When hack is enabled, buffer size equals the stride (calculated from descriptor's
// CalculateOffset method in the device layer). This properly accounts for the
// descriptor's transform pipeline and non-compact strides.
// When hack is disabled, use the full element space size.
const long_index_t a_buffer_size =
SplitKOffsetHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
const long_index_t b_buffer_size =
SplitKOffsetHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
ignore = k_batch; // k_batch value itself not used in this function
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid + split_k_offset_a, a_buffer_size);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + split_k_offset_b, b_buffer_size);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatA,
FloatAAdjusted,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_b_k0_m_k1_grid_desc,
make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatB,
FloatBAdjusted,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_b_k0_n_k1_grid_desc,
make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
K1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
K1 < 32))
? true
: false;
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(K1,
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAAdjusted,
FloatBAdjusted,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXdl,
NPerXdl,
MRepeat,
NRepeat,
KPack,
ComputeTypeA,
ComputeTypeB>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
a_b_k0_m_k1_block_desc,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_b_k0_n_k1_grid_desc,
b_b_k0_n_k1_block_desc,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// output: register to global memory
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
static_assert(N1 == NWave, "");
static_assert(M2 * M3 * M4 == MPerXdl, "");
static_assert(N2 == NPerXdl, "");
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
M1,
M2,
M3,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock
make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
N1,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXdl,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
// LDS to global
c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_forward_step);
}
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
}
});
// move on mxdlperwave dimension
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
}
});
}
}
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,

View File

@@ -149,7 +149,8 @@ struct TransformConvBwdWeightToGemm
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_hack = false) // Deprecated parameter for backward compatibility
{
using namespace ck;
@@ -172,7 +173,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch;
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
@@ -190,7 +192,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -208,7 +210,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -246,7 +248,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -285,7 +287,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -323,7 +325,8 @@ struct TransformConvBwdWeightToGemm
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_hack = false)
{
using namespace ck;
@@ -359,7 +362,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch;
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -378,7 +382,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -393,7 +397,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -422,7 +426,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -463,7 +467,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -497,7 +501,8 @@ struct TransformConvBwdWeightToGemm
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_hack = false)
{
using namespace ck;
@@ -540,7 +545,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch;
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
@@ -559,7 +565,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -574,7 +580,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -603,7 +609,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -653,7 +659,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));

View File

@@ -324,7 +324,9 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_hack = false,
const bool use_full_batch_kindex = false)
{
using namespace ck;
@@ -353,7 +355,10 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
@@ -373,7 +378,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -389,7 +394,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -419,7 +424,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -460,7 +465,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -495,7 +500,9 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_hack = false,
const bool use_full_batch_kindex = false)
{
using namespace ck;
@@ -531,7 +538,10 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -551,7 +561,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -567,7 +577,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -597,7 +607,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -647,7 +657,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -681,7 +691,9 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_hack = false,
const bool use_full_batch_kindex = false)
{
using namespace ck;
@@ -724,7 +736,10 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
@@ -744,7 +759,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -760,7 +775,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -790,7 +805,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -855,7 +870,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));

View File

@@ -111,6 +111,101 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
return vy.template AsType<double2_t>()[I0];
}
#if defined(__gfx11__)
template <>
__device__ float8_t atomic_add<float8_t>(float8_t* p_dst, const float8_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
const vector_type<float, 8> vx{x};
vector_type<float, 8> vy{0};
vy.template AsType<float>()(I0) =
atomicAdd(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
vy.template AsType<float>()(I1) =
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
vy.template AsType<float>()(I2) =
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, vx.template AsType<float>()[I2]);
vy.template AsType<float>()(I3) =
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, vx.template AsType<float>()[I3]);
vy.template AsType<float>()(I4) =
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 4, vx.template AsType<float>()[I4]);
vy.template AsType<float>()(I5) =
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 5, vx.template AsType<float>()[I5]);
vy.template AsType<float>()(I6) =
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 6, vx.template AsType<float>()[I6]);
vy.template AsType<float>()(I7) =
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 7, vx.template AsType<float>()[I7]);
return vy.template AsType<float8_t>()[I0];
}
template <>
__device__ half4_t atomic_add<half4_t>(half4_t* p_dst, const half4_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const vector_type<half_t, 4> vx{x};
vector_type<half_t, 4> vy{0};
vy.template AsType<half_t>()(I0) =
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
vx.template AsType<half_t>()[I1]);
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
vx.template AsType<half_t>()[I2]);
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
vx.template AsType<half_t>()[I3]);
return vy.template AsType<half4_t>()[I0];
}
template <>
__device__ half8_t atomic_add<half8_t>(half8_t* p_dst, const half8_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
const vector_type<half_t, 8> vx{x};
vector_type<half_t, 8> vy{0};
vy.template AsType<half_t>()(I0) =
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
vx.template AsType<half_t>()[I1]);
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
vx.template AsType<half_t>()[I2]);
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
vx.template AsType<half_t>()[I3]);
vy.template AsType<half_t>()(I4) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 4,
vx.template AsType<half_t>()[I4]);
vy.template AsType<half_t>()(I5) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 5,
vx.template AsType<half_t>()[I5]);
vy.template AsType<half_t>()(I6) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 6,
vx.template AsType<half_t>()[I6]);
vy.template AsType<half_t>()(I7) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 7,
vx.template AsType<half_t>()[I7]);
return vy.template AsType<half8_t>()[I0];
}
#endif // defined(__gfx11__)
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for