mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Add CTranspose optimization for NCHW cases just like in xdl cshuffle non-v3 device implementation.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -704,6 +704,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
@@ -717,7 +719,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
|
||||
p_bs_grid_(i) =
|
||||
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
|
||||
});
|
||||
|
||||
DsGridPointer p_ds_grid_grp;
|
||||
|
||||
Reference in New Issue
Block a user