Fix buffer size calculations

This commit is contained in:
Graner, Johannes
2025-11-20 11:32:21 +00:00
parent fc86ec44f5
commit 350162728d
2 changed files with 16 additions and 7 deletions

View File

@@ -732,7 +732,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_);
k_batch_,
false, // Don't modify KBatch dimension
false, // Don't modify KBatch dimension
true); // use_full_batch_kindex: keep full KBatch*K0 dimension
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];

View File

@@ -693,15 +693,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
const long_index_t split_k_offset_b =
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
// When hack is enabled, use GetElementSpaceSize() divided by k_batch for buffer size.
// This matches the stride calculation in the device layer and correctly accounts for
// the memory layout encoded in GetElementSpaceSize().
const long_index_t a_buffer_size =
split_k_offset_a_hack ? (a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / k_batch)
: a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
const long_index_t b_buffer_size =
split_k_offset_b_hack ? (b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / k_batch)
: b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid + split_k_offset_a,
a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / a_space_size_divisor);
p_a_grid + split_k_offset_a, a_buffer_size);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + split_k_offset_b,
b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / b_space_size_divisor);
p_b_grid + split_k_offset_b, b_buffer_size);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());