Disallow hack in non-contiguous edge cases

This commit is contained in:
Graner, Johannes
2025-11-25 13:29:37 +00:00
parent 7ad122d1c1
commit 5d1d298e2b

View File

@@ -779,24 +779,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const bool is_b_stride_divisible =
b_grid_desc_k0_n_k1_.GetElementSpaceSize() % k_batch_ == 0;
// Check if descriptor is contiguous (no padding in layout)
// For a contiguous (K0, M, K1) descriptor, ElementSpaceSize should equal K0*M*K1
const long_index_t expected_size_a = static_cast<long_index_t>(
a_grid_desc_k0_m_k1_.GetLength(I0) * a_grid_desc_k0_m_k1_.GetLength(I1) *
a_grid_desc_k0_m_k1_.GetLength(I2));
const bool is_a_contiguous =
a_grid_desc_k0_m_k1_.GetElementSpaceSize() == expected_size_a;
const long_index_t expected_size_b = static_cast<long_index_t>(
b_grid_desc_k0_n_k1_.GetLength(I0) * b_grid_desc_k0_n_k1_.GetLength(I1) *
b_grid_desc_k0_n_k1_.GetLength(I2));
const bool is_b_contiguous =
b_grid_desc_k0_n_k1_.GetElementSpaceSize() == expected_size_b;
// Determine if we can safely use the split-k offset hack
// The hack requires contiguous descriptor layout for correct stride calculation
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
is_k_not_paded && is_correct_layout && is_a_stride_divisible;
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
is_a_contiguous;
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
is_correct_layout && is_b_stride_divisible;
is_correct_layout && is_b_stride_divisible && is_b_contiguous;
// Calculate stride from descriptor size
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
// so we need to divide by k_batch_ to get the per-batch stride when the hack is
// enabled. The divisibility check above ensures no truncation occurs.
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
// Calculate stride for split-k offset hack
// The descriptor has shape (K0_full, M, K1) where K0_full = k_batch_ * K0_per_batch
// To advance from k_idx=0 to k_idx=1, we need to skip K0_per_batch slices
// Each K0 slice occupies M*K1 elements, so stride = (K0_full/k_batch_) * M * K1
if(split_k_offset_a_hack_)
split_k_stride_a_ /= k_batch_;
{
const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_;
split_k_stride_a_ = k0_per_batch * a_grid_desc_k0_m_k1_.GetLength(I1) *
a_grid_desc_k0_m_k1_.GetLength(I2);
}
else
{
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
}
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
if(split_k_offset_b_hack_)
split_k_stride_b_ /= k_batch_;
{
const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_;
split_k_stride_b_ = k0_per_batch * b_grid_desc_k0_n_k1_.GetLength(I1) *
b_grid_desc_k0_n_k1_.GetLength(I2);
}
else
{
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
}
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);