mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Fix buffer size calculations
This commit is contained in:
@@ -732,7 +732,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
k_batch_,
|
||||
false, // Don't modify KBatch dimension
|
||||
false, // Don't modify KBatch dimension
|
||||
true); // use_full_batch_kindex: keep full KBatch*K0 dimension
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
|
||||
@@ -693,15 +693,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
const long_index_t split_k_offset_b =
|
||||
split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0;
|
||||
|
||||
const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1;
|
||||
// When hack is enabled, use GetElementSpaceSize() divided by k_batch for buffer size.
|
||||
// This matches the stride calculation in the device layer and correctly accounts for
|
||||
// the memory layout encoded in GetElementSpaceSize().
|
||||
const long_index_t a_buffer_size =
|
||||
split_k_offset_a_hack ? (a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / k_batch)
|
||||
: a_b_k0_m_k1_grid_desc.GetElementSpaceSize();
|
||||
|
||||
const long_index_t b_buffer_size =
|
||||
split_k_offset_b_hack ? (b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / k_batch)
|
||||
: b_b_k0_n_k1_grid_desc.GetElementSpaceSize();
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid + split_k_offset_a,
|
||||
a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / a_space_size_divisor);
|
||||
p_a_grid + split_k_offset_a, a_buffer_size);
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + split_k_offset_b,
|
||||
b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / b_space_size_divisor);
|
||||
p_b_grid + split_k_offset_b, b_buffer_size);
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user