mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Candidate fix 11
This commit is contained in:
@@ -567,157 +567,61 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
|
||||
index_t offset)
|
||||
{
|
||||
using OutputDataType = typename DistributedTensor_::DataType;
|
||||
using OutTensor = remove_cvref_t<DistributedTensor_>;
|
||||
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
|
||||
using InTensor = remove_cvref_t<decltype(trans_tensor)>;
|
||||
using OutTensor = remove_cvref_t<DistributedTensor_>;
|
||||
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
|
||||
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
|
||||
using InTensor = remove_cvref_t<decltype(trans_tensor)>;
|
||||
|
||||
constexpr auto input_distr = typename InTensor::StaticTileDistribution{};
|
||||
constexpr auto output_distr = typename OutTensor::StaticTileDistribution{};
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
|
||||
|
||||
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
|
||||
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
|
||||
|
||||
static_assert(NDimYIn == NDimYOut,
|
||||
"Mixed precision transpose conversion requires same Y rank.");
|
||||
constexpr index_t total_elems_in = reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{});
|
||||
constexpr index_t total_elems_out = reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{});
|
||||
|
||||
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
|
||||
|
||||
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
|
||||
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
|
||||
|
||||
// For mixed precision: element space size must be the same (total bytes match)
|
||||
static_assert(y_in_element_space_size == y_out_element_space_size,
|
||||
"For mixed precision transpose, input and output element space size must match!");
|
||||
|
||||
constexpr index_t total_elems_in =
|
||||
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{});
|
||||
constexpr index_t total_elems_out =
|
||||
reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{});
|
||||
static_assert(total_elems_in == total_elems_out,
|
||||
"For mixed precision transpose, input/output element counts must match!");
|
||||
|
||||
using InDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYIn, 1>::type;
|
||||
using OutDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYOut, 1>::type;
|
||||
using InScalarsPerElemY = typename uniform_sequence_gen<NDimYIn, 1>::type;
|
||||
using OutScalarsPerElemY = typename uniform_sequence_gen<NDimYOut, 1>::type;
|
||||
|
||||
using InSFC_Y =
|
||||
space_filling_curve<decltype(y_in_lengths), InDimAccessOrderY, InScalarsPerElemY, false>;
|
||||
using OutSFC_Y = space_filling_curve<decltype(y_out_lengths),
|
||||
OutDimAccessOrderY,
|
||||
OutScalarsPerElemY,
|
||||
false>;
|
||||
|
||||
static_assert(InSFC_Y::get_num_of_access() == total_elems_in,
|
||||
"Unexpected input SFC access count for mixed precision transpose.");
|
||||
static_assert(OutSFC_Y::get_num_of_access() == total_elems_out,
|
||||
"Unexpected output SFC access count for mixed precision transpose.");
|
||||
|
||||
constexpr auto in_ys_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
constexpr auto in_ys_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
constexpr auto out_ys_rhs_major = OutDstrEncode::ys_to_rhs_major_;
|
||||
constexpr auto out_ys_rhs_minor = OutDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
constexpr auto in_rank = [&] {
|
||||
array<index_t, NDimYIn> rank{};
|
||||
static_for<0, NDimYIn, 1>{}([&](auto d) {
|
||||
index_t r = 0;
|
||||
static_for<0, NDimYIn, 1>{}([&](auto e) {
|
||||
constexpr bool less =
|
||||
(in_ys_rhs_major[e] < in_ys_rhs_major[d]) ||
|
||||
((in_ys_rhs_major[e] == in_ys_rhs_major[d]) &&
|
||||
(in_ys_rhs_minor[e] < in_ys_rhs_minor[d])) ||
|
||||
((in_ys_rhs_major[e] == in_ys_rhs_major[d]) &&
|
||||
(in_ys_rhs_minor[e] == in_ys_rhs_minor[d]) && (e < d));
|
||||
if constexpr(less)
|
||||
{
|
||||
++r;
|
||||
}
|
||||
});
|
||||
rank(d) = r;
|
||||
});
|
||||
return rank;
|
||||
}();
|
||||
|
||||
constexpr auto out_rank = [&] {
|
||||
array<index_t, NDimYOut> rank{};
|
||||
static_for<0, NDimYOut, 1>{}([&](auto d) {
|
||||
index_t r = 0;
|
||||
static_for<0, NDimYOut, 1>{}([&](auto e) {
|
||||
constexpr bool less =
|
||||
(out_ys_rhs_major[e] < out_ys_rhs_major[d]) ||
|
||||
((out_ys_rhs_major[e] == out_ys_rhs_major[d]) &&
|
||||
(out_ys_rhs_minor[e] < out_ys_rhs_minor[d])) ||
|
||||
((out_ys_rhs_major[e] == out_ys_rhs_major[d]) &&
|
||||
(out_ys_rhs_minor[e] == out_ys_rhs_minor[d]) && (e < d));
|
||||
if constexpr(less)
|
||||
{
|
||||
++r;
|
||||
}
|
||||
});
|
||||
rank(d) = r;
|
||||
});
|
||||
return rank;
|
||||
}();
|
||||
|
||||
constexpr auto in_order = [&] {
|
||||
array<index_t, NDimYIn> order{};
|
||||
static_for<0, NDimYIn, 1>{}([&](auto d) { order(in_rank[d]) = d; });
|
||||
return order;
|
||||
}();
|
||||
|
||||
constexpr auto out_order = [&] {
|
||||
array<index_t, NDimYOut> order{};
|
||||
static_for<0, NDimYOut, 1>{}([&](auto d) { order(out_rank[d]) = d; });
|
||||
return order;
|
||||
}();
|
||||
|
||||
constexpr auto y_in_lens = [&] {
|
||||
array<index_t, NDimYIn> lens{};
|
||||
static_for<0, NDimYIn, 1>{}([&](auto i) { lens(i) = y_in_lengths[i]; });
|
||||
return lens;
|
||||
}();
|
||||
|
||||
constexpr auto y_out_lens = [&] {
|
||||
array<index_t, NDimYOut> lens{};
|
||||
static_for<0, NDimYOut, 1>{}([&](auto i) { lens(i) = y_out_lengths[i]; });
|
||||
return lens;
|
||||
}();
|
||||
using OutSFC_Y =
|
||||
space_filling_curve<decltype(y_out_lengths), OutDimAccessOrderY, OutScalarsPerElemY, false>;
|
||||
|
||||
static_for<0, total_elems_out, 1>{}([&](auto i_out) {
|
||||
constexpr auto out_idx_y_seq = OutSFC_Y::get_index(i_out);
|
||||
constexpr auto idx_out = OutSFC_Y::get_index(i_out);
|
||||
constexpr index_t out_off = y_out_desc.calculate_offset(idx_out);
|
||||
|
||||
array<index_t, NDimYOut> out_idx_y{};
|
||||
static_for<0, NDimYOut, 1>{}([&](auto iy) { out_idx_y(iy) = out_idx_y_seq[iy]; });
|
||||
constexpr index_t linear = [&] {
|
||||
index_t v = 0;
|
||||
static_for<0, NDimYOut, 1>{}([&](auto d) {
|
||||
v = v * y_out_lengths[d] + idx_out[d];
|
||||
});
|
||||
return v;
|
||||
}();
|
||||
|
||||
index_t linear = 0;
|
||||
static_for<0, NDimYOut, 1>{}([&](auto k) {
|
||||
const index_t y = out_order[k];
|
||||
linear = linear * y_out_lens[y] + out_idx_y[y];
|
||||
});
|
||||
const auto idx_in = [&] {
|
||||
array<index_t, NDimYIn> idx{};
|
||||
index_t remain = linear;
|
||||
static_for<NDimYIn - 1, -1, -1>{}([&](auto d) {
|
||||
constexpr index_t len = y_in_lengths[d];
|
||||
idx(d) = remain % len;
|
||||
remain /= len;
|
||||
});
|
||||
return idx;
|
||||
}();
|
||||
|
||||
array<index_t, NDimYIn> in_idx_y{};
|
||||
index_t remain = linear;
|
||||
static_for<NDimYIn - 1, -1, -1>{}([&](auto k_rev) {
|
||||
const index_t y = in_order[k_rev];
|
||||
const index_t l = y_in_lens[y];
|
||||
in_idx_y[y] = remain % l;
|
||||
remain /= l;
|
||||
});
|
||||
const index_t in_off = y_in_desc.calculate_offset(idx_in);
|
||||
|
||||
const index_t in_offset = y_in_desc.calculate_offset(in_idx_y);
|
||||
const index_t out_offset = y_out_desc.calculate_offset(out_idx_y);
|
||||
|
||||
out_tensor.get_thread_buffer()[out_offset] =
|
||||
type_convert<OutputDataType>(trans_tensor.get_thread_buffer()[in_offset]);
|
||||
out_tensor.get_thread_buffer()[out_off] =
|
||||
type_convert<OutputDataType>(trans_tensor.get_thread_buffer()[in_off]);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user