Disable hack if stride not divisible by k_batch

This commit is contained in:
Graner, Johannes
2025-11-24 07:20:38 -05:00
parent 350162728d
commit 0e8ca436f3
2 changed files with 91 additions and 32 deletions

View File

@@ -759,6 +759,48 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads,
k_batch_)[I2];
// Extract complex conditions to named boolean variables for clarity
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
const bool can_divide_n_spatial_by_k_batch =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// NEW: Check if descriptor element space sizes are divisible by k_batch_
// This prevents integer division truncation when calculating strides
const bool is_a_stride_divisible =
a_grid_desc_k0_m_k1_.GetElementSpaceSize() % k_batch_ == 0;
const bool is_b_stride_divisible =
b_grid_desc_k0_n_k1_.GetElementSpaceSize() % k_batch_ == 0;
// Determine if we can safely use the split-k offset hack
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible;
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is
// enabled. The divisibility check above ensures no truncation occurs.
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_b_hack_)
split_k_stride_b_ /= k_batch_;
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -810,30 +852,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
e_in_transpose_desc_.GetLength(I1)}
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
ce_grid_desc_m_n_.GetLength(I1)};
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_b_hack_)
split_k_stride_b_ /= k_batch_;
}
std::size_t GetWorkspaceATensorSizeBytes() const

View File

@@ -590,19 +590,59 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
k_batch_ = split_k;
}
// Create descriptors first (with hack flags temporarily set to false)
// so we can check if element space sizes are divisible by k_batch
const auto descs_initial =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides_transposed,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_,
false, // split_k_offset_a_hack (temporary)
false); // split_k_offset_b_hack (temporary)
// Extract complex conditions to named boolean variables for clarity
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const bool is_k_not_paded =
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
split_k_offset_a_hack_ =
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// Check if there is KPading and we can divide N by k_batch
split_k_offset_b_hack_ =
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
const bool can_divide_n_spatial_by_k_batch =
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
const bool is_correct_layout =
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
// NEW: Check if descriptor element space sizes are divisible by k_batch_
// This prevents integer division truncation when calculating strides
const bool is_a_stride_divisible =
descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0;
const bool is_b_stride_divisible =
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
// Determine if we can safely use the split-k offset hack
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible;
// Now create descriptors with the correct hack flags
const auto descs =
conv_to_gemm_transformer
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
@@ -629,7 +669,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
// so we need to divide by k_batch_ to get the per-batch stride when the hack is
// enabled. The divisibility check above ensures no truncation occurs.
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;