Initial integaration of packed cast.

This commit is contained in:
Ville Pietilä
2025-08-04 15:34:35 +00:00
parent 590e119828
commit e92c0bf68e
2 changed files with 45 additions and 2 deletions

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/packed_cast.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -100,6 +101,21 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// gfx950 specific optimizations for BF16 inputs
#if defined(__gfx950__)
static constexpr bool is_gfx950_and_bf16_input_ =
std::is_same_v<ADataType, bhalf_t> &&
std::is_same_v<BDataType, bhalf_t> &&
std::is_same_v<CShuffleDataType, bhalf_t> &&
std::is_same_v<AccDataType, float>;
#else
static constexpr bool is_gfx950_and_bf16_input_ = false;
#endif
using CShuffleInputDataType = std::conditional_t<is_gfx950_and_bf16_input_,
CShuffleDataType,
AccDataType>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
@@ -884,7 +900,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3<CShuffleInputDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -970,6 +986,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx950_and_bf16_input_)
{
packed_cast(sfc_c_vgpr);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
@@ -1256,7 +1277,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3<CShuffleInputDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
@@ -1341,6 +1362,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
if constexpr (is_gfx950_and_bf16_input_)
{
packed_cast(sfc_c_vgpr);
}
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
__host__ __device__ inline void packed_cast(const auto& sfc_c_vgpr)
{
// This function is a placeholder for packed cast operations.
// For now, it does nothing.
};
}