mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
WIP: Integration of packed cast into gridwise_gemm_xdl_cshuffle_conv_v3.
This commit is contained in:
@@ -112,9 +112,10 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
static constexpr bool is_gfx950_and_bf16_input_ = false;
|
||||
#endif
|
||||
|
||||
using CShuffleInputDataType = std::conditional_t<is_gfx950_and_bf16_input_,
|
||||
CShuffleDataType,
|
||||
AccDataType>;
|
||||
// using CShuffleInputDataType = std::conditional_t<is_gfx950_and_bf16_input_,
|
||||
// CShuffleDataType,
|
||||
// AccDataType>;
|
||||
using CShuffleInputDataType = AccDataType;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
@@ -918,7 +919,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
true,
|
||||
is_gfx950_and_bf16_input_>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
@@ -988,15 +990,26 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
packed_cast(sfc_c_vgpr);
|
||||
auto c_thread_packed_cast = PackedCast<
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
@@ -1295,7 +1308,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
true,
|
||||
is_gfx950_and_bf16_input_>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
@@ -1365,7 +1379,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
packed_cast(sfc_c_vgpr);
|
||||
auto c_thread_packed_cast = PackedCast<
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
@@ -9,9 +11,77 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
__host__ __device__ inline void packed_cast(const auto& sfc_c_vgpr)
|
||||
template <typename SrcDesc, ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
struct PackedCast
|
||||
{
|
||||
// This function is a placeholder for packed cast operations.
|
||||
// For now, it does nothing.
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// Calculate total elements in this slice
|
||||
constexpr index_t elements_per_slice =
|
||||
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
|
||||
|
||||
constexpr auto calculate_coords = [&](auto idx) constexpr {
|
||||
constexpr index_t m4_offset = idx.value % M4;
|
||||
constexpr index_t m2_offset = (idx.value / M4) % M2;
|
||||
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
|
||||
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
return make_tuple(
|
||||
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
|
||||
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
|
||||
Number<0>{},
|
||||
Number<0>{},
|
||||
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
|
||||
Number<0>{},
|
||||
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
|
||||
Number<0>{}
|
||||
);
|
||||
};
|
||||
|
||||
constexpr index_t num_pairs = elements_per_slice / 2;
|
||||
constexpr bool has_odd_element = (elements_per_slice % 2 == 1);
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto pair_idx) {
|
||||
constexpr auto idx_0 = Number<pair_idx * 2>{};
|
||||
constexpr auto idx_1 = Number<pair_idx * 2 + 1>{};
|
||||
|
||||
constexpr auto coord_0 = calculate_coords(idx_0);
|
||||
constexpr auto coord_1 = calculate_coords(idx_1);
|
||||
|
||||
float& val_0 = src_buf[coord_0];
|
||||
float& val_1 = src_buf[coord_1];
|
||||
|
||||
// Use packed conversion
|
||||
static_cast_float_to_bhalf_packed(val_0, val_1);
|
||||
});
|
||||
|
||||
// Handle last element if the number of elements is odd.
|
||||
if constexpr (has_odd_element)
|
||||
{
|
||||
constexpr auto last_idx = Number<elements_per_slice - 1>{};
|
||||
constexpr auto last_coord = calculate_coords(last_idx);
|
||||
|
||||
// Single element conversion
|
||||
float& last_val = src_buf[last_coord];
|
||||
const auto single_bf16 = static_cast<__bf16>(last_val);
|
||||
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
|
||||
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
|
||||
parts[1] = bf16_bits[0];
|
||||
}
|
||||
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -34,11 +34,15 @@ template <typename SrcData,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false,
|
||||
bool PackedInput = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
static constexpr bool float_input_and_bf16_output_ =
|
||||
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
@@ -106,17 +110,30 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
if constexpr (PackedInput && float_input_and_bf16_output_)
|
||||
{
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
DstData v;
|
||||
const float packed_float = src_buf[Number<src_offset>{}];
|
||||
const bhalf_t* bf16_array = reinterpret_cast<const bhalf_t*>(&packed_float);
|
||||
dst_vector.template AsType<DstData>()(i) = bf16_array[0];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
DstData v;
|
||||
|
||||
dst_vector.template AsType<DstData>()(i) = v;
|
||||
});
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
dst_vector.template AsType<DstData>()(i) = v;
|
||||
});
|
||||
}
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
@@ -68,6 +68,32 @@ inline __device__ bhalf2_t static_cast_float2_to_bhalf2_rne(float2_t x)
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: Why do we need the host instance?
|
||||
inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, float& y)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
uint32_t result;
|
||||
asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2"
|
||||
: "=v"(result)
|
||||
: "v"(x), "v"(y));
|
||||
|
||||
// Extract individual BF16 values from packed result
|
||||
const uint16_t* bf16_values = reinterpret_cast<const uint16_t*>(&result);
|
||||
|
||||
// Treat x and y as arrays of uint16_t
|
||||
uint16_t* x_parts = reinterpret_cast<uint16_t*>(&x);
|
||||
uint16_t* y_parts = reinterpret_cast<uint16_t*>(&y);
|
||||
|
||||
// Store BF16 values directly to the upper 16 bits (index 1 on little-endian)
|
||||
x_parts[1] = bf16_values[0];
|
||||
y_parts[1] = bf16_values[1];
|
||||
#else
|
||||
// Skip conversion for non-GFX950 architectures
|
||||
x = static_cast<float>(static_cast<bhalf_t>(x));
|
||||
y = static_cast<float>(static_cast<bhalf_t>(y));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Declare a template function for conversion of bf16 vector of two values using RNE
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16x2_convert_rne(X x);
|
||||
|
||||
Reference in New Issue
Block a user