mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Disallow hack in non-contiguous edge cases
This commit is contained in:
@@ -779,24 +779,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const bool is_b_stride_divisible =
|
||||
b_grid_desc_k0_n_k1_.GetElementSpaceSize() % k_batch_ == 0;
|
||||
|
||||
// Check if descriptor is contiguous (no padding in layout)
|
||||
// For a contiguous (K0, M, K1) descriptor, ElementSpaceSize should equal K0*M*K1
|
||||
const long_index_t expected_size_a = static_cast<long_index_t>(
|
||||
a_grid_desc_k0_m_k1_.GetLength(I0) * a_grid_desc_k0_m_k1_.GetLength(I1) *
|
||||
a_grid_desc_k0_m_k1_.GetLength(I2));
|
||||
const bool is_a_contiguous =
|
||||
a_grid_desc_k0_m_k1_.GetElementSpaceSize() == expected_size_a;
|
||||
|
||||
const long_index_t expected_size_b = static_cast<long_index_t>(
|
||||
b_grid_desc_k0_n_k1_.GetLength(I0) * b_grid_desc_k0_n_k1_.GetLength(I1) *
|
||||
b_grid_desc_k0_n_k1_.GetLength(I2));
|
||||
const bool is_b_contiguous =
|
||||
b_grid_desc_k0_n_k1_.GetElementSpaceSize() == expected_size_b;
|
||||
|
||||
// Determine if we can safely use the split-k offset hack
|
||||
// The hack requires contiguous descriptor layout for correct stride calculation
|
||||
split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch &&
|
||||
is_k_not_paded && is_correct_layout && is_a_stride_divisible;
|
||||
is_k_not_paded && is_correct_layout && is_a_stride_divisible &&
|
||||
is_a_contiguous;
|
||||
|
||||
split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded &&
|
||||
is_correct_layout && is_b_stride_divisible;
|
||||
is_correct_layout && is_b_stride_divisible && is_b_contiguous;
|
||||
|
||||
// Calculate stride from descriptor size
|
||||
// NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1,
|
||||
// so we need to divide by k_batch_ to get the per-batch stride when the hack is
|
||||
// enabled. The divisibility check above ensures no truncation occurs.
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
// Calculate stride for split-k offset hack
|
||||
// The descriptor has shape (K0_full, M, K1) where K0_full = k_batch_ * K0_per_batch
|
||||
// To advance from k_idx=0 to k_idx=1, we need to skip K0_per_batch slices
|
||||
// Each K0 slice occupies M*K1 elements, so stride = (K0_full/k_batch_) * M * K1
|
||||
if(split_k_offset_a_hack_)
|
||||
split_k_stride_a_ /= k_batch_;
|
||||
{
|
||||
const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_;
|
||||
split_k_stride_a_ = k0_per_batch * a_grid_desc_k0_m_k1_.GetLength(I1) *
|
||||
a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
if(split_k_offset_b_hack_)
|
||||
split_k_stride_b_ /= k_batch_;
|
||||
{
|
||||
const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_;
|
||||
split_k_stride_b_ = k0_per_batch * b_grid_desc_k0_n_k1_.GetLength(I1) *
|
||||
b_grid_desc_k0_n_k1_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
|
||||
|
||||
Reference in New Issue
Block a user