mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fixed packed_cast implementation for slice access.
This commit is contained in:
@@ -991,14 +991,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCast<
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
@@ -1380,7 +1379,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCast<
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
|
||||
@@ -11,44 +11,47 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename SrcDesc, ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
struct PackedCast
|
||||
{
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf)
|
||||
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// Calculate total elements in this slice
|
||||
constexpr index_t elements_per_slice =
|
||||
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
|
||||
|
||||
constexpr auto calculate_coords = [&](auto idx) constexpr {
|
||||
constexpr index_t m4_offset = idx.value % M4;
|
||||
constexpr index_t m2_offset = (idx.value / M4) % M2;
|
||||
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
|
||||
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
return make_tuple(
|
||||
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
|
||||
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
|
||||
Number<0>{},
|
||||
Number<0>{},
|
||||
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
|
||||
Number<0>{},
|
||||
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
|
||||
Number<0>{}
|
||||
);
|
||||
constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr {
|
||||
|
||||
// We know that the access order is
|
||||
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
// Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, 1, 1, M2, 1, M4, 1>
|
||||
|
||||
constexpr index_t m4_offset = idx.value % M4;
|
||||
constexpr index_t m2_offset = (idx.value / M4) % M2;
|
||||
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
|
||||
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
return make_tuple(
|
||||
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
|
||||
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
|
||||
Number<0>{}, // this dim has unit size
|
||||
Number<0>{}, // this dim has unit size
|
||||
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
|
||||
Number<0>{}, // this dim has unit size
|
||||
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
|
||||
Number<0>{} // this dim has unit size
|
||||
);
|
||||
};
|
||||
|
||||
constexpr index_t num_pairs = elements_per_slice / 2;
|
||||
@@ -60,9 +63,12 @@ namespace ck {
|
||||
|
||||
constexpr auto coord_0 = calculate_coords(idx_0);
|
||||
constexpr auto coord_1 = calculate_coords(idx_1);
|
||||
|
||||
constexpr auto offset_0 = src_desc.CalculateOffset(coord_0);
|
||||
constexpr auto offset_1 = src_desc.CalculateOffset(coord_1);
|
||||
|
||||
float& val_0 = src_buf[coord_0];
|
||||
float& val_1 = src_buf[coord_1];
|
||||
float& val_0 = src_buf(Number<offset_0>{});
|
||||
float& val_1 = src_buf(Number<offset_1>{});
|
||||
|
||||
// Use packed conversion
|
||||
static_cast_float_to_bhalf_packed(val_0, val_1);
|
||||
@@ -75,7 +81,8 @@ namespace ck {
|
||||
constexpr auto last_coord = calculate_coords(last_idx);
|
||||
|
||||
// Single element conversion
|
||||
float& last_val = src_buf[last_coord];
|
||||
constexpr auto last_offset = src_desc.CalculateOffset(last_coord);
|
||||
float& last_val = src_buf[Number<last_offset>{}];
|
||||
const auto single_bf16 = static_cast<__bf16>(last_val);
|
||||
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
|
||||
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
|
||||
|
||||
Reference in New Issue
Block a user