WIP: Integration of packed cast into gridwise_gemm_xdl_cshuffle_conv_v3.

This commit is contained in:
Ville Pietilä
2025-08-05 15:12:36 +00:00
parent e92c0bf68e
commit 44202b9d32
4 changed files with 161 additions and 23 deletions

View File

@@ -112,9 +112,10 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
static constexpr bool is_gfx950_and_bf16_input_ = false;
#endif
using CShuffleInputDataType = std::conditional_t<is_gfx950_and_bf16_input_,
CShuffleDataType,
AccDataType>;
// using CShuffleInputDataType = std::conditional_t<is_gfx950_and_bf16_input_,
// CShuffleDataType,
// AccDataType>;
using CShuffleInputDataType = AccDataType;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
{
@@ -918,7 +919,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
1,
InMemoryDataOperationEnum::Set,
1,
true>{
true,
is_gfx950_and_bf16_input_>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
@@ -988,15 +990,26 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
if constexpr (is_gfx950_and_bf16_input_)
{
packed_cast(sfc_c_vgpr);
auto c_thread_packed_cast = PackedCast<
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
@@ -1295,7 +1308,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
1,
InMemoryDataOperationEnum::Set,
1,
true>{
true,
is_gfx950_and_bf16_input_>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
@@ -1365,7 +1379,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
if constexpr (is_gfx950_and_bf16_input_)
{
packed_cast(sfc_c_vgpr);
auto c_thread_packed_cast = PackedCast<
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS

View File

@@ -2,6 +2,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/host_utility/hip_check_error.hpp"
@@ -9,9 +11,77 @@
namespace ck {
__host__ __device__ inline void packed_cast(const auto& sfc_c_vgpr)
template <typename SrcDesc, ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCast
{
// This function is a placeholder for packed cast operations.
// For now, it does nothing.
template <typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// Calculate total elements in this slice
constexpr index_t elements_per_slice =
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
constexpr auto calculate_coords = [&](auto idx) constexpr {
constexpr index_t m4_offset = idx.value % M4;
constexpr index_t m2_offset = (idx.value / M4) % M2;
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
return make_tuple(
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
Number<0>{},
Number<0>{},
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
Number<0>{},
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
Number<0>{}
);
};
constexpr index_t num_pairs = elements_per_slice / 2;
constexpr bool has_odd_element = (elements_per_slice % 2 == 1);
static_for<0, num_pairs, 1>{}([&](auto pair_idx) {
constexpr auto idx_0 = Number<pair_idx * 2>{};
constexpr auto idx_1 = Number<pair_idx * 2 + 1>{};
constexpr auto coord_0 = calculate_coords(idx_0);
constexpr auto coord_1 = calculate_coords(idx_1);
float& val_0 = src_buf[coord_0];
float& val_1 = src_buf[coord_1];
// Use packed conversion
static_cast_float_to_bhalf_packed(val_0, val_1);
});
// Handle last element if the number of elements is odd.
if constexpr (has_odd_element)
{
constexpr auto last_idx = Number<elements_per_slice - 1>{};
constexpr auto last_coord = calculate_coords(last_idx);
// Single element conversion
float& last_val = src_buf[last_coord];
const auto single_bf16 = static_cast<__bf16>(last_val);
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
parts[1] = bf16_bits[0];
}
};
};
}

View File

@@ -34,11 +34,15 @@ template <typename SrcData,
InMemoryDataOperationEnum DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun,
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false,
bool PackedInput = false>
struct ThreadwiseTensorSliceTransfer_v1r3
{
static constexpr index_t nDim = SliceLengths::Size();
static constexpr bool float_input_and_bf16_output_ =
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
@@ -106,17 +110,30 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// copy data from src_buf into dst_vector
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
if constexpr (PackedInput && float_input_and_bf16_output_)
{
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
DstData v;
const float packed_float = src_buf[Number<src_offset>{}];
const bhalf_t* bf16_array = reinterpret_cast<const bhalf_t*>(&packed_float);
dst_vector.template AsType<DstData>()(i) = bf16_array[0];
});
}
else
{
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
DstData v;
dst_vector.template AsType<DstData>()(i) = v;
});
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
dst_vector.template AsType<DstData>()(i) = v;
});
}
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);

View File

@@ -68,6 +68,32 @@ inline __device__ bhalf2_t static_cast_float2_to_bhalf2_rne(float2_t x)
}
#endif
// TODO: Why do we need the host instance?
inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, float& y)
{
#if defined(__gfx950__)
uint32_t result;
asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2"
: "=v"(result)
: "v"(x), "v"(y));
// Extract individual BF16 values from packed result
const uint16_t* bf16_values = reinterpret_cast<const uint16_t*>(&result);
// Treat x and y as arrays of uint16_t
uint16_t* x_parts = reinterpret_cast<uint16_t*>(&x);
uint16_t* y_parts = reinterpret_cast<uint16_t*>(&y);
// Store BF16 values directly to the upper 16 bits (index 1 on little-endian)
x_parts[1] = bf16_values[0];
y_parts[1] = bf16_values[1];
#else
// Skip conversion for non-GFX950 architectures
x = static_cast<float>(static_cast<bhalf_t>(x));
y = static_cast<float>(static_cast<bhalf_t>(y));
#endif
}
// Declare a template function for conversion of bf16 vector of two values using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16x2_convert_rne(X x);