Fix tensor descriptors and stride calculations

This commit is contained in:
Graner, Johannes
2025-11-12 13:30:56 +00:00
parent 64ca8414d6
commit e63bc7a2ce
6 changed files with 186 additions and 119 deletions

View File

@@ -814,16 +814,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
split_k_stride_a_ =
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_b_hack_)
split_k_stride_b_ /= k_batch_;
}
std::size_t GetWorkspaceATensorSizeBytes() const

View File

@@ -65,7 +65,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
bool split_k_offset_b_hack,
index_t k_batch)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -96,7 +97,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
split_k_stride_a,
split_k_stride_b,
split_k_offset_a_hack,
split_k_offset_b_hack);
split_k_offset_b_hack,
k_batch);
}
#else
ignore = p_a_grid;
@@ -115,6 +117,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = split_k_stride_b;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
ignore = k_batch;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
@@ -587,6 +590,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
k_batch_ = split_k;
}
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -603,12 +619,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
k_batch_,
split_k_offset_a_hack_,
split_k_offset_b_hack_);
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_b_hack_)
split_k_stride_b_ /= k_batch_;
block_2_ctile_map_ =
GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
@@ -650,23 +679,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
split_k_stride_a_ =
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
}
std::size_t GetWorkspaceATensorSizeBytes() const
@@ -913,7 +925,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
arg.split_k_stride_a_,
arg.split_k_stride_b_,
arg.split_k_offset_a_hack_,
arg.split_k_offset_b_hack_);
arg.split_k_offset_b_hack_,
arg.k_batch_);
};
if(has_main_k0_block_loop)

View File

@@ -514,8 +514,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
: p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{},
@@ -584,6 +584,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
k_batch_ = split_k;
}
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -600,11 +613,24 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
k_batch_,
split_k_offset_a_hack_,
split_k_offset_b_hack_);
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_b_hack_)
split_k_stride_b_ /= k_batch_;
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
@@ -615,38 +641,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_,
GridwiseGemm64::CalculateMBlock(GemmM),
GridwiseGemm64::CalculateNBlock(GemmN));
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
split_k_stride_a_ =
a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_;
split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_;
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
@@ -685,16 +694,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
@@ -703,10 +712,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
template <typename GridwiseGemm>
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
@@ -724,7 +733,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto num_k_per_block =
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
if(arg.k_batch_ > 1)
@@ -760,8 +769,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
dim3(BlockSize),
0,
gemm_arg_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block,
@@ -780,8 +789,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block,
@@ -1341,10 +1350,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
{
@@ -1475,8 +1484,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB))
{
return false;

View File

@@ -164,7 +164,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CBlockClusterAdaptor c_block_cluster_adaptor)
const CBlockClusterAdaptor c_block_cluster_adaptor,
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack,
index_t k_batch)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -182,7 +187,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
a_element_op,
b_element_op,
c_element_op,
c_block_cluster_adaptor);
c_block_cluster_adaptor,
split_k_stride_a,
split_k_stride_b,
split_k_offset_a_hack,
split_k_offset_b_hack,
k_batch);
}
#else
ignore = p_a_grid;
@@ -195,6 +205,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = b_element_op;
ignore = c_element_op;
ignore = c_block_cluster_adaptor;
ignore = split_k_stride_a;
ignore = split_k_stride_b;
ignore = split_k_offset_a_hack;
ignore = split_k_offset_b_hack;
ignore = k_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -662,7 +677,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
const long_index_t split_k_stride_a,
const long_index_t split_k_stride_b,
bool split_k_offset_a_hack,
bool split_k_offset_b_hack)
bool split_k_offset_b_hack,
index_t k_batch)
{
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
@@ -677,10 +693,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
const long_index_t split_k_offset_b =
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
const long_index_t a_space_size_divisor =
split_k_offset_a_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
const long_index_t b_space_size_divisor =
split_k_offset_b_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1;
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid + split_k_offset_a,

View File

@@ -150,7 +150,9 @@ struct TransformConvBwdWeightToGemm
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
using namespace ck;
@@ -173,7 +175,9 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
@@ -191,7 +195,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -209,7 +213,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -247,7 +251,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -286,7 +290,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -324,7 +328,9 @@ struct TransformConvBwdWeightToGemm
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
using namespace ck;
@@ -360,7 +366,9 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -379,7 +387,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -394,7 +402,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -423,7 +431,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -464,7 +472,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -498,7 +506,9 @@ struct TransformConvBwdWeightToGemm
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
using namespace ck;
@@ -541,7 +551,9 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
@@ -560,7 +572,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -575,7 +587,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -604,7 +616,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -654,7 +666,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));

View File

@@ -324,7 +324,9 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
using namespace ck;
@@ -353,7 +355,9 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
@@ -373,7 +377,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -389,7 +393,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -419,7 +423,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -460,7 +464,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -495,7 +499,9 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
using namespace ck;
@@ -531,7 +537,9 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -551,7 +559,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -567,7 +575,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -597,7 +605,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -647,7 +655,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -681,7 +689,9 @@ struct TransformConvBwdWeightToGemmV2
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
const index_t batch_k,
const bool split_k_offset_a_hack = false,
const bool split_k_offset_b_hack = false)
{
using namespace ck;
@@ -724,7 +734,9 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
@@ -744,7 +756,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -760,7 +772,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -790,7 +802,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -855,7 +867,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));