Working version 2 of the packed cast.

This commit is contained in:
Ville Pietilä
2025-08-12 12:46:01 +00:00
parent 6148d1c75f
commit cee7644c85
4 changed files with 80 additions and 43 deletions

View File

@@ -800,15 +800,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
if constexpr(is_gfx950_and_bf16_input_)
{
constexpr auto register_size_per_xdl_op = blockwise_gemm_pipeline.xdlops_gemm.GetRegSizePerXdlops();
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
ck::bhalf2_t,
(MXdlPerWave * NXdlPerWave + 1)/ 2,
register_size_per_xdl_op,
true> c_thread_buf_bf16;
}
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
@@ -1010,8 +1001,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf, // source buffer,
c_thread_buf_bf16 // destination buffer
c_thread_buf // source buffer
);
}
@@ -1390,7 +1380,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
if constexpr (is_gfx950_and_bf16_input_)
{
auto c_thread_packed_cast = PackedCast<
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,

View File

@@ -6,6 +6,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
@@ -94,6 +95,10 @@ namespace ck {
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCastV2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
@@ -130,8 +135,8 @@ namespace ck {
static_for<0, num_pairs, 1>{}([&](auto i_pair)
{
constexpr auto idx_1d_0 = 2 * i_pair;
constexpr auto idx_1d_1 = 2 * i_pair + 1;
constexpr auto idx_1d_0 = I2 * i_pair;
constexpr auto idx_1d_1 = I2 * i_pair + I1;
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
@@ -146,16 +151,17 @@ namespace ck {
// Handle last element if the number of elements is odd.
if constexpr (has_odd_element)
{
constexpr auto last_idx_1d = num_access - 1;
constexpr auto last_idx_1d = Number<num_access - 1>{};
constexpr auto last_idx_md = SpaceFillingCurve::GetIndex(last_idx_1d);
// Single element conversion
constexpr auto last_src_offset = src_desc.CalculateOffset(last_idx_1d);
float& last_val = src_buf(Number<last_offset>{});
float& last_val = src_buf(Number<last_src_offset>{});
const auto single_bf16 = static_cast<__bf16>(last_val);
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
parts[0] = bf16_bits[0];
}
};
};
}

View File

@@ -35,10 +35,14 @@ template <typename SrcData,
InMemoryDataOperationEnum DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun,
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false,
bool PackedInput = false>
bool PackedInput = false,
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = SliceLengths::Size();
static constexpr bool float_input_and_bf16_output_ =
@@ -86,38 +90,76 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr (PackedInput && float_input_and_bf16_output_)
{
static_assert(DstScalarPerVector == 2, "wrong! DstScalarPerVector must be 2");
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
// TODO: Fill the dst_vector vectprized access of size 2.
// Copying the dst_vector into dst_buf should be the same as before.
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
static_for<0, num_access, 1>{}([&](auto idx_1d)
{
// We need map the odd indices to the even indices, since
// the even indices contain a packed bf16x2 value, where
// the first value contains the bf16 value for the corresponding even index
// and the second value contains the bf16 value for the odd index following the even index.
// The odd indices are not used, so we can just ignore them.
constexpr auto pair_index = idx_1d % I2;
constexpr auto idx_src_1d = idx_1d - pair_index;
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
union
{
float src_float;
bhalf16_t src_bf16x2;
} packed_value;
packed_value.src_float = src_buf[Number<src_offset>{}];
dst_vector.template AsType<DstData>()(I0) = packed_value.src_bf16x2[pair_index.value];
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
}
else
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);

View File

@@ -227,7 +227,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
}
// Use --gtest_filter="TestGroupedConvndBwdWeight2d_bf16_gfx950/*" to only this subset of tests.
#if defined(__gfx950__)
// List tests in the test suite with --gtest_list_tests
template <typename Tuple>
class TestGroupedConvndBwdWeight2d_bf16_gfx950 : public TestGroupedConvndBwdWeight<Tuple>
@@ -257,4 +257,3 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d_bf16_gfx950, Test2D)
this->Run();
}
#endif