mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Add back the separate packed cast step.
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -897,7 +898,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -1006,6 +1007,21 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
@@ -1292,7 +1308,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -1401,6 +1417,21 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -1645,7 +1646,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -1762,6 +1763,21 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
@@ -2077,7 +2093,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -2188,6 +2204,21 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
@@ -896,7 +897,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
|
||||
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
@@ -1001,6 +1002,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// make sure it's safe to do ds_write
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx650_and_bf16_output())
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct)
|
||||
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
|
||||
141
include/ck/tensor_operation/gpu/grid/packed_cast.hpp
Normal file
141
include/ck/tensor_operation/gpu/grid/packed_cast.hpp
Normal file
@@ -0,0 +1,141 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
struct PackedCast
|
||||
{
|
||||
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// Calculate total elements in this slice
|
||||
constexpr index_t elements_per_slice =
|
||||
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
|
||||
|
||||
constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr {
|
||||
|
||||
// We know that the access order is
|
||||
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
// Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, 1, 1, M2, 1, M4, 1>
|
||||
|
||||
constexpr index_t m4_offset = idx.value % M4;
|
||||
constexpr index_t m2_offset = (idx.value / M4) % M2;
|
||||
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
|
||||
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
return make_tuple(
|
||||
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
|
||||
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
|
||||
Number<0>{}, // this dim has unit size
|
||||
Number<0>{}, // this dim has unit size
|
||||
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
|
||||
Number<0>{}, // this dim has unit size
|
||||
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
|
||||
Number<0>{} // this dim has unit size
|
||||
);
|
||||
};
|
||||
|
||||
constexpr index_t num_pairs = elements_per_slice / 2;
|
||||
constexpr bool has_odd_element = (elements_per_slice % 2 == 1);
|
||||
|
||||
static_assert(!has_odd_element, "PackedCast does not support odd number of elements");
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto pair_idx) {
|
||||
constexpr auto idx_0 = Number<pair_idx * 2>{};
|
||||
constexpr auto idx_1 = Number<pair_idx * 2 + 1>{};
|
||||
|
||||
constexpr auto coord_0 = calculate_coords(idx_0);
|
||||
constexpr auto coord_1 = calculate_coords(idx_1);
|
||||
|
||||
constexpr auto offset_0 = src_desc.CalculateOffset(coord_0);
|
||||
constexpr auto offset_1 = src_desc.CalculateOffset(coord_1);
|
||||
|
||||
float& val_0 = src_buf(Number<offset_0>{});
|
||||
const float val_1 = src_buf[Number<offset_1>{}];
|
||||
|
||||
static_cast_float_to_bhalf_packed_v2(val_0, val_1);
|
||||
});
|
||||
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
struct PackedCastV2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
using SliceLengths = Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>;
|
||||
using DimAccessOrder = Sequence<0, 1, 2, 3, 4, 5, 6, 7>;
|
||||
using DstScalarPerAccess = Sequence<1, 1, 1, 1, 1, 1, 1, 1>;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
DstScalarPerAccess>;
|
||||
|
||||
static_assert(SpaceFillingCurve::ScalarPerVector == 1,
|
||||
"wrong! SpaceFillingCurve::ScalarPerVector must be 1 for PackedCastV2");
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr index_t num_pairs = num_access / 2;
|
||||
constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
|
||||
static_assert(!has_odd_element, "PackedCastV2 does not support odd number of elements");
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
{
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
|
||||
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
|
||||
|
||||
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
|
||||
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
|
||||
|
||||
float& val_0 = src_buf(Number<src_offset_0>{});
|
||||
const float val_1 = src_buf[Number<src_offset_1>{}];
|
||||
static_cast_float_to_bhalf_packed_v2(val_0, val_1);
|
||||
});
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -13,6 +13,200 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
|
||||
|
||||
namespace ck {
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is StaticBuffer
|
||||
// 3. SrcSliceOrginIdx is known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is not known at compile-time
|
||||
// 2. DstBuffer is DynamicBuffer
|
||||
// 3. DstSliceOrginIdx is not known at compile time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
bool PackedInput = false,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
static constexpr bool float_input_and_bf16_output_ =
|
||||
std::is_same_v<SrcData, float> && std::is_same_v<DstData, ck::bhalf_t>;
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_pass_through(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const ElementwiseOperation& element_op)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
|
||||
element_op_{element_op}
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
|
||||
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
|
||||
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d)
|
||||
{
|
||||
// We need map the odd indices to the even indices, since
|
||||
// the even indices contain a packed bf16x2 value, where
|
||||
// the first value contains the bf16 value for the corresponding even index
|
||||
// and the second value contains the bf16 value for the odd index following the even index.
|
||||
// The odd indices are not used, so we can just ignore them.
|
||||
constexpr auto pair_index = idx_1d % I2;
|
||||
constexpr auto idx_src_1d = idx_1d - pair_index;
|
||||
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
|
||||
|
||||
union
|
||||
{
|
||||
float src_float;
|
||||
bhalf2_t src_bf16x2;
|
||||
} packed_value;
|
||||
|
||||
packed_value.src_float = src_buf[Number<src_offset>{}];
|
||||
|
||||
DstData v;
|
||||
element_op_(v, packed_value.src_bf16x2[pair_index.value]);
|
||||
dst_vector.template AsType<DstData>()(I0) = v;
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
}; // namespace ThreadwiseTensorSliceTransfer_v1r3_pass_through
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
@@ -292,7 +486,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
@@ -305,12 +499,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
DstData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
dst_vector.template AsType<DstData>()(i) = v;
|
||||
});
|
||||
|
||||
@@ -379,6 +574,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
|
||||
@@ -91,6 +91,25 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(fl
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts two floats into a vector of 2 16-bit bfloat types (bhalf2_t) using
|
||||
* rounding to nearest/even (RNE).
|
||||
*
|
||||
* @param x First float value. The packed value is stored here.
|
||||
* @param y Second float value.
|
||||
*/
|
||||
inline __host__ __device__ void static_cast_float_to_bhalf_packed_v2(float& x, float y)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
bhalf2_t bf16x2;
|
||||
} converter;
|
||||
|
||||
converter.bf16x2 = {bf16_convert_rtn<bhalf_t>(x), bf16_convert_rtn<bhalf_t>(y)};
|
||||
x = converter.fp32;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts two floats into a vector of 2 16-bit bfloat types (bhalf2_t) using
|
||||
* rounding to nearest/even (RNE).
|
||||
|
||||
@@ -15,45 +15,6 @@
|
||||
using ck::bhalf_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(BHALF_T, Nan)
|
||||
{
|
||||
const uint16_t binary_bhalf_nan = 0x7FC0;
|
||||
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
|
||||
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Inf)
|
||||
{
|
||||
const uint16_t binary_bhalf_inf = 0x7F80;
|
||||
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
|
||||
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, MantisaOverflow)
|
||||
{
|
||||
const float abs_tol = std::pow(2, -7);
|
||||
const uint32_t val = 0x81FFFFFF;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
|
||||
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, ExpOverflow)
|
||||
{
|
||||
const uint32_t val = 0xFF800000;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, MantisaExpOverflow)
|
||||
{
|
||||
const uint32_t val = 0xFFFFFFFF;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
|
||||
ASSERT_TRUE(std::isnan(float_val));
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
|
||||
}
|
||||
|
||||
__global__ void cast_roundtrip(const float2 input, float2* output)
|
||||
{
|
||||
const ck::bhalf2_t bhalf2_val = ck::bf16x2_convert_rne<ck::bhalf2_t, float>(input.x, input.y);
|
||||
@@ -74,10 +35,26 @@ __global__ void cast(const float input, float* output)
|
||||
*output = type_convert<float>(bhalf_val);
|
||||
}
|
||||
|
||||
__global__ void packed_cast_in_place(const float x1, const float x2, ck::bhalf2_t* output)
|
||||
{
|
||||
union
|
||||
{
|
||||
float src;
|
||||
ck::bhalf2_t dst;
|
||||
} converter;
|
||||
|
||||
float x = x1;
|
||||
ck::static_cast_float_to_bhalf_packed_v2(x, x2);
|
||||
|
||||
converter.src = x;
|
||||
*output = converter.dst;
|
||||
}
|
||||
|
||||
enum struct CastMode : int
|
||||
{
|
||||
Standard = 0,
|
||||
Packed = 1
|
||||
Packed = 1,
|
||||
PackedInPlace = 2
|
||||
};
|
||||
|
||||
template <CastMode PackedCast, int NumElements>
|
||||
@@ -119,6 +96,60 @@ __global__ void test_performance_kernel(float* input, ck::bhalf_t* output)
|
||||
}
|
||||
}
|
||||
|
||||
template <CastMode PackedCast, int NumElements>
|
||||
__global__ void test_in_place_performance_kernel(float* input, ck::bhalf_t* output)
|
||||
{
|
||||
ck::bhalf_t buffer_bf16[NumElements];
|
||||
float buffer_float[NumElements];
|
||||
|
||||
// Initialize input data
|
||||
for(int i = 0; i < NumElements; i++)
|
||||
{
|
||||
buffer_float[i] = input[i];
|
||||
}
|
||||
|
||||
if constexpr (PackedCast == CastMode::PackedInPlace)
|
||||
{
|
||||
union
|
||||
{
|
||||
float src;
|
||||
ck::bhalf2_t dst;
|
||||
} workspace;
|
||||
|
||||
for(int i = 0; i < NumElements; i++)
|
||||
{
|
||||
for (int j = 0; j < NumElements; j++)
|
||||
{
|
||||
int index = (i + j) % NumElements;
|
||||
index = index < NumElements - 1 ? index : NumElements - 2;
|
||||
workspace.src = buffer_float[i];
|
||||
ck::static_cast_float_to_bhalf_packed_v2(workspace.src, buffer_float[j]);
|
||||
ck::bhalf2_t* buffer_range = reinterpret_cast<ck::bhalf2_t*>(&buffer_bf16[index]);
|
||||
*buffer_range = workspace.dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < NumElements; i++)
|
||||
{
|
||||
for (int j = 0; j < NumElements; j++)
|
||||
{
|
||||
int index = (i + j) % NumElements;
|
||||
index = index < NumElements - 1 ? index : NumElements - 2;
|
||||
buffer_bf16[index] = ck::bf16_convert_rtn_base(buffer_float[i]);
|
||||
buffer_bf16[index + 1] = ck::bf16_convert_rtn_base(buffer_float[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy results back to output
|
||||
for(int i = 0; i < NumElements; i++)
|
||||
{
|
||||
output[i] = buffer_bf16[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int NumElements>
|
||||
void run_performance_test()
|
||||
{
|
||||
@@ -165,6 +196,91 @@ void run_performance_test()
|
||||
ASSERT_LT(packed_time, baseline_time);
|
||||
}
|
||||
|
||||
template <int NumElements>
|
||||
void run_in_place_performance_test()
|
||||
{
|
||||
float* input_dev;
|
||||
ck::bhalf_t* output_dev;
|
||||
std::vector<ck::bhalf_t> output_host(NumElements);
|
||||
|
||||
hip_check_error(hipMalloc(&input_dev, sizeof(float) * NumElements));
|
||||
hip_check_error(hipMalloc(&output_dev, sizeof(ck::bhalf_t) * NumElements));
|
||||
|
||||
// Initialize input data on the device
|
||||
std::vector<float> input_host(NumElements);
|
||||
for (int i = 0; i < NumElements; i++)
|
||||
{
|
||||
input_host[i] = 3.14f * static_cast<float>(i) - 1.7f;
|
||||
}
|
||||
|
||||
hip_check_error(hipMemcpy(input_dev, input_host.data(), sizeof(float) * NumElements, hipMemcpyHostToDevice));
|
||||
|
||||
StreamConfig stream_config;
|
||||
stream_config.time_kernel_ = true;
|
||||
|
||||
auto baseline_kernel = test_in_place_performance_kernel<CastMode::Standard, NumElements>;
|
||||
auto packed_kernel = test_in_place_performance_kernel<CastMode::PackedInPlace, NumElements>;
|
||||
|
||||
constexpr dim3 grid_size(1);
|
||||
constexpr dim3 block_size(1);
|
||||
constexpr size_t shared_mem_size = 0;
|
||||
|
||||
const float baseline_time = launch_and_time_kernel(stream_config, baseline_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
|
||||
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
|
||||
|
||||
const float packed_time = launch_and_time_kernel(stream_config, packed_kernel, grid_size, block_size, shared_mem_size, input_dev, output_dev);
|
||||
hip_check_error(hipMemcpy(output_host.data(), output_dev, sizeof(ck::bhalf_t) * NumElements, hipMemcpyDeviceToHost));
|
||||
|
||||
// Cleanup
|
||||
hip_check_error(hipFree(input_dev));
|
||||
hip_check_error(hipFree(output_dev));
|
||||
|
||||
std::cout << "Packed cast time ( " << NumElements << " elements): " << packed_time << " ms" << std::endl;
|
||||
std::cout << "Baseline cast time ( " << NumElements << " elements): " << baseline_time << " ms" << std::endl;
|
||||
|
||||
// Check if packed cast is faster than baseline
|
||||
ASSERT_LT(packed_time, baseline_time);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Nan)
|
||||
{
|
||||
const uint16_t binary_bhalf_nan = 0x7FC0;
|
||||
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
|
||||
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Inf)
|
||||
{
|
||||
const uint16_t binary_bhalf_inf = 0x7F80;
|
||||
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
|
||||
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, MantisaOverflow)
|
||||
{
|
||||
const float abs_tol = std::pow(2, -7);
|
||||
const uint32_t val = 0x81FFFFFF;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
|
||||
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, ExpOverflow)
|
||||
{
|
||||
const uint32_t val = 0xFF800000;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, MantisaExpOverflow)
|
||||
{
|
||||
const uint32_t val = 0xFFFFFFFF;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
|
||||
ASSERT_TRUE(std::isnan(float_val));
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Performance)
|
||||
{
|
||||
if (ck::get_device_name() == "gfx950")
|
||||
@@ -365,3 +481,41 @@ TEST(BHALF_T, CastOnDevice)
|
||||
ASSERT_NEAR(float_val_after_cast_host, -float_vals[idx], abs_tol);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BHALF_T, PackedCast_in_place)
|
||||
{
|
||||
const float v1 = 3.14f;
|
||||
const float v2 = -1.618f;
|
||||
ck::bhalf2_t* bhalf2_val_d;
|
||||
hip_check_error(hipMalloc(&bhalf2_val_d, sizeof(ck::bhalf2_t)));
|
||||
|
||||
packed_cast_in_place<<<1, 1>>>(v1, v2, bhalf2_val_d);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
ck::bhalf2_t bhalf2_val_h;
|
||||
hip_check_error(hipMemcpy(&bhalf2_val_h, bhalf2_val_d, sizeof(ck::bhalf2_t), hipMemcpyDeviceToHost));
|
||||
|
||||
// Convert back to floats
|
||||
const float fval1 = type_convert<float>(bhalf2_val_h[0]);
|
||||
const float fval2 = type_convert<float>(bhalf2_val_h[1]);
|
||||
|
||||
const float abs_tol = std::pow(2, -7);
|
||||
ASSERT_NEAR(fval1, v1, abs_tol);
|
||||
ASSERT_NEAR(fval2, v2, abs_tol);
|
||||
|
||||
hip_check_error(hipFree(bhalf2_val_d));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, PackedCast_in_place_performance)
|
||||
{
|
||||
if (ck::get_device_name() == "gfx950")
|
||||
{
|
||||
run_in_place_performance_test<32>();
|
||||
run_in_place_performance_test<64>();
|
||||
run_in_place_performance_test<128>();
|
||||
}
|
||||
else
|
||||
{
|
||||
GTEST_SKIP() << "Packed cast performance test requires gfx950.";
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user