mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Candidate fix 7
This commit is contained in:
@@ -566,7 +566,6 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
|
||||
NumCoord>& __restrict__ tile_window,
|
||||
index_t offset)
|
||||
{
|
||||
using InputDataType = typename BottomTensorView_::DataType;
|
||||
using OutputDataType = typename DistributedTensor_::DataType;
|
||||
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
@@ -574,9 +573,6 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
|
||||
using InTensor = remove_cvref_t<decltype(trans_tensor)>;
|
||||
using OutTensor = remove_cvref_t<DistributedTensor_>;
|
||||
|
||||
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
|
||||
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
|
||||
|
||||
constexpr auto input_distr = typename InTensor::StaticTileDistribution{};
|
||||
constexpr auto output_distr = typename OutTensor::StaticTileDistribution{};
|
||||
|
||||
@@ -606,54 +602,78 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
|
||||
static_assert(total_elems_in == total_elems_out,
|
||||
"For mixed precision transpose, input/output element counts must match!");
|
||||
|
||||
constexpr index_t in_rh_major_count = InDstrEncode::NDimX + 1;
|
||||
constexpr index_t out_rh_major_count = OutDstrEncode::NDimX + 1;
|
||||
constexpr auto in_ps_ys_to_xs = input_distr.get_ps_ys_to_xs_adaptor();
|
||||
constexpr auto out_ps_ys_to_xs = output_distr.get_ps_ys_to_xs_adaptor();
|
||||
|
||||
static_assert(in_rh_major_count == out_rh_major_count,
|
||||
"Input/output RH major dimension count must match.");
|
||||
const auto in_ps_idx = input_distr.get_partition_index();
|
||||
const auto out_ps_idx = output_distr.get_partition_index();
|
||||
|
||||
static_for<1, in_rh_major_count, 1>{}([&](auto rh_major) {
|
||||
constexpr index_t in_ndim_rh_minor = InDstrEncode::detail::ndims_rhs_minor_[rh_major];
|
||||
constexpr index_t out_ndim_rh_minor = OutDstrEncode::detail::ndims_rhs_minor_[rh_major];
|
||||
using InDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYIn, 1>::type;
|
||||
using OutDimAccessOrderY = typename arithmetic_sequence_gen<0, NDimYOut, 1>::type;
|
||||
using InScalarsPerElemY = typename uniform_sequence_gen<NDimYIn, 1>::type;
|
||||
using OutScalarsPerElemY = typename uniform_sequence_gen<NDimYOut, 1>::type;
|
||||
|
||||
static_assert(in_ndim_rh_minor == out_ndim_rh_minor,
|
||||
"Input/output RH minor dimension count must match per RH major.");
|
||||
using InSFC_Y =
|
||||
space_filling_curve<decltype(y_in_lengths), InDimAccessOrderY, InScalarsPerElemY, false>;
|
||||
using OutSFC_Y = space_filling_curve<decltype(y_out_lengths),
|
||||
OutDimAccessOrderY,
|
||||
OutScalarsPerElemY,
|
||||
false>;
|
||||
|
||||
static_for<0, in_ndim_rh_minor, 1>{}([&](auto rh_minor) {
|
||||
constexpr index_t i_in = InDstrEncode::detail::rhs_major_minor_to_ys_[rh_major]
|
||||
[rh_minor];
|
||||
constexpr index_t i_out = OutDstrEncode::detail::rhs_major_minor_to_ys_[rh_major]
|
||||
[rh_minor];
|
||||
static_assert(InSFC_Y::get_num_of_access() == total_elems_in,
|
||||
"Unexpected input SFC access count for mixed precision transpose.");
|
||||
static_assert(OutSFC_Y::get_num_of_access() == total_elems_out,
|
||||
"Unexpected output SFC access count for mixed precision transpose.");
|
||||
|
||||
static_assert(i_in >= 0 && i_out >= 0,
|
||||
"Every H-space RH coordinate must map to valid Y dims.");
|
||||
static_assert(y_in_lengths[number<i_in>{}] == y_out_lengths[number<i_out>{}],
|
||||
"Mapped Y dimensions must have equal lengths.");
|
||||
static_for<0, total_elems_in, 1>{}([&](auto i_in) {
|
||||
constexpr auto in_idx_y = InSFC_Y::get_index(i_in);
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(in_idx_y);
|
||||
array<index_t, NDimYIn> in_idx_y_arr{};
|
||||
static_for<0, NDimYIn, 1>{}([&](auto iy) {
|
||||
in_idx_y_arr(iy) = in_idx_y[iy];
|
||||
});
|
||||
});
|
||||
const auto in_ps_ys_idx = container_concat(in_ps_idx, in_idx_y_arr);
|
||||
const auto in_adaptor_coord = make_tensor_adaptor_coordinate(in_ps_ys_to_xs, in_ps_ys_idx);
|
||||
const auto in_x_idx = in_adaptor_coord.get_bottom_index();
|
||||
|
||||
sweep_tile<InTensor>([&](auto idx_in) {
|
||||
constexpr auto idx_y_in = InTensor::get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
decltype(idx_in){});
|
||||
index_t out_offset = -1;
|
||||
bool found = false;
|
||||
|
||||
constexpr auto idx_y_out = generate_sequence_v2(
|
||||
[&](auto i_out) {
|
||||
constexpr index_t rh_major = OutDstrEncode::ys_to_rhs_major_[i_out];
|
||||
constexpr index_t rh_minor = OutDstrEncode::ys_to_rhs_minor_[i_out];
|
||||
constexpr index_t i_in =
|
||||
InDstrEncode::detail::rhs_major_minor_to_ys_[rh_major][rh_minor];
|
||||
static_for<0, total_elems_out, 1>{}([&](auto i_out) {
|
||||
if(found)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
static_assert(i_in >= 0, "Input Y dim for output RH coordinate was not found.");
|
||||
constexpr auto out_idx_y = OutSFC_Y::get_index(i_out);
|
||||
array<index_t, NDimYOut> out_idx_y_arr{};
|
||||
static_for<0, NDimYOut, 1>{}([&](auto iy) {
|
||||
out_idx_y_arr(iy) = out_idx_y[iy];
|
||||
});
|
||||
const auto out_ps_ys_idx = container_concat(out_ps_idx, out_idx_y_arr);
|
||||
const auto out_adaptor_coord =
|
||||
make_tensor_adaptor_coordinate(out_ps_ys_to_xs, out_ps_ys_idx);
|
||||
const auto out_x_idx = out_adaptor_coord.get_bottom_index();
|
||||
|
||||
return number<idx_y_in[number<i_in>{}]>{};
|
||||
},
|
||||
number<NDimYOut>{});
|
||||
bool same_x = true;
|
||||
static_for<0, input_distr.get_num_of_dimension_x(), 1>{}([&](auto ix) {
|
||||
same_x = same_x && (in_x_idx[ix] == out_x_idx[ix]);
|
||||
});
|
||||
|
||||
//constexpr index_t in_off = y_in_desc.calculate_offset(idx_y_in);
|
||||
constexpr index_t out_off = y_out_desc.calculate_offset(idx_y_out);
|
||||
if(same_x)
|
||||
{
|
||||
out_offset = y_out_desc.calculate_offset(out_idx_y);
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
|
||||
out_tensor.get_thread_buffer()[number<out_off>{}] =
|
||||
type_convert<OutputDataType>(trans_tensor(idx_in));
|
||||
if(!found)
|
||||
{
|
||||
out_offset = in_offset;
|
||||
}
|
||||
|
||||
out_tensor.get_thread_buffer()[out_offset] =
|
||||
type_convert<OutputDataType>(trans_tensor.get_thread_buffer()[number<in_offset>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user