mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
Integrate new packed cast threadwise tensor slice transfer into gridwise gemm pipelines.
This commit is contained in:
@@ -4,13 +4,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -101,18 +99,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// gfx950 specific optimizations for BF16 inputs
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool is_gfx950_and_bf16_input_ =
|
||||
std::is_same_v<ADataType, ck::bhalf_t> &&
|
||||
std::is_same_v<BDataType, ck::bhalf_t> &&
|
||||
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
|
||||
std::is_same_v<AccDataType, float>;
|
||||
#else
|
||||
static constexpr bool is_gfx950_and_bf16_input_ = false;
|
||||
#endif
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
@@ -274,10 +261,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_}
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[GridwiseGemm_xdl_cshuffle_conv_v3] GFX950 and BF16 optimization enabled: " << is_gfx950_and_bf16_input_ << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
@@ -656,6 +639,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// if arch = gfx942
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
|
||||
// check if we should apply gf950 device specific optimization for BF16 output
|
||||
__device__ static constexpr bool is_gfx650_and_bf16_output()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return
|
||||
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
|
||||
std::is_same_v<AccDataType, float>;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -900,9 +895,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -920,8 +915,30 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true,
|
||||
is_gfx950_and_bf16_input_>{
|
||||
true>,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>
|
||||
>;
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
@@ -989,21 +1006,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
@@ -1288,9 +1290,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -1308,8 +1310,30 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true,
|
||||
is_gfx950_and_bf16_input_>{
|
||||
true>,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>
|
||||
>;
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
@@ -1376,21 +1400,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -1385,6 +1384,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
// check if we should apply gf950 device specific optimization for BF16 output
|
||||
__device__ static constexpr bool is_gfx650_and_bf16_output()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return
|
||||
std::is_same_v<CShuffleDataType, ck::bhalf_t> &&
|
||||
std::is_same_v<AccDataType, float>;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -1647,9 +1658,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
}
|
||||
};
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -1669,8 +1681,33 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true,
|
||||
is_gfx950_and_bf16_input_>{
|
||||
true>,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
conditional_t<DoElementwiseBeforeCShuffle,
|
||||
CElementwiseOperation,
|
||||
tensor_operation::element_wise::PassThrough>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>
|
||||
>;
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
@@ -1740,21 +1777,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc (TensorDescriptor struct)
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
@@ -2068,28 +2090,52 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
using ThreadwiseTransfer = std::conditional_t<
|
||||
is_gfx650_and_bf16_output(),
|
||||
ThreadwiseTensorSliceTransfer_v1r3_packed_cast<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>
|
||||
>;
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true,
|
||||
is_gfx950_and_bf16_input_>{
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTransfer{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
@@ -2157,21 +2203,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
auto c_thread_packed_cast = PackedCastV2<
|
||||
M2,
|
||||
M4,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle
|
||||
>{};
|
||||
c_thread_packed_cast.Run(
|
||||
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, // source desc
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), // source slice origin
|
||||
c_thread_buf // source buffer
|
||||
);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
struct PackedCast
|
||||
{
|
||||
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
// Calculate total elements in this slice
|
||||
constexpr index_t elements_per_slice =
|
||||
CShuffleMXdlPerWavePerShuffle * CShuffleNXdlPerWavePerShuffle * M2 * M4;
|
||||
|
||||
constexpr auto calculate_coords = [src_slice_origin_idx](auto idx) constexpr {
|
||||
|
||||
// We know that the access order is
|
||||
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
// Sequence<CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, 1, 1, M2, 1, M4, 1>
|
||||
|
||||
constexpr index_t m4_offset = idx.value % M4;
|
||||
constexpr index_t m2_offset = (idx.value / M4) % M2;
|
||||
constexpr index_t n_xdl_offset = (idx.value / (M4 * M2)) % CShuffleNXdlPerWavePerShuffle;
|
||||
constexpr index_t m_xdl_offset = idx.value / (M4 * M2 * CShuffleNXdlPerWavePerShuffle);
|
||||
|
||||
return make_tuple(
|
||||
src_slice_origin_idx[Number<0>{}] + Number<m_xdl_offset>{},
|
||||
src_slice_origin_idx[Number<1>{}] + Number<n_xdl_offset>{},
|
||||
Number<0>{}, // this dim has unit size
|
||||
Number<0>{}, // this dim has unit size
|
||||
src_slice_origin_idx[Number<4>{}] + Number<m2_offset>{},
|
||||
Number<0>{}, // this dim has unit size
|
||||
src_slice_origin_idx[Number<6>{}] + Number<m4_offset>{},
|
||||
Number<0>{} // this dim has unit size
|
||||
);
|
||||
};
|
||||
|
||||
constexpr index_t num_pairs = elements_per_slice / 2;
|
||||
constexpr bool has_odd_element = (elements_per_slice % 2 == 1);
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto pair_idx) {
|
||||
constexpr auto idx_0 = Number<pair_idx * 2>{};
|
||||
constexpr auto idx_1 = Number<pair_idx * 2 + 1>{};
|
||||
|
||||
constexpr auto coord_0 = calculate_coords(idx_0);
|
||||
constexpr auto coord_1 = calculate_coords(idx_1);
|
||||
|
||||
constexpr auto offset_0 = src_desc.CalculateOffset(coord_0);
|
||||
constexpr auto offset_1 = src_desc.CalculateOffset(coord_1);
|
||||
|
||||
float& val_0 = src_buf(Number<offset_0>{});
|
||||
float& val_1 = src_buf(Number<offset_1>{});
|
||||
|
||||
static_cast_float_to_bhalf_packed(val_0, val_1);
|
||||
});
|
||||
|
||||
// Handle last element if the number of elements is odd.
|
||||
if constexpr (has_odd_element)
|
||||
{
|
||||
constexpr auto last_idx = Number<elements_per_slice - 1>{};
|
||||
constexpr auto last_coord = calculate_coords(last_idx);
|
||||
|
||||
// Single element conversion
|
||||
constexpr auto last_offset = src_desc.CalculateOffset(last_coord);
|
||||
float& last_val = src_buf[Number<last_offset>{}];
|
||||
const auto single_bf16 = static_cast<__bf16>(last_val);
|
||||
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
|
||||
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
|
||||
parts[1] = bf16_bits[0];
|
||||
}
|
||||
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
template <ck::index_t M2, ck::index_t M4, ck::index_t CShuffleMXdlPerWavePerShuffle, ck::index_t CShuffleNXdlPerWavePerShuffle>
|
||||
struct PackedCastV2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
template <typename SrcDesc, typename SrcSliceOriginIdx, typename SrcBuffer>
|
||||
__device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
using SliceLengths = Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>;
|
||||
using DimAccessOrder = Sequence<0, 1, 2, 3, 4, 5, 6, 7>;
|
||||
using DstScalarPerAccess = Sequence<1, 1, 1, 1, 1, 1, 1, 1>;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
DstScalarPerAccess>;
|
||||
|
||||
static_assert(SpaceFillingCurve::ScalarPerVector == 1,
|
||||
"wrong! SpaceFillingCurve::ScalarPerVector must be 1 for PackedCastV2");
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
constexpr index_t num_pairs = num_access / 2;
|
||||
constexpr bool has_odd_element = (num_access % 2 == 1);
|
||||
|
||||
static_for<0, num_pairs, 1>{}([&](auto i_pair)
|
||||
{
|
||||
constexpr auto idx_1d_0 = I2 * i_pair;
|
||||
constexpr auto idx_1d_1 = I2 * i_pair + I1;
|
||||
constexpr auto idx_md_0 = SpaceFillingCurve::GetIndex(idx_1d_0);
|
||||
constexpr auto idx_md_1 = SpaceFillingCurve::GetIndex(idx_1d_1);
|
||||
|
||||
constexpr index_t src_offset_0 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_0);
|
||||
constexpr index_t src_offset_1 = src_desc.CalculateOffset(src_slice_origin_idx + idx_md_1);
|
||||
|
||||
float& val_0 = src_buf(Number<src_offset_0>{});
|
||||
float& val_1 = src_buf(Number<src_offset_1>{});
|
||||
static_cast_float_to_bhalf_packed_v2(val_0, val_1);
|
||||
});
|
||||
|
||||
// Handle last element if the number of elements is odd.
|
||||
if constexpr (has_odd_element)
|
||||
{
|
||||
constexpr auto last_idx_1d = Number<num_access - 1>{};
|
||||
constexpr auto last_idx_md = SpaceFillingCurve::GetIndex(last_idx_1d);
|
||||
|
||||
// Single element conversion
|
||||
constexpr auto last_src_offset = src_desc.CalculateOffset(last_idx_1d);
|
||||
float& last_val = src_buf(Number<last_src_offset>{});
|
||||
const auto single_bf16 = static_cast<__bf16>(last_val);
|
||||
uint16_t* parts = reinterpret_cast<uint16_t*>(&last_val);
|
||||
const uint16_t* bf16_bits = reinterpret_cast<const uint16_t*>(&single_bf16);
|
||||
parts[0] = bf16_bits[0];
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user