Candidate fix 7

This commit is contained in:
Sami Aario
2026-02-23 09:03:20 +00:00
parent 1de8bc9501
commit c79ab1f84a

View File

@@ -566,7 +566,6 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
NumCoord>& __restrict__ tile_window,
index_t offset)
{
using InputDataType = typename BottomTensorView_::DataType;
using OutputDataType = typename DistributedTensor_::DataType;
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
@@ -574,9 +573,6 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
using InTensor = remove_cvref_t<decltype(trans_tensor)>;
using OutTensor = remove_cvref_t<DistributedTensor_>;
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
constexpr auto input_distr = typename InTensor::StaticTileDistribution{};
constexpr auto output_distr = typename OutTensor::StaticTileDistribution{};
@@ -606,54 +602,78 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
static_assert(total_elems_in == total_elems_out,
"For mixed precision transpose, input/output element counts must match!");
constexpr index_t in_rh_major_count = InDstrEncode::NDimX + 1;
constexpr index_t out_rh_major_count = OutDstrEncode::NDimX + 1;
constexpr auto in_ps_ys_to_xs = input_distr.get_ps_ys_to_xs_adaptor();
constexpr auto out_ps_ys_to_xs = output_distr.get_ps_ys_to_xs_adaptor();
static_assert(in_rh_major_count == out_rh_major_count,
"Input/output RH major dimension count must match.");
const auto in_ps_idx = input_distr.get_partition_index();
const auto out_ps_idx = output_distr.get_partition_index();
static_for<1, in_rh_major_count, 1>{}([&](auto rh_major) {
constexpr index_t in_ndim_rh_minor = InDstrEncode::detail::ndims_rhs_minor_[rh_major];
constexpr index_t out_ndim_rh_minor = OutDstrEncode::detail::ndims_rhs_minor_[rh_major];
using InDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYIn, 1>::type;
using OutDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYOut, 1>::type;
using InScalarsPerElemY = typename uniform_sequence_gen<NDimYIn, 1>::type;
using OutScalarsPerElemY = typename uniform_sequence_gen<NDimYOut, 1>::type;
static_assert(in_ndim_rh_minor == out_ndim_rh_minor,
"Input/output RH minor dimension count must match per RH major.");
using InSFC_Y =
space_filling_curve<decltype(y_in_lengths), InDimAccessOrderY, InScalarsPerElemY, false>;
using OutSFC_Y = space_filling_curve<decltype(y_out_lengths),
OutDimAccessOrderY,
OutScalarsPerElemY,
false>;
static_for<0, in_ndim_rh_minor, 1>{}([&](auto rh_minor) {
constexpr index_t i_in = InDstrEncode::detail::rhs_major_minor_to_ys_[rh_major]
[rh_minor];
constexpr index_t i_out = OutDstrEncode::detail::rhs_major_minor_to_ys_[rh_major]
[rh_minor];
static_assert(InSFC_Y::get_num_of_access() == total_elems_in,
"Unexpected input SFC access count for mixed precision transpose.");
static_assert(OutSFC_Y::get_num_of_access() == total_elems_out,
"Unexpected output SFC access count for mixed precision transpose.");
static_assert(i_in >= 0 && i_out >= 0,
"Every H-space RH coordinate must map to valid Y dims.");
static_assert(y_in_lengths[number<i_in>{}] == y_out_lengths[number<i_out>{}],
"Mapped Y dimensions must have equal lengths.");
static_for<0, total_elems_in, 1>{}([&](auto i_in) {
constexpr auto in_idx_y = InSFC_Y::get_index(i_in);
constexpr index_t in_offset = y_in_desc.calculate_offset(in_idx_y);
array<index_t, NDimYIn> in_idx_y_arr{};
static_for<0, NDimYIn, 1>{}([&](auto iy) {
in_idx_y_arr(iy) = in_idx_y[iy];
});
});
const auto in_ps_ys_idx = container_concat(in_ps_idx, in_idx_y_arr);
const auto in_adaptor_coord = make_tensor_adaptor_coordinate(in_ps_ys_to_xs, in_ps_ys_idx);
const auto in_x_idx = in_adaptor_coord.get_bottom_index();
sweep_tile<InTensor>([&](auto idx_in) {
constexpr auto idx_y_in = InTensor::get_tile_distribution().get_y_indices_from_distributed_indices(
decltype(idx_in){});
index_t out_offset = -1;
bool found = false;
constexpr auto idx_y_out = generate_sequence_v2(
[&](auto i_out) {
constexpr index_t rh_major = OutDstrEncode::ys_to_rhs_major_[i_out];
constexpr index_t rh_minor = OutDstrEncode::ys_to_rhs_minor_[i_out];
constexpr index_t i_in =
InDstrEncode::detail::rhs_major_minor_to_ys_[rh_major][rh_minor];
static_for<0, total_elems_out, 1>{}([&](auto i_out) {
if(found)
{
return;
}
static_assert(i_in >= 0, "Input Y dim for output RH coordinate was not found.");
constexpr auto out_idx_y = OutSFC_Y::get_index(i_out);
array<index_t, NDimYOut> out_idx_y_arr{};
static_for<0, NDimYOut, 1>{}([&](auto iy) {
out_idx_y_arr(iy) = out_idx_y[iy];
});
const auto out_ps_ys_idx = container_concat(out_ps_idx, out_idx_y_arr);
const auto out_adaptor_coord =
make_tensor_adaptor_coordinate(out_ps_ys_to_xs, out_ps_ys_idx);
const auto out_x_idx = out_adaptor_coord.get_bottom_index();
return number<idx_y_in[number<i_in>{}]>{};
},
number<NDimYOut>{});
bool same_x = true;
static_for<0, input_distr.get_num_of_dimension_x(), 1>{}([&](auto ix) {
same_x = same_x && (in_x_idx[ix] == out_x_idx[ix]);
});
//constexpr index_t in_off = y_in_desc.calculate_offset(idx_y_in);
constexpr index_t out_off = y_out_desc.calculate_offset(idx_y_out);
if(same_x)
{
out_offset = y_out_desc.calculate_offset(out_idx_y);
found = true;
}
});
out_tensor.get_thread_buffer()[number<out_off>{}] =
type_convert<OutputDataType>(trans_tensor(idx_in));
if(!found)
{
out_offset = in_offset;
}
out_tensor.get_thread_buffer()[out_offset] =
type_convert<OutputDataType>(trans_tensor.get_thread_buffer()[number<in_offset>{}]);
});
}