mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] MX FLATMM Fix M Padding (#3489)
* Fix M Padding * Fix tensor desc ele space size
This commit is contained in:
@@ -517,7 +517,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
"wrong!");
|
||||
|
||||
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
|
||||
static_assert(NWarp == 4);
|
||||
static_assert(MWarp == 1);
|
||||
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
|
||||
@@ -113,11 +113,13 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
|
||||
const index_t M0 = rows / M1;
|
||||
const index_t M0 = integer_divide_ceil(rows, M1);
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
const auto desc_0 =
|
||||
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_0 = decltype(d0)( // set correct size (without padding)
|
||||
d0.get_transforms(),
|
||||
tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
|
||||
Reference in New Issue
Block a user