mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
use PackedSize in slicing
This commit is contained in:
@@ -283,6 +283,7 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
|
||||
// below function should be used under tuple_array<> type, no extra check will perform here
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
|
||||
|
||||
@@ -75,7 +75,9 @@ struct static_distributed_tensor
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
|
||||
// divide element number by PackedSize to get the correct thread buffer size
|
||||
/// TODO: check if this is correct
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size() / PackedSize>
|
||||
sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
|
||||
@@ -322,8 +322,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
|
||||
const auto& c_block_tile = MXGemmPipeline{}(a_block_window,
|
||||
b_block_window,
|
||||
const auto& c_block_tile = MXGemmPipeline{}(a_block_window[number<0>{}],
|
||||
b_block_window[number<0>{}],
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
|
||||
Reference in New Issue
Block a user