mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Disable hack if stride not divisible by k_batch
This commit is contained in:
@@ -759,6 +759,48 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_right_pads,
|
||||
k_batch_)[I2];
|
||||
|
||||
// Extract complex conditions to named boolean variables for clarity
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
|
||||
|
||||
const bool can_divide_n_spatial_by_k_batch =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
|
||||
|
||||
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
|
||||
|
||||
const bool is_correct_layout =
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
// NEW: Check if descriptor element space sizes are divisible by k_batch_
|
||||
// This prevents integer division truncation when calculating strides
|
||||
const bool is_a_stride_divisible =
|
||||
a_grid_desc_k0_m_k1_.GetElementSpaceSize() % k_batch_ == 0;
|
||||
|
||||
const bool is_b_stride_divisible =
|
||||
b_grid_desc_k0_n_k1_.GetElementSpaceSize() % k_batch_ == 0;
|
||||
|
||||
// Determine if we can safely use the split-k offset hack
|
||||
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
|
||||
is_k_not_paded && is_correct_layout && is_a_stride_divisible;
|
||||
|
||||
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
|
||||
is_correct_layout && is_b_stride_divisible;
|
||||
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is
|
||||
// enabled. The divisibility check above ensures no truncation occurs.
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_b_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
@@ -810,30 +852,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
e_in_transpose_desc_.GetLength(I1)}
|
||||
: Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0),
|
||||
ce_grid_desc_m_n_.GetLength(I1)};
|
||||
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_b_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
|
||||
@@ -590,19 +590,59 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
|
||||
// Create descriptors first (with hack flags temporarily set to false)
|
||||
// so we can check if element space sizes are divisible by k_batch
|
||||
const auto descs_initial =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides_transposed,
|
||||
e_g_k_c_xs_strides_transposed,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_,
|
||||
false, // split_k_offset_a_hack (temporary)
|
||||
false); // split_k_offset_b_hack (temporary)
|
||||
|
||||
// Extract complex conditions to named boolean variables for clarity
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
const bool is_k_not_paded =
|
||||
(Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0;
|
||||
// Check if there is KPading and we can divide N * OutSpatialDims by k_batch
|
||||
split_k_offset_a_hack_ =
|
||||
k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
// Check if there is KPading and we can divide N by k_batch
|
||||
split_k_offset_b_hack_ =
|
||||
k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded &&
|
||||
|
||||
const bool can_divide_n_spatial_by_k_batch =
|
||||
(Conv_N_ * output_spatial_acum) % k_batch_ == 0;
|
||||
|
||||
const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0;
|
||||
|
||||
const bool is_correct_layout =
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<InLayout, WeiLayout, OutLayout>();
|
||||
|
||||
// NEW: Check if descriptor element space sizes are divisible by k_batch_
|
||||
// This prevents integer division truncation when calculating strides
|
||||
const bool is_a_stride_divisible =
|
||||
descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0;
|
||||
|
||||
const bool is_b_stride_divisible =
|
||||
descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0;
|
||||
|
||||
// Determine if we can safely use the split-k offset hack
|
||||
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
|
||||
is_k_not_paded && is_correct_layout && is_a_stride_divisible;
|
||||
|
||||
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
|
||||
is_correct_layout && is_b_stride_divisible;
|
||||
|
||||
// Now create descriptors with the correct hack flags
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
@@ -629,7 +669,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is
|
||||
// enabled. The divisibility check above ensures no truncation occurs.
|
||||
split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
|
||||
Reference in New Issue
Block a user