Remove separate packed cast step.

This commit is contained in:
Ville Pietilä
2025-09-03 10:57:16 +00:00
parent d56d7bc821
commit 70d57ca8b9
5 changed files with 5 additions and 207 deletions

View File

@@ -9,7 +9,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -898,7 +897,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1007,21 +1006,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -1308,7 +1292,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1417,21 +1401,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),

View File

@@ -9,7 +9,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -1761,21 +1760,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
// if constexpr (is_gfx650_and_bf16_output())
// {
// auto c_thread_packed_cast = PackedCastV2<
// M2,
// M4,
// CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle
// >{};
// c_thread_packed_cast.Run(
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
// c_thread_buf // source buffer
// );
// }
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -2203,21 +2187,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
// if constexpr (is_gfx650_and_bf16_output())
// {
// auto c_thread_packed_cast = PackedCastV2<
// M2,
// M4,
// CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle
// >{};
// c_thread_packed_cast.Run(
// c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
// c_thread_buf // source buffer
// );
// }
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),

View File

@@ -8,7 +8,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -1594,7 +1593,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1756,21 +1755,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -2166,7 +2150,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// shuffle: threadwise copy C from VGPR to LDS
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -2326,21 +2310,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle
>{};
c_thread_packed_cast.Run(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
c_thread_buf // source buffer
);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),

View File

@@ -8,7 +8,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
@@ -897,7 +896,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using ThreadwiseTransfer = std::conditional_t<
is_gfx650_and_bf16_output(),
ThreadwiseTensorSliceTransfer_v1r3_pass_through<
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
FloatAcc,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
@@ -1002,21 +1001,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// make sure it's safe to do ds_write
block_sync_lds();
if constexpr (is_gfx650_and_bf16_output())
{
auto c_thread_packed_cast = PackedCastV2<
M2,
M4,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle
>{};
c_thread_packed_cast.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct)
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin
c_thread_buf // source buffer
);
}
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,

View File

@@ -1,93 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/static_buffer.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
struct PackedCastV2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
using SliceLengths = Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>;
using DimAccessOrder = Sequence<0, 1, 2, 3, 4, 5, 6, 7>;
using DstScalarPerAccess = Sequence<1, 1, 1, 1, 1, 1, 1, 1>;
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
DstScalarPerAccess,
false>;
static_assert(SpaceFillingCurve::ScalarPerVector == 1,
"wrong! SpaceFillingCurve::ScalarPerVector must be 1 for PackedCastV2");
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
constexpr index_t num_pairs = num_access / 2;
constexpr bool has_odd_element = (num_access % 2 == 1);
static_assert(!has_odd_element, "PackedCastV2 does not support odd number of elements");
ck::float2_t float2_buffer;
static_for<0, num_pairs, 1>{}([&](auto i_pair)
{
constexpr auto idx_1d_0 = I2 * i_pair;
constexpr auto idx_1d_1 = I2 * i_pair + I1;
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
constexpr auto idx_md_pair = SpaceFillingCurve::GetIndex(i_pair);
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
constexpr index_t pair_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_pair);
if constexpr (src_offset_1 - src_offset_0 == 1)
{
// Load two consecutive float values from the src buffer
float2_buffer = src_buf.template GetAsType<ck::float2_t>(Number<src_offset_0>{});
}
else
{
// Load the two float values one by one
float2_buffer= {src_buf[Number<src_offset_0>{}], src_buf[Number<src_offset_1>{}]};
}
// Store the packed bfloat2 value back to the src buffer
const ck::bhalf2_t packed_value= bf16x2_convert_rne<ck::bhalf2_t, float>(float2_buffer[0], float2_buffer[1]);
union {
ck::bhalf2_t bhalf2;
float fp32;
} converter;
converter.bhalf2 = packed_value;
src_buf(Number<pair_offset>{}) = converter.fp32;
});
};
};
}