use PackedSize in slicing

This commit is contained in:
Sami Remes
2026-01-27 13:01:06 -05:00
parent 08ec1f4192
commit 30d4c25d5a
3 changed files with 6 additions and 3 deletions

View File

@@ -283,6 +283,7 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) const { TP_COM_(); return get<I>(); }
// below function should be used under tuple_array<> type, no extra check will perform here
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }

View File

@@ -75,7 +75,9 @@ struct static_distributed_tensor
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
// divide element number by PackedSize to get the correct thread buffer size
/// TODO: check if this is correct
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size() / PackedSize>
sliced_thread_data;
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {

View File

@@ -322,8 +322,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
"ScaleM and ScaleN should have the same GranularityK");
const auto& c_block_tile = MXGemmPipeline{}(a_block_window,
b_block_window,
const auto& c_block_tile = MXGemmPipeline{}(a_block_window[number<0>{}],
b_block_window[number<0>{}],
scale_a_block_window,
scale_b_block_window,
num_loop,