mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Initial integaration of packed cast.
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -100,6 +101,21 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// gfx950 specific optimizations for BF16 inputs
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool is_gfx950_and_bf16_input_ =
|
||||
std::is_same_v<ADataType, bhalf_t> &&
|
||||
std::is_same_v<BDataType, bhalf_t> &&
|
||||
std::is_same_v<CShuffleDataType, bhalf_t> &&
|
||||
std::is_same_v<AccDataType, float>;
|
||||
#else
|
||||
static constexpr bool is_gfx950_and_bf16_input_ = false;
|
||||
#endif
|
||||
|
||||
using CShuffleInputDataType = std::conditional_t<is_gfx950_and_bf16_input_,
|
||||
CShuffleDataType,
|
||||
AccDataType>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
|
||||
@@ -884,7 +900,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<CShuffleInputDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -970,6 +986,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
packed_cast(sfc_c_vgpr);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
@@ -1256,7 +1277,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<CShuffleInputDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
@@ -1341,6 +1362,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr (is_gfx950_and_bf16_input_)
|
||||
{
|
||||
packed_cast(sfc_c_vgpr);
|
||||
}
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
|
||||
17
include/ck/tensor_operation/gpu/grid/packed_cast.hpp
Normal file
17
include/ck/tensor_operation/gpu/grid/packed_cast.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__host__ __device__ inline void packed_cast(const auto& sfc_c_vgpr)
|
||||
{
|
||||
// This function is a placeholder for packed cast operations.
|
||||
// For now, it does nothing.
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user