Fixed packed_cast implementation for slice access.

This commit is contained in:
Ville Pietilä
2025-08-06 11:04:29 +00:00
parent 44202b9d32
commit 4b8a559da9
2 changed files with 33 additions and 28 deletions

View File

@@ -991,14 +991,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
if constexpr (is_gfx950_and_bf16_input_)
{
auto c_thread_packed_cast = PackedCast<
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
@@ -1380,7 +1379,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
if constexpr (is_gfx950_and_bf16_input_)
{
auto c_thread_packed_cast = PackedCast<
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
M2,
M4,
CShuffleMXdlPerWavePerShuffle,

View File

@@ -11,44 +11,47 @@
namespace ck {
template <typename SrcDesc, ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCast
{
template <typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf)
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// Calculate total elements in this slice
constexpr index_t elements_per_slice =
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
constexpr auto calculate_coords = [&](auto idx) constexpr {
constexpr index_t m4_offset = idx.value % M4;
constexpr index_t m2_offset = (idx.value / M4) % M2;
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
return make_tuple(
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
Number<0>{},
Number<0>{},
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
Number<0>{},
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
Number<0>{}
);
constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr {
// We know that the access order is
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, 1, 1, M2, 1, M4, 1>
constexpr index_t m4_offset = idx.value % M4;
constexpr index_t m2_offset = (idx.value / M4) % M2;
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
return make_tuple(
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
Number<0>{}, // this dim has unit size
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
Number<0>{}, // this dim has unit size
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
Number<0>{} // this dim has unit size
);
};
constexpr index_t num_pairs = elements_per_slice / 2;
@@ -60,9 +63,12 @@ namespace ck {
constexpr auto coord_0 = calculate_coords(idx_0);
constexpr auto coord_1 = calculate_coords(idx_1);
constexpr auto offset_0 = src_desc.CalculateOffset(coord_0);
constexpr auto offset_1 = src_desc.CalculateOffset(coord_1);
float& val_0 = src_buf[coord_0];
float& val_1 = src_buf[coord_1];
float& val_0 = src_buf(Number<offset_0>{});
float& val_1 = src_buf(Number<offset_1>{});
// Use packed conversion
static_cast_float_to_bhalf_packed(val_0, val_1);
@@ -75,7 +81,8 @@ namespace ck {
constexpr auto last_coord = calculate_coords(last_idx);
// Single element conversion
float& last_val = src_buf[last_coord];
constexpr auto last_offset = src_desc.CalculateOffset(last_coord);
float& last_val = src_buf[Number<last_offset>{}];
const auto single_bf16 = static_cast<__bf16>(last_val);
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);