diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index c26e52a37b..600767d47c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -9,7 +9,6 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -898,7 +897,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -1007,21 +1006,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -1308,7 +1292,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -1417,21 +1401,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index e4f485971a..85bd5612f3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -9,7 +9,6 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -1761,21 +1760,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); - // if constexpr (is_gfx650_and_bf16_output()) - // { - // auto c_thread_packed_cast = PackedCastV2< - // M2, - // M4, - // CShuffleMXdlPerWavePerShuffle, - // CShuffleNXdlPerWavePerShuffle - // >{}; - // c_thread_packed_cast.Run( - // c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) - // sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - // c_thread_buf // source buffer - // ); - // } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -2203,21 +2187,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); - // if constexpr (is_gfx650_and_bf16_output()) - // { - // auto c_thread_packed_cast = PackedCastV2< - // M2, - // M4, - // CShuffleMXdlPerWavePerShuffle, - // CShuffleNXdlPerWavePerShuffle - // >{}; - // c_thread_packed_cast.Run( - // c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc - // sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - // c_thread_buf // source buffer - // ); - // } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 983b83dc13..8d63cf1525 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -8,7 +8,6 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -1594,7 +1593,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -1756,21 +1755,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct) - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -2166,7 +2150,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // shuffle: threadwise copy C from VGPR to LDS using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< AccDataType, CShuffleDataType, decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), @@ -2326,21 +2310,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // make sure it's safe to write to LDS block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle - >{}; - c_thread_packed_cast.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin - c_thread_buf // source buffer - ); - } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 52814fae66..a1f42ffcb6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -8,7 +8,6 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "ck/tensor_operation/gpu/grid/packed_cast.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" @@ -897,7 +896,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using ThreadwiseTransfer = std::conditional_t< is_gfx650_and_bf16_output(), - ThreadwiseTensorSliceTransfer_v1r3_pass_through< + ThreadwiseTensorSliceTransfer_v1r3_packed_cast< FloatAcc, FloatC, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), @@ -1002,21 +1001,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // make sure it's safe to do ds_write block_sync_lds(); - if constexpr (is_gfx650_and_bf16_output()) - { - auto c_thread_packed_cast = PackedCastV2< - M2, - M4, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle - >{}; - c_thread_packed_cast.Run( - c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, // source desc (TensorDescriptor struct) - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), // source slice origin - c_thread_buf // source buffer - ); - } - // VGPR to LDS c_thread_copy_vgpr_to_lds.Run( c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, diff --git a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp b/include/ck/tensor_operation/gpu/grid/packed_cast.hpp deleted file mode 100644 index 6a7190aa05..0000000000 --- a/include/ck/tensor_operation/gpu/grid/packed_cast.hpp +++ /dev/null @@ -1,93 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/data_type.hpp" -#include "ck/utility/type_convert.hpp" -#include "ck/utility/static_buffer.hpp" -#include "ck/tensor_description/tensor_space_filling_curve.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { - template - struct PackedCastV2 - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - template - __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf) - { - static_assert(SrcDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc need to known at compile-time"); - static_assert(is_known_at_compile_time>::value, - "wrong! SrcSliceOrigin need to known at compile-time"); - - static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); - - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - - using SliceLengths = Sequence; - using DimAccessOrder = Sequence<0, 1, 2, 3, 4, 5, 6, 7>; - using DstScalarPerAccess = Sequence<1, 1, 1, 1, 1, 1, 1, 1>; - using SpaceFillingCurve = SpaceFillingCurve; - - static_assert(SpaceFillingCurve::ScalarPerVector == 1, - "wrong! SpaceFillingCurve::ScalarPerVector must be 1 for PackedCastV2"); - - constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess(); - constexpr index_t num_pairs = num_access / 2; - constexpr bool has_odd_element = (num_access % 2 == 1); - - static_assert(!has_odd_element, "PackedCastV2 does not support odd number of elements"); - - ck::float2_t float2_buffer; - static_for<0, num_pairs, 1>{}([&](auto i_pair) - { - constexpr auto idx_1d_0 = I2 * i_pair; - constexpr auto idx_1d_1 = I2 * i_pair + I1; - constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0); - constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1); - constexpr auto idx_md_pair = SpaceFillingCurve::GetIndex(i_pair); - - constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0); - constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1); - constexpr index_t pair_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_pair); - - if constexpr (src_offset_1 - src_offset_0 == 1) - { - // Load two consecutive float values from the src buffer - float2_buffer = src_buf.template GetAsType(Number{}); - } - else - { - // Load the two float values one by one - float2_buffer= {src_buf[Number{}], src_buf[Number{}]}; - } - - // Store the packed bfloat2 value back to the src buffer - const ck::bhalf2_t packed_value= bf16x2_convert_rne(float2_buffer[0], float2_buffer[1]); - union { - ck::bhalf2_t bhalf2; - float fp32; - } converter; - converter.bhalf2 = packed_value; - src_buf(Number{}) = converter.fp32; - }); - }; - }; -}