diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp index f34efe5c2f..5b65b79c1a 100644 --- a/include/ck_tile/core/tensor/transpose_tile.hpp +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -83,9 +83,6 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, constexpr index_t num_vec_in = vec_length_out; constexpr index_t num_vec_out = vec_length_in; - using InVec = array; - using OutVec = array; - // SFC constexpr auto scalars_per_access_arr = generate_array( [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; }, @@ -101,51 +98,84 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, static_assert(num_access > 0, "wrong! num_access should be larger than 0"); - // in/out vectors to be transposed - thread_buffer in_vectors; - thread_buffer out_vectors; + if constexpr(num_vec_in == 1 || num_vec_out == 1) + { + // loop over SFC + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y = SFC_Y::get_index(iAccess); - // loop over SFC and do transpose - static_for<0, num_access, 1>{}([&](auto iAccess) { - // data index [y0, y1, ...] in the order of input tensor - constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y); + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y); - // get input vectors - static_for<0, num_vec_in, 1>{}([&](auto i) { - constexpr auto idx_y_in = generate_tuple( - [&](auto ii) { - return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; - }, - number{}); - - constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); - static_assert(in_offset % vec_length_in == 0); - - in_vectors(i).template get_as()(I0) = - in_tensor.get_thread_buffer() - .template get_as()[number{}]; + if constexpr(vec_length_in == 1) + { + out_tensor.get_thread_buffer()[number{}] = + in_tensor.get_thread_buffer()[number{}]; + } + else + { + using Vec = array; + out_tensor.get_thread_buffer().template get_as( + number{}) = + in_tensor.get_thread_buffer().template get_as( + number{}); + } }); + } + else + { + using InVec = array; + using OutVec = array; - // transpose - transpose_vectors{}(in_vectors, out_vectors); + // in/out vectors to be transposed + thread_buffer in_vectors; + thread_buffer out_vectors; - // set output vectors - static_for<0, num_vec_out, 1>{}([&](auto i) { - constexpr auto idx_y_out_tmp = generate_array( - [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, - number{}); + // loop over SFC and do transpose + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); - constexpr auto idx_y_out = - container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + // get input vectors + static_for<0, num_vec_in, 1>{}([&](auto i) { + constexpr auto idx_y_in = generate_tuple( + [&](auto ii) { + return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); - constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); - static_assert(out_offset % vec_length_out == 0); + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + static_assert(in_offset % vec_length_in == 0); - out_tensor.get_thread_buffer().template set_as( - number{}, - out_vectors[i].template get_as()[I0]); + in_vectors(i).template get_as()(I0) = + in_tensor.get_thread_buffer() + .template get_as()[number{}]; + }); + + // transpose + transpose_vectors{}(in_vectors, out_vectors); + + // set output vectors + static_for<0, num_vec_out, 1>{}([&](auto i) { + constexpr auto idx_y_out_tmp = generate_array( + [&](auto ii) { + return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); + + constexpr auto idx_y_out = + container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); + static_assert(out_offset % vec_length_out == 0); + + out_tensor.get_thread_buffer().template set_as( + number{}, + out_vectors[i].template get_as()[I0]); + }); }); - }); + } } } // namespace detail