mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK] Allow tensors larger than 2GB in grouped conv bwd weight (#3169)
* Take split_k into account when checking 2GB tensor limit. * Revert "Take split_k into account when checking 2GB tensor limit." This reverts commitadf35c91be. * Optimize grouped conv bwd wei split_k off calc (cherry picked from commit 2115642ee59050dabd81393c1b8f03b34adc05aa) * Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp (cherry picked from commit 900d4d4b466f5730ae1189370d3c96267c35ea69) * Fix tensor descriptors and stride calculations * Don't miss half of the elements * Fix buffer size calculations * Disable hack if stride not divisible by k_batch * Clean up comments * Disallow hack in non-contiguous edge cases * Index -> Dim * Fix broken test * Refactor applicability checks into separate function * fix missed variable name * Fix variable name in info print * update V3 2GB check * No more regression, use templates instead * Code deduplication * Regression fix for cshuffle * arch-guarded atomic_add implementations for gfx11 * Similar for half(4|8)_t as well * Only use both offset hacks at the same time * Revert "arch-guarded atomic_add implementations for gfx11" This reverts commit3883fe6935. This reverts commit5311ec608d. * Reapply "arch-guarded atomic_add implementations for gfx11" This reverts commit1972adeddc. * Only remove float4 atomic_add * Refactor to single flag * Consolidate template parameters * Consolidate flag in transformers --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com> [ROCm/composable_kernel commit:ee2c35b92d]
This commit is contained in:
@@ -24,6 +24,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -60,13 +61,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
@@ -77,23 +84,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
@@ -118,14 +131,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
@@ -139,24 +158,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
DispatchSplitKHack_2Lds<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = split_k_offset_hack;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
@@ -693,7 +718,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
const auto descs =
|
||||
// Create initial descriptors with hack=false to check compactness
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
@@ -709,11 +735,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
ce_grid_desc_m_n_ = descs[I2];
|
||||
k_batch_,
|
||||
false, // hack=false for initial check
|
||||
true); // use_full_batch_kindex
|
||||
|
||||
ce_elementwise_grid_desc_m_n_ =
|
||||
conv_to_gemm_transformer_v1
|
||||
@@ -733,6 +757,67 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_right_pads,
|
||||
k_batch_)[I2];
|
||||
|
||||
split_k_offset_hack_ =
|
||||
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
|
||||
descs_initial[I0],
|
||||
descs_initial[I1],
|
||||
k_batch_,
|
||||
Conv_N_,
|
||||
output_spatial_lengths_,
|
||||
KPerBlock);
|
||||
|
||||
// Create final descriptors with correct hack flag
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
split_k_offset_hack_, // Use determined hack flag
|
||||
true); // use_full_batch_kindex
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
ce_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Step 5: Calculate stride using CalculateOffset on FINAL descriptors
|
||||
if(split_k_offset_hack_)
|
||||
{
|
||||
const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_;
|
||||
const auto idx_start = make_multi_index(0, 0, 0);
|
||||
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.CalculateOffset(idx_next) -
|
||||
a_grid_desc_k0_m_k1_.CalculateOffset(idx_start);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
if(split_k_offset_hack_)
|
||||
{
|
||||
const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_;
|
||||
const auto idx_start = make_multi_index(0, 0, 0);
|
||||
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.CalculateOffset(idx_next) -
|
||||
b_grid_desc_k0_n_k1_.CalculateOffset(idx_start);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
@@ -869,6 +954,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -971,7 +1059,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -987,7 +1078,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1920,14 +2014,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -33,6 +34,74 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dispatch helper function for split-K hack - handles 2-way dispatch based on runtime flag
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid,
|
||||
const FloatB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
void* p_shared,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
if(split_k_offset_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
@@ -62,7 +131,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack,
|
||||
index_t k_batch)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -79,17 +152,33 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
DispatchBatchedGemmSplitKHack<GridwiseGemm,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
AGridDesc_B_K0_M_K1,
|
||||
BGridDesc_B_K0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
Block2CTileMap,
|
||||
HasMainKBlockLoop>(
|
||||
p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
split_k_offset_hack,
|
||||
k_batch);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -104,6 +193,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = batch_count;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
ignore = k_batch;
|
||||
|
||||
compute_ptr_offset_of_batch.GetAPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetBPtrOffset(0);
|
||||
@@ -459,7 +552,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
false>, // Both true/false give the same occupancy.
|
||||
false>, // HasMainKBlockLoop - both true/false give the same occupancy
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
return std::max(1, max_occupancy);
|
||||
@@ -576,6 +669,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes are divisible by k_batch
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
false); // split_k_offset_b_hack (temporary)
|
||||
|
||||
split_k_offset_hack_ =
|
||||
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
|
||||
descs_initial[I0],
|
||||
descs_initial[I1],
|
||||
k_batch_,
|
||||
Conv_N_,
|
||||
output_spatial_lengths_,
|
||||
K0PerBlock * K1);
|
||||
|
||||
// Now create descriptors with the correct hack flag
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -592,12 +716,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
k_batch_,
|
||||
split_k_offset_hack_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Calculate stride using CalculateOffset method for accurate stride
|
||||
// This works correctly for any descriptor transform pipeline
|
||||
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
@@ -732,6 +867,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -878,7 +1016,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_,
|
||||
arg.k_batch_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
@@ -58,13 +59,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
@@ -74,20 +81,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
|
||||
DispatchSplitKHack<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -96,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = num_k_per_block;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
|
||||
#endif // end of if (defined(__gfx9__)
|
||||
}
|
||||
|
||||
@@ -119,14 +134,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
const index_t num_k_per_block,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
@@ -140,21 +161,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx);
|
||||
DispatchSplitKHack_2Lds<GridwiseGemm,
|
||||
AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
|
||||
karg.p_b_grid + b_batch_offset + split_k_offset_b,
|
||||
karg.p_c_grid + e_batch_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_idx * num_k_per_block,
|
||||
gridDim.y,
|
||||
split_k_offset_hack);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -163,6 +187,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = num_k_per_block;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = split_k_offset_hack;
|
||||
#endif // end of if (defined(__gfx9__)
|
||||
}
|
||||
|
||||
@@ -490,8 +517,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
: p_a_grid_{p_out_grid},
|
||||
p_b_grid_{p_in_grid},
|
||||
p_c_grid_{p_wei_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
@@ -560,6 +587,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes match product of dimensions
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
false, // split_k_offset_b_hack (temporary)
|
||||
true); // use_full_batch_kindex=true for V1-compatible descriptors
|
||||
|
||||
split_k_offset_hack_ =
|
||||
SplitKHackEligibility<NDimSpatial, InLayout, WeiLayout, OutLayout>::Check(
|
||||
descs_initial[I0],
|
||||
descs_initial[I1],
|
||||
k_batch_,
|
||||
Conv_N_,
|
||||
output_spatial_lengths_,
|
||||
K0PerBlock);
|
||||
|
||||
// Now create descriptors with the correct hack flag
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -576,11 +635,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
k_batch_,
|
||||
split_k_offset_hack_,
|
||||
true); // use_full_batch_kindex=true for V1-compatible descriptors
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// Calculate stride using CalculateOffset method for accurate stride
|
||||
// This works correctly for any descriptor transform pipeline
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
@@ -591,8 +662,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -604,8 +675,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
@@ -631,6 +702,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
long_index_t c_space_size_bytes;
|
||||
|
||||
bool split_k_offset_hack_;
|
||||
long_index_t split_k_stride_a_, split_k_stride_b_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -640,17 +714,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
void ShowInfo(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
@@ -659,10 +731,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
template <typename GridwiseGemm>
|
||||
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
const BDataType* p_b_grid = arg.p_b_grid_;
|
||||
@@ -680,7 +752,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto num_k_per_block =
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.k_batch_ > 1)
|
||||
@@ -716,11 +788,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -732,11 +807,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
num_k_per_block,
|
||||
arg.split_k_stride_a_,
|
||||
arg.split_k_stride_b_,
|
||||
arg.split_k_offset_hack_);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -749,7 +827,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
@@ -781,7 +859,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
@@ -1090,7 +1168,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
@@ -1159,7 +1237,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
else
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
@@ -1232,7 +1310,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
@@ -1289,10 +1367,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
#endif
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
||||
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
@@ -1423,9 +1501,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
|
||||
const bool a_small_enough = arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() /
|
||||
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
|
||||
sizeof(ADataType) <=
|
||||
TwoGB;
|
||||
const bool b_small_enough = arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() /
|
||||
(arg.split_k_offset_hack_ ? arg.k_batch_ : 1) *
|
||||
sizeof(BDataType) <=
|
||||
TwoGB;
|
||||
const bool c_small_enough =
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB;
|
||||
if(!(a_small_enough && b_small_enough && c_small_enough))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Check if a tensor descriptor has compact layout
|
||||
// Compact means: GetElementSpaceSize() == product of all dimension lengths
|
||||
// Non-compact descriptors have complex transform pipelines that may not support split-k hack
|
||||
template <typename Descriptor>
|
||||
bool IsDescriptorCompact(const Descriptor& desc)
|
||||
{
|
||||
// Calculate product of all dimensions
|
||||
long_index_t dims_product = 1;
|
||||
constexpr index_t num_dims = Descriptor::GetNumOfDimension();
|
||||
|
||||
// Use template recursion to multiply all dimension lengths
|
||||
static_for<0, num_dims, 1>{}(
|
||||
[&](auto i) { dims_product *= static_cast<long_index_t>(desc.GetLength(i)); });
|
||||
|
||||
return desc.GetElementSpaceSize() == dims_product;
|
||||
}
|
||||
|
||||
// Determine split-k hack eligibility for descriptor pair
|
||||
// This checks all the conditions required for safely using the split-k offset hack
|
||||
template <index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
struct SplitKHackEligibility
|
||||
{
|
||||
template <typename ADescriptor, typename BDescriptor>
|
||||
static bool
|
||||
Check(const ADescriptor& a_desc,
|
||||
const BDescriptor& b_desc,
|
||||
index_t k_batch,
|
||||
index_t Conv_N,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage
|
||||
{
|
||||
// Only enable hack if k_batch > 1
|
||||
if(k_batch <= 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Calculate output spatial product
|
||||
const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(),
|
||||
output_spatial_lengths.end(),
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// Check various divisibility and layout requirements
|
||||
const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0;
|
||||
|
||||
const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0;
|
||||
|
||||
const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0;
|
||||
|
||||
const bool is_correct_layout =
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0;
|
||||
|
||||
const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0;
|
||||
|
||||
// Check descriptor compactness
|
||||
const bool is_a_compact = IsDescriptorCompact(a_desc);
|
||||
const bool is_b_compact = IsDescriptorCompact(b_desc);
|
||||
|
||||
// Require BOTH A and B to be eligible for the hack to avoid KBatch dimension mismatch
|
||||
// The gridwise kernel's CheckValidity requires A.KBatch == B.KBatch, so we must
|
||||
// apply the hack uniformly to both tensors to maintain kernel applicability
|
||||
const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch &&
|
||||
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
|
||||
is_b_stride_divisible && is_a_compact && is_b_compact;
|
||||
|
||||
return eligible;
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to dispatch split-K hack for standard kernel (single LDS)
|
||||
// Reduces code duplication in device layer implementations
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
if(split_k_offset_hack)
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to dispatch split-K hack for 2lds kernel
|
||||
// Reduces code duplication in device layer implementations
|
||||
template <typename GridwiseGemm,
|
||||
typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType>
|
||||
__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const typename GridwiseGemm::Argument& karg,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
index_t k_id,
|
||||
index_t k_batch,
|
||||
bool split_k_offset_hack)
|
||||
{
|
||||
if(split_k_offset_hack)
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
|
||||
BGridDesc_BK0_N_K1,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id,
|
||||
k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -663,7 +663,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd,
|
||||
bool SplitKOffsetHack = false>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
@@ -673,12 +674,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -744,7 +749,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -775,7 +780,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1024,7 +1029,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd,
|
||||
bool SplitKOffsetHack = false>
|
||||
__device__ static void Run_2Lds(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
@@ -1035,12 +1041,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -1106,7 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(k_id, m_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1137,7 +1147,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(k_id, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -149,7 +150,8 @@ template <typename GridwiseGemm,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
bool HasMainKBlockLoop,
|
||||
bool SplitKOffsetHack>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -164,7 +166,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
index_t k_batch)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -172,17 +177,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, SplitKOffsetHack>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor,
|
||||
split_k_stride_a,
|
||||
split_k_stride_b,
|
||||
k_batch);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
@@ -195,6 +204,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c_block_cluster_adaptor;
|
||||
ignore = split_k_stride_a;
|
||||
ignore = split_k_stride_b;
|
||||
ignore = k_batch;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -536,7 +548,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation>();
|
||||
CGlobalMemoryDataOperation_>();
|
||||
}
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
@@ -646,6 +658,416 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool SplitKOffsetHack = false>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor,
|
||||
const long_index_t split_k_stride_a,
|
||||
const long_index_t split_k_stride_b,
|
||||
index_t k_batch)
|
||||
{
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
// Use compile-time branching based on template parameters
|
||||
const long_index_t split_k_offset_a = SplitKOffsetHack ? k_batch_id * split_k_stride_a : 0;
|
||||
const long_index_t split_k_offset_b = SplitKOffsetHack ? k_batch_id * split_k_stride_b : 0;
|
||||
|
||||
// When hack is enabled, buffer size equals the stride (calculated from descriptor's
|
||||
// CalculateOffset method in the device layer). This properly accounts for the
|
||||
// descriptor's transform pipeline and non-compact strides.
|
||||
// When hack is disabled, use the full element space size.
|
||||
const long_index_t a_buffer_size =
|
||||
SplitKOffsetHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
|
||||
|
||||
const long_index_t b_buffer_size =
|
||||
SplitKOffsetHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
|
||||
|
||||
ignore = k_batch; // k_batch value itself not used in this function
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid + split_k_offset_a, a_buffer_size);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + split_k_offset_b, b_buffer_size);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
if(!c_block_cluster_adaptor.ValidCTileIndex(
|
||||
make_tuple(block_work_idx[I1], block_work_idx[I2]),
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatA,
|
||||
FloatAAdjusted,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
FloatBAdjusted,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
K1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && K1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
K1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(K1,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAAdjusted,
|
||||
FloatBAdjusted,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
|
||||
a_b_k0_m_k1_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
b_b_k0_n_k1_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
|
||||
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatC*>(p_shared),
|
||||
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
static_assert(M1 == MWave, "");
|
||||
static_assert(N1 == NWave, "");
|
||||
static_assert(M2 * M3 * M4 == MPerXdl, "");
|
||||
static_assert(N2 == NPerXdl, "");
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0), // freeze mblock
|
||||
make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
|
||||
M1,
|
||||
M2,
|
||||
M3,
|
||||
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_freeze_transform(I0), // freeze nblock
|
||||
make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
|
||||
N1,
|
||||
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// LDS to global
|
||||
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // index_t BlockSize,
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
FloatC, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun
|
||||
{c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
|
||||
c_element_op};
|
||||
|
||||
constexpr auto mxdlperwave_forward_step =
|
||||
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
|
||||
constexpr auto nxdlperwave_forward_step =
|
||||
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
|
||||
constexpr auto nxdlperwave_backward_step =
|
||||
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
|
||||
|
||||
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
|
||||
constexpr auto mxdlperwave = mxdlperwave_iter;
|
||||
|
||||
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
|
||||
constexpr bool nxdlperwave_forward_sweep =
|
||||
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
|
||||
|
||||
constexpr index_t nxdlperwave_value =
|
||||
nxdlperwave_forward_sweep
|
||||
? nxdlperwave_iter
|
||||
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
|
||||
|
||||
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
|
||||
|
||||
// make sure it's safe to do ds_write
|
||||
block_sync_lds();
|
||||
|
||||
// VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_buf);
|
||||
|
||||
// make sure it's safe to do ds_read
|
||||
block_sync_lds();
|
||||
|
||||
// LDS to global
|
||||
c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
// move on nxdlperwave dimension
|
||||
if constexpr(nxdlperwave_forward_sweep &&
|
||||
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
nxdlperwave_forward_step);
|
||||
}
|
||||
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
nxdlperwave_backward_step);
|
||||
}
|
||||
});
|
||||
|
||||
// move on mxdlperwave dimension
|
||||
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
|
||||
{
|
||||
c_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
|
||||
@@ -149,7 +149,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_hack = false) // Deprecated parameter for backward compatibility
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -172,7 +173,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -190,7 +192,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -208,7 +210,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -246,7 +248,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -285,7 +287,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -323,7 +325,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -359,7 +362,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
@@ -378,7 +382,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -393,7 +397,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -422,7 +426,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -463,7 +467,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -497,7 +501,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_hack = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -540,7 +545,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
@@ -559,7 +565,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -574,7 +580,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -603,7 +609,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
@@ -653,7 +659,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -324,7 +324,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_hack = false,
|
||||
const bool use_full_batch_kindex = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -353,7 +355,10 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
|
||||
@@ -373,7 +378,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -389,7 +394,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -419,7 +424,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -460,7 +465,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -495,7 +500,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_hack = false,
|
||||
const bool use_full_batch_kindex = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -531,7 +538,10 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
@@ -551,7 +561,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -567,7 +577,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -597,7 +607,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -647,7 +657,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -681,7 +691,9 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
const index_t batch_k,
|
||||
const bool split_k_offset_hack = false,
|
||||
const bool use_full_batch_kindex = false)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -724,7 +736,10 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
|
||||
// kernel compatibility
|
||||
const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
|
||||
const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
@@ -744,7 +759,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -760,7 +775,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -790,7 +805,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -855,7 +870,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
@@ -111,6 +111,101 @@ __device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
|
||||
return vy.template AsType<double2_t>()[I0];
|
||||
}
|
||||
|
||||
#if defined(__gfx11__)
|
||||
template <>
|
||||
__device__ float8_t atomic_add<float8_t>(float8_t* p_dst, const float8_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
const vector_type<float, 8> vx{x};
|
||||
vector_type<float, 8> vy{0};
|
||||
|
||||
vy.template AsType<float>()(I0) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
||||
vy.template AsType<float>()(I1) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
||||
vy.template AsType<float>()(I2) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, vx.template AsType<float>()[I2]);
|
||||
vy.template AsType<float>()(I3) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, vx.template AsType<float>()[I3]);
|
||||
vy.template AsType<float>()(I4) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 4, vx.template AsType<float>()[I4]);
|
||||
vy.template AsType<float>()(I5) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 5, vx.template AsType<float>()[I5]);
|
||||
vy.template AsType<float>()(I6) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 6, vx.template AsType<float>()[I6]);
|
||||
vy.template AsType<float>()(I7) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 7, vx.template AsType<float>()[I7]);
|
||||
|
||||
return vy.template AsType<float8_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half4_t atomic_add<half4_t>(half4_t* p_dst, const half4_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const vector_type<half_t, 4> vx{x};
|
||||
vector_type<half_t, 4> vy{0};
|
||||
|
||||
vy.template AsType<half_t>()(I0) =
|
||||
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
|
||||
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
|
||||
vx.template AsType<half_t>()[I1]);
|
||||
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
|
||||
vx.template AsType<half_t>()[I2]);
|
||||
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
|
||||
vx.template AsType<half_t>()[I3]);
|
||||
|
||||
return vy.template AsType<half4_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half8_t atomic_add<half8_t>(half8_t* p_dst, const half8_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
const vector_type<half_t, 8> vx{x};
|
||||
vector_type<half_t, 8> vy{0};
|
||||
|
||||
vy.template AsType<half_t>()(I0) =
|
||||
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
|
||||
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
|
||||
vx.template AsType<half_t>()[I1]);
|
||||
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
|
||||
vx.template AsType<half_t>()[I2]);
|
||||
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
|
||||
vx.template AsType<half_t>()[I3]);
|
||||
vy.template AsType<half_t>()(I4) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 4,
|
||||
vx.template AsType<half_t>()[I4]);
|
||||
vy.template AsType<half_t>()(I5) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 5,
|
||||
vx.template AsType<half_t>()[I5]);
|
||||
vy.template AsType<half_t>()(I6) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 6,
|
||||
vx.template AsType<half_t>()[I6]);
|
||||
vy.template AsType<half_t>()(I7) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 7,
|
||||
vx.template AsType<half_t>()[I7]);
|
||||
|
||||
return vy.template AsType<half8_t>()[I0];
|
||||
}
|
||||
#endif // defined(__gfx11__)
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
|
||||
|
||||
Reference in New Issue
Block a user