Fix broken test

This commit is contained in:
Graner, Johannes
2025-11-28 09:35:03 +00:00
parent c168426885
commit 9e3e1b6935
6 changed files with 187 additions and 94 deletions

View File

@@ -716,7 +716,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
k_batch_ = split_k;
}
const auto descs =
// Step 1: Create initial descriptors with hack=false to check compactness
const auto descs_initial =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
@@ -733,13 +734,20 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_left_pads,
input_right_pads,
k_batch_,
false, // Don't modify KBatch dimension
false, // Don't modify KBatch dimension
true); // use_full_batch_kindex: keep full KBatch*K0 dimension
false, // hack=false for initial check
false, // hack=false for initial check
true); // use_full_batch_kindex
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
// Step 2: Check if descriptors are compact (element_space == product of dimensions)
const auto a_dims_product = static_cast<long_index_t>(descs_initial[I0].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I2));
const auto b_dims_product = static_cast<long_index_t>(descs_initial[I1].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I2));
const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product);
const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product);
ce_elementwise_grid_desc_m_n_ =
conv_to_gemm_transformer_v1
@@ -774,43 +782,53 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible =
a_grid_desc_k0_m_k1_.GetElementSpaceSize() % k_batch_ == 0;
descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0;
const bool is_b_stride_divisible =
b_grid_desc_k0_n_k1_.GetElementSpaceSize() % k_batch_ == 0;
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
// Check if descriptor is contiguous (no padding in layout)
// For a contiguous (K0, M, K1) descriptor, ElementSpaceSize should equal K0*M*K1
const long_index_t expected_size_a = static_cast<long_index_t>(
a_grid_desc_k0_m_k1_.GetLength(I0) * a_grid_desc_k0_m_k1_.GetLength(I1) *
a_grid_desc_k0_m_k1_.GetLength(I2));
const bool is_a_contiguous =
a_grid_desc_k0_m_k1_.GetElementSpaceSize() == expected_size_a;
const long_index_t expected_size_b = static_cast<long_index_t>(
b_grid_desc_k0_n_k1_.GetLength(I0) * b_grid_desc_k0_n_k1_.GetLength(I1) *
b_grid_desc_k0_n_k1_.GetLength(I2));
const bool is_b_contiguous =
b_grid_desc_k0_n_k1_.GetElementSpaceSize() == expected_size_b;
// Determine if we can safely use the split-k offset hack
// The hack requires contiguous descriptor layout for correct stride calculation
// Step 3: Determine if hack can be enabled (only for compact layouts)
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_a_contiguous;
is_a_compact;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible && is_b_contiguous;
is_correct_layout && is_b_stride_divisible && is_b_compact;
// Calculate stride for split-k offset hack
// The descriptor has shape (K0_full, M, K1) where K0_full = k_batch_ * K0_per_batch
// To advance from k_idx=0 to k_idx=1, we need to skip K0_per_batch slices
// Each K0 slice occupies M*K1 elements, so stride = (K0_full/k_batch_) * M * K1
// Step 4: Create final descriptors with correct hack flags
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
split_k_offset_a_hack_, // Use determined hack flag
split_k_offset_b_hack_, // Use determined hack flag
true); // use_full_batch_kindex
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
// Step 5: Calculate stride using CalculateOffset on FINAL descriptors
if(split_k_offset_a_hack_)
{
const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_;
split_k_stride_a_ = k0_per_batch * a_grid_desc_k0_m_k1_.GetLength(I1) *
a_grid_desc_k0_m_k1_.GetLength(I2);
const auto idx_start = make_multi_index(0, 0, 0);
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
split_k_stride_a_ = a_grid_desc_k0_m_k1_.CalculateOffset(idx_next) -
a_grid_desc_k0_m_k1_.CalculateOffset(idx_start);
}
else
{
@@ -820,8 +838,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
if(split_k_offset_b_hack_)
{
const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_;
split_k_stride_b_ = k0_per_batch * b_grid_desc_k0_n_k1_.GetLength(I1) *
b_grid_desc_k0_n_k1_.GetLength(I2);
const auto idx_start = make_multi_index(0, 0, 0);
const auto idx_next = make_multi_index(k0_per_batch, 0, 0);
split_k_stride_b_ = b_grid_desc_k0_n_k1_.CalculateOffset(idx_next) -
b_grid_desc_k0_n_k1_.CalculateOffset(idx_start);
}
else
{

View File

@@ -632,12 +632,28 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const bool is_b_stride_divisible =
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
// Check if descriptor has compact layout (product of dimensions equals element space)
// Non-compact layouts have complex transform pipelines that don't support the hack
const auto a_dims_product = static_cast<long_index_t>(descs_initial[I0].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I2)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I3));
const auto b_dims_product = static_cast<long_index_t>(descs_initial[I1].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I2)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I3));
const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product);
const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product);
// Determine if we can safely use the split-k offset hack
// Only enable for compact layouts where element_space_size == product of dimensions
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible;
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_a_compact;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible;
is_correct_layout && is_b_stride_divisible && is_b_compact;
// Now create descriptors with the correct hack flags
const auto descs =
@@ -664,10 +680,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is
// enabled. The divisibility check above ensures no truncation occurs.
// Calculate stride using CalculateOffset method for accurate stride
// This works correctly for any descriptor transform pipeline
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;

View File

@@ -584,19 +584,72 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
k_batch_ = split_k;
}
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes match product of dimensions
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
false, // split_k_offset_a_hack (temporary)
false, // split_k_offset_b_hack (temporary)
true); // use_full_batch_kindex=true for V1-compatible descriptors
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
const bool can_divide_n_spatial_by_k_batch =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
const bool is_a_stride_divisible =
descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0;
const bool is_b_stride_divisible =
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
// Check if descriptor has compact layout (product of dimensions equals element space)
// Non-compact layouts have complex transform pipelines that don't support the hack
// Note: CShuffleV3 descriptors are 3D [K0, M, K1], not 4D like CShuffle
const auto a_dims_product = static_cast<long_index_t>(descs_initial[I0].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I0].GetLength(I2));
const auto b_dims_product = static_cast<long_index_t>(descs_initial[I1].GetLength(I0)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I1)) *
static_cast<long_index_t>(descs_initial[I1].GetLength(I2));
const bool is_a_compact = (descs_initial[I0].GetElementSpaceSize() == a_dims_product);
const bool is_b_compact = (descs_initial[I1].GetElementSpaceSize() == b_dims_product);
// Determine if we can safely use the split-k offset hack
// Only enable for compact layouts where element_space_size == product of dimensions
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_a_compact;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible && is_b_compact;
// Now create descriptors with the correct hack flags
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -622,9 +675,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
// Calculate stride from descriptor size
// With use_full_batch_kindex=true, descriptors contain full k-batch dimension
// so we divide by k_batch_ to get per-batch stride
// Calculate stride using CalculateOffset method for accurate stride
// This works correctly for any descriptor transform pipeline
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;

View File

@@ -693,16 +693,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
const long_index_t split_k_offset_b =
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
// When hack is enabled, use GetElementSpaceSize() divided by k_batch for buffer size.
// This matches the stride calculation in the device layer and correctly accounts for
// the memory layout encoded in GetElementSpaceSize().
// When hack is enabled, buffer size equals the stride (calculated from descriptor's
// CalculateOffset method in the device layer). This properly accounts for the
// descriptor's transform pipeline and non-compact strides.
// When hack is disabled, use the full element space size.
const long_index_t a_buffer_size =
split_k_offset_a_hack ? (a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / k_batch)
: a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
split_k_offset_a_hack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
const long_index_t b_buffer_size =
split_k_offset_b_hack ? (b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / k_batch)
: b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
split_k_offset_b_hack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
ignore = k_batch; // k_batch value itself not used in this function
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid + split_k_offset_a, a_buffer_size);

View File

@@ -175,9 +175,10 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPadA = KBatchDimA * GemmK0 * GemmK1Number;
const index_t GemmKPadB = KBatchDimB * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
@@ -188,7 +189,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -206,7 +207,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -244,7 +245,7 @@ struct TransformConvBwdWeightToGemm
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -283,7 +284,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -366,9 +367,10 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPadA = KBatchDimA * GemmK0 * GemmK1Number;
const index_t GemmKPadB = KBatchDimB * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -380,7 +382,7 @@ struct TransformConvBwdWeightToGemm
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -395,7 +397,7 @@ struct TransformConvBwdWeightToGemm
// B: input tensor
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -424,7 +426,7 @@ struct TransformConvBwdWeightToGemm
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -465,7 +467,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -551,9 +553,10 @@ struct TransformConvBwdWeightToGemm
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const index_t KBatchDimA = split_k_offset_a_hack ? 1 : GemmKBatch;
const index_t KBatchDimB = split_k_offset_b_hack ? 1 : GemmKBatch;
const index_t GemmKPadA = KBatchDimA * GemmK0 * GemmK1Number;
const index_t GemmKPadB = KBatchDimB * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
@@ -565,7 +568,7 @@ struct TransformConvBwdWeightToGemm
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -580,7 +583,7 @@ struct TransformConvBwdWeightToGemm
// B: input tensor
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -609,7 +612,7 @@ struct TransformConvBwdWeightToGemm
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -659,7 +662,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));

View File

@@ -356,13 +356,14 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchDimA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchDimB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t GemmKPadA = KBatchDimA * GemmK0 * GemmK1Number;
const index_t GemmKPadB = KBatchDimB * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
@@ -375,7 +376,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -391,7 +392,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -421,7 +422,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -462,7 +463,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -543,13 +544,14 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchIndexA =
const index_t KBatchDimA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchIndexB =
const index_t KBatchDimB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t GemmKPadA = KBatchDimA * GemmK0 * GemmK1Number;
const index_t GemmKPadB = KBatchDimB * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
@@ -562,14 +564,14 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -578,14 +580,14 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -608,14 +610,14 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimA * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -658,14 +660,14 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)),
make_tuple(make_unmerge_transform(make_tuple(KBatchDimB * GemmK0, GemmK1Number)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -745,13 +747,14 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
// When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise
// kernel compatibility
const index_t KBatchDimA =
(split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t KBatchDimB =
(split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch;
const index_t GemmKPadA = KBatchIndexA * GemmK0 * GemmK1Number;
const index_t GemmKPadB = KBatchIndexB * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
@@ -764,7 +767,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -780,7 +783,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -810,7 +813,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_right_pad_transform(GemmKTotal, GemmKPadA - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -875,7 +878,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPadB - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));