mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
Working version 2 of the packed cast.
This commit is contained in:
@@ -800,15 +800,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
if constexpr(is_gfx950_and_bf16_input_)
|
||||
{
|
||||
constexpr auto register_size_per_xdl_op = blockwise_gemm_pipeline.xdlops_gemm.GetRegSizePerXdlops();
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
ck::bhalf2_t,
|
||||
(MXdlPerWave * NXdlPerWave + 1)/ 2,
|
||||
register_size_per_xdl_op,
|
||||
true> c_thread_buf_bf16;
|
||||
}
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
@@ -1010,8 +1001,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf, // source buffer,
|
||||
c_thread_buf_bf16 // destination buffer
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1390,7 +1380,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCast<
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -94,6 +95,10 @@ namespace ck {
|
||||
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
struct PackedCastV2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
|
||||
{
|
||||
@@ -130,8 +135,8 @@ namespace ck {
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
{
|
||||
constexpr auto idx_1d_0 = 2 * i_pair;
|
||||
constexpr auto idx_1d_1 = 2 * i_pair + 1;
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
|
||||
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
|
||||
|
||||
@@ -146,16 +151,17 @@ namespace ck {
|
||||
// Handle last element if the number of elements is odd.
|
||||
if constexpr (has_odd_element)
|
||||
{
|
||||
constexpr auto last_idx_1d = num_access - 1;
|
||||
constexpr auto last_idx_1d = Number<num_access - 1>{};
|
||||
constexpr auto last_idx_md = SpaceFillingCurve::GetIndex(last_idx_1d);
|
||||
|
||||
// Single element conversion
|
||||
constexpr auto last_src_offset = src_desc.CalculateOffset(last_idx_1d);
|
||||
float& last_val = src_buf(Number<last_offset>{});
|
||||
float& last_val = src_buf(Number<last_src_offset>{});
|
||||
const auto single_bf16 = static_cast<__bf16>(last_val);
|
||||
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
|
||||
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
|
||||
parts[0] = bf16_bits[0];
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -35,10 +35,14 @@ template <typename SrcData,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false,
|
||||
bool PackedInput = false>
|
||||
bool PackedInput = false,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
static constexpr bool float_input_and_bf16_output_ =
|
||||
@@ -86,38 +90,76 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
if constexpr (PackedInput && float_input_and_bf16_output_)
|
||||
{
|
||||
static_assert(DstScalarPerVector == 2, "wrong! DstScalarPerVector must be 2");
|
||||
static_assert(std::is_same_v<DimAccessOrder, Sequence<0, 1, 2, 3, 4, 5, 6, 7>>,
|
||||
"wrong! DimAccessOrder must be the identity sequence <0, 1, 2, 3, 4, 5, 6, 7>");
|
||||
|
||||
// TODO: Fill the dst_vector vectprized access of size 2.
|
||||
// Copying the dst_vector into dst_buf should be the same as before.
|
||||
static_assert(1 == SpaceFillingCurve::ScalarPerVector, "wrong!1 != SpaceFillingCurve::ScalarPerVector");
|
||||
static_assert(1 == DstScalarPerVector, "wrong!1 != DstScalarPerVector");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d)
|
||||
{
|
||||
// We need map the odd indices to the even indices, since
|
||||
// the even indices contain a packed bf16x2 value, where
|
||||
// the first value contains the bf16 value for the corresponding even index
|
||||
// and the second value contains the bf16 value for the odd index following the even index.
|
||||
// The odd indices are not used, so we can just ignore them.
|
||||
constexpr auto pair_index = idx_1d % I2;
|
||||
constexpr auto idx_src_1d = idx_1d - pair_index;
|
||||
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_src_1d);
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md);
|
||||
|
||||
union
|
||||
{
|
||||
float src_float;
|
||||
bhalf16_t src_bf16x2;
|
||||
} packed_value;
|
||||
|
||||
packed_value.src_float = src_buf[Number<src_offset>{}];
|
||||
dst_vector.template AsType<DstData>()(I0) = packed_value.src_bf16x2[pair_index.value];
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
|
||||
@@ -227,7 +227,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
|
||||
}
|
||||
|
||||
// Use --gtest_filter="TestGroupedConvndBwdWeight2d_bf16_gfx950/*" to only this subset of tests.
|
||||
#if defined(__gfx950__)
|
||||
// List tests in the test suite with --gtest_list_tests
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight2d_bf16_gfx950 : public TestGroupedConvndBwdWeight<Tuple>
|
||||
@@ -257,4 +257,3 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d_bf16_gfx950, Test2D)
|
||||
|
||||
this->Run();
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user