Candidate fix 11

This commit is contained in:
Sami Aario
2026-02-23 13:54:53 +00:00
parent 3dad12583b
commit f35c9da001

View File

@@ -567,157 +567,61 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
index_t offset)
{
using OutputDataType = typename DistributedTensor_::DataType;
using OutTensor = remove_cvref_t<DistributedTensor_>;
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
using InTensor = remove_cvref_t<decltype(trans_tensor)>;
using OutTensor = remove_cvref_t<DistributedTensor_>;
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
using InTensor = remove_cvref_t<decltype(trans_tensor)>;
constexpr auto input_distr = typename InTensor::StaticTileDistribution{};
constexpr auto output_distr = typename OutTensor::StaticTileDistribution{};
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
static_assert(NDimYIn == NDimYOut,
"Mixed precision transpose conversion requires same Y rank.");
constexpr index_t total_elems_in = reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{});
constexpr index_t total_elems_out = reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{});
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
// For mixed precision: element space size must be the same (total bytes match)
static_assert(y_in_element_space_size == y_out_element_space_size,
"For mixed precision transpose, input and output element space size must match!");
constexpr index_t total_elems_in =
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{});
constexpr index_t total_elems_out =
reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{});
static_assert(total_elems_in == total_elems_out,
"For mixed precision transpose, input/output element counts must match!");
using InDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYIn, 1>::type;
using OutDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYOut, 1>::type;
using InScalarsPerElemY = typename uniform_sequence_gen<NDimYIn, 1>::type;
using OutScalarsPerElemY = typename uniform_sequence_gen<NDimYOut, 1>::type;
using InSFC_Y =
space_filling_curve<decltype(y_in_lengths), InDimAccessOrderY, InScalarsPerElemY, false>;
using OutSFC_Y = space_filling_curve<decltype(y_out_lengths),
OutDimAccessOrderY,
OutScalarsPerElemY,
false>;
static_assert(InSFC_Y::get_num_of_access() == total_elems_in,
"Unexpected input SFC access count for mixed precision transpose.");
static_assert(OutSFC_Y::get_num_of_access() == total_elems_out,
"Unexpected output SFC access count for mixed precision transpose.");
constexpr auto in_ys_rhs_major = InDstrEncode::ys_to_rhs_major_;
constexpr auto in_ys_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
constexpr auto out_ys_rhs_major = OutDstrEncode::ys_to_rhs_major_;
constexpr auto out_ys_rhs_minor = OutDstrEncode::ys_to_rhs_minor_;
constexpr auto in_rank = [&] {
array<index_t, NDimYIn> rank{};
static_for<0, NDimYIn, 1>{}([&](auto d) {
index_t r = 0;
static_for<0, NDimYIn, 1>{}([&](auto e) {
constexpr bool less =
(in_ys_rhs_major[e] < in_ys_rhs_major[d]) ||
((in_ys_rhs_major[e] == in_ys_rhs_major[d]) &&
(in_ys_rhs_minor[e] < in_ys_rhs_minor[d])) ||
((in_ys_rhs_major[e] == in_ys_rhs_major[d]) &&
(in_ys_rhs_minor[e] == in_ys_rhs_minor[d]) && (e < d));
if constexpr(less)
{
++r;
}
});
rank(d) = r;
});
return rank;
}();
constexpr auto out_rank = [&] {
array<index_t, NDimYOut> rank{};
static_for<0, NDimYOut, 1>{}([&](auto d) {
index_t r = 0;
static_for<0, NDimYOut, 1>{}([&](auto e) {
constexpr bool less =
(out_ys_rhs_major[e] < out_ys_rhs_major[d]) ||
((out_ys_rhs_major[e] == out_ys_rhs_major[d]) &&
(out_ys_rhs_minor[e] < out_ys_rhs_minor[d])) ||
((out_ys_rhs_major[e] == out_ys_rhs_major[d]) &&
(out_ys_rhs_minor[e] == out_ys_rhs_minor[d]) && (e < d));
if constexpr(less)
{
++r;
}
});
rank(d) = r;
});
return rank;
}();
constexpr auto in_order = [&] {
array<index_t, NDimYIn> order{};
static_for<0, NDimYIn, 1>{}([&](auto d) { order(in_rank[d]) = d; });
return order;
}();
constexpr auto out_order = [&] {
array<index_t, NDimYOut> order{};
static_for<0, NDimYOut, 1>{}([&](auto d) { order(out_rank[d]) = d; });
return order;
}();
constexpr auto y_in_lens = [&] {
array<index_t, NDimYIn> lens{};
static_for<0, NDimYIn, 1>{}([&](auto i) { lens(i) = y_in_lengths[i]; });
return lens;
}();
constexpr auto y_out_lens = [&] {
array<index_t, NDimYOut> lens{};
static_for<0, NDimYOut, 1>{}([&](auto i) { lens(i) = y_out_lengths[i]; });
return lens;
}();
using OutSFC_Y =
space_filling_curve<decltype(y_out_lengths), OutDimAccessOrderY, OutScalarsPerElemY, false>;
static_for<0, total_elems_out, 1>{}([&](auto i_out) {
constexpr auto out_idx_y_seq = OutSFC_Y::get_index(i_out);
constexpr auto idx_out = OutSFC_Y::get_index(i_out);
constexpr index_t out_off = y_out_desc.calculate_offset(idx_out);
array<index_t, NDimYOut> out_idx_y{};
static_for<0, NDimYOut, 1>{}([&](auto iy) { out_idx_y(iy) = out_idx_y_seq[iy]; });
constexpr index_t linear = [&] {
index_t v = 0;
static_for<0, NDimYOut, 1>{}([&](auto d) {
v = v * y_out_lengths[d] + idx_out[d];
});
return v;
}();
index_t linear = 0;
static_for<0, NDimYOut, 1>{}([&](auto k) {
const index_t y = out_order[k];
linear = linear * y_out_lens[y] + out_idx_y[y];
});
const auto idx_in = [&] {
array<index_t, NDimYIn> idx{};
index_t remain = linear;
static_for<NDimYIn - 1, -1, -1>{}([&](auto d) {
constexpr index_t len = y_in_lengths[d];
idx(d) = remain % len;
remain /= len;
});
return idx;
}();
array<index_t, NDimYIn> in_idx_y{};
index_t remain = linear;
static_for<NDimYIn - 1, -1, -1>{}([&](auto k_rev) {
const index_t y = in_order[k_rev];
const index_t l = y_in_lens[y];
in_idx_y[y] = remain % l;
remain /= l;
});
const index_t in_off = y_in_desc.calculate_offset(idx_in);
const index_t in_offset = y_in_desc.calculate_offset(in_idx_y);
const index_t out_offset = y_out_desc.calculate_offset(out_idx_y);
out_tensor.get_thread_buffer()[out_offset] =
type_convert<OutputDataType>(trans_tensor.get_thread_buffer()[in_offset]);
out_tensor.get_thread_buffer()[out_off] =
type_convert<OutputDataType>(trans_tensor.get_thread_buffer()[in_off]);
});
}