Add CTranspose optimization for NCHW cases just like in xdl cshuffle non-v3 device implementation.

This commit is contained in:
kiefer
2025-08-24 12:44:01 +00:00
parent b53c584eb9
commit 6ad73cd0cd
2 changed files with 364 additions and 420 deletions

View File

@@ -704,6 +704,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
@@ -717,7 +719,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;