Remove obsolete version of the packed cast.

This commit is contained in:
Ville Pietilä
2025-08-27 11:28:55 +00:00
parent 481df169f2
commit 54302c6f77

View File

@@ -10,75 +10,6 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCast
{
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// Calculate total elements in this slice
constexpr index_t elements_per_slice =
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr {
// We know that the access order is
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, 1, 1, M2, 1, M4, 1>
constexpr index_t m4_offset = idx.value % M4;
constexpr index_t m2_offset = (idx.value / M4) % M2;
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
return make_tuple(
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
Number<0>{}, // this dim has unit size
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
Number<0>{} // this dim has unit size
);
};
constexpr index_t num_pairs = elements_per_slice / 2;
constexpr bool has_odd_element = (elements_per_slice % 2 == 1);
static_assert(!has_odd_element, "PackedCast does not support odd number of elements");
static_for<0, num_pairs, 1>{}([&](auto pair_idx) {
constexpr auto idx_0 = Number<pair_idx * 2>{};
constexpr auto idx_1 = Number<pair_idx * 2 + 1>{};
constexpr auto coord_0 = calculate_coords(idx_0);
constexpr auto coord_1 = calculate_coords(idx_1);
constexpr auto offset_0 = src_desc.CalculateOffset(coord_0);
constexpr auto offset_1 = src_desc.CalculateOffset(coord_1);
float& val_0 = src_buf(Number<offset_0>{});
const float val_1 = src_buf[Number<offset_1>{}];
static_cast_float_to_bhalf_packed_v2(val_0, val_1);
});
};
};
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCastV2
{