[CK_TILE] MX FLATMM Fix M Padding (#3489)

* Fix M Padding

* Fix tensor desc ele space size
This commit is contained in:
Yi DING
2025-12-29 09:09:12 +08:00
committed by GitHub
parent a3916a8d16
commit b0ea67e377
2 changed files with 6 additions and 4 deletions

View File

@@ -517,7 +517,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
"wrong!");
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
static_assert(NWarp == 4);
static_assert(MWarp == 1);
using CWarpTensor = typename WG::CWarpTensor;

View File

@@ -113,11 +113,13 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const index_t M0 = integer_divide_ceil(rows, M1);
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_0 = decltype(d0)( // set correct size (without padding)
d0.get_transforms(),
tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),