diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 125d32aad8..bc7d2323d0 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -517,7 +517,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds - const index_t M0 = rows / M1; + const index_t M0 = integer_divide_ceil(rows, M1); const auto row_lens = make_tuple(M0, number{}); - const auto desc_0 = - make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto desc_0 = decltype(d0)( // set correct size (without padding) + d0.get_transforms(), + tensor_view_tmp.get_tensor_descriptor().get_element_space_size()); const auto desc_1 = transform_tensor_descriptor( desc_0, make_tuple(make_pass_through_transform(M0),