diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 7f8176d5ec..4329d590b8 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -283,6 +283,7 @@ struct tuple : impl::tuple_base, T...> template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) { TP_COM_(); return get(); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) const { TP_COM_(); return get(); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number) { TP_COM_(); return get(); } // TODO: compatible + template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number) const { TP_COM_(); return get(); } // below function should be used under tuple_array<> type, no extra check will perform here template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast&>(*this); } diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 10c7587bcb..1994f345c0 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -75,7 +75,9 @@ struct static_distributed_tensor constexpr auto sliced_thread_tensor_desc = make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); - thread_buffer + // divide element number by PackedSize to get the correct thread buffer size + /// TODO: check if this is correct + thread_buffer sliced_thread_data; static_ford>{}([&](auto idx) { diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp index 24173a89dd..97e26e756f 100644 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp @@ -322,8 +322,8 @@ struct MXGemmKernel : UniversalGemmKernel{}], + b_block_window[number<0>{}], scale_a_block_window, scale_b_block_window, num_loop,