[CK_TILE] Fix GEMM Memory Pipeline (#2034)

* [CK_TILE] Fix GEMM Memory Pipeline

* Fix transpose tile

* Add comments

[ROCm/composable_kernel commit: 80aae6119b]
This commit is contained in:
aledudek
2025-04-08 12:40:04 +02:00
committed by GitHub
parent cba2a5d45d
commit 6dbaeb5fe8

View File

@@ -83,9 +83,6 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
@@ -101,51 +98,84 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
if constexpr(num_vec_in == 1 || num_vec_out == 1)
{
// loop over SFC
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y = SFC_Y::get_index(iAccess);
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_tuple(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
if constexpr(vec_length_in == 1)
{
out_tensor.get_thread_buffer()[number<out_offset>{}] =
in_tensor.get_thread_buffer()[number<in_offset>{}];
}
else
{
using Vec = array<DataType, vec_length_in>;
out_tensor.get_thread_buffer().template get_as<Vec>(
number<out_offset / vec_length_in>{}) =
in_tensor.get_thread_buffer().template get_as<Vec>(
number<in_offset / vec_length_in>{});
}
});
}
else
{
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_tuple(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) {
return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
});
}
}
} // namespace detail