mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
add missing type convert
This commit is contained in:
@@ -11,8 +11,7 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_xdl_c_shuffle.hpp"
|
||||
#include "device_gemm_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
@@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
|
||||
ADataType, // ADataType
|
||||
BDataType, // BDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // AccDataType
|
||||
CDataType, // CShuffleDataType
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
PassThrough, // AElementwiseOperation
|
||||
PassThrough, // BElementwiseOperation
|
||||
PassThrough, // CElementwiseOperation
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
<ALayout, // typename ALayout,
|
||||
BLayout, // typename BLayout,
|
||||
CLayout, // typename CLayout,
|
||||
ADataType, // typename ADataType,
|
||||
BDataType, // typename BDataType,
|
||||
CDataType, // typename CDataType,
|
||||
AccDataType, // typename GemmAccDataType,
|
||||
CDataType, // typename CShuffleDataType,
|
||||
PassThrough, // typename AElementwiseOperation,
|
||||
PassThrough, // typename BElementwiseOperation,
|
||||
PassThrough, // typename CElementwiseOperation,
|
||||
GemmDefault, // GemmSpecialization GemmSpec,
|
||||
1, // index_t NumGemmKPrefetchStage,
|
||||
256, // index_t BlockSize,
|
||||
256, // index_t MPerBlock,
|
||||
128, // index_t NPerBlock,
|
||||
32, // index_t KPerBlock,
|
||||
8, // index_t AK1,
|
||||
8, // index_t BK1,
|
||||
32, // index_t MPerXDL,
|
||||
32, // index_t NPerXDL,
|
||||
4, // index_t MXdlPerWave,
|
||||
2, // index_t NXdlPerWave,
|
||||
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
|
||||
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
|
||||
2, // index_t ABlockTransferSrcVectorDim,
|
||||
8, // index_t ABlockTransferSrcScalarPerVector,
|
||||
8, // index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
1, // bool ABlockLdsExtraM,
|
||||
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
|
||||
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
|
||||
2, // index_t BBlockTransferSrcVectorDim,
|
||||
8, // index_t BBlockTransferSrcScalarPerVector,
|
||||
8, // index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
1, // bool BBlockLdsExtraN,
|
||||
1, // index_t CShuffleMXdlPerWavePerShuffle,
|
||||
1, // index_t CShuffleNXdlPerWavePerShuffle,
|
||||
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
@@ -46,7 +46,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpeciali
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
#if 0
|
||||
#if 1
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
|
||||
@@ -11,8 +11,7 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_xdl_c_shuffle.hpp"
|
||||
#include "device_gemm_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
@@ -20,64 +19,63 @@
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
using CDataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
|
||||
ADataType, // ADataType
|
||||
BDataType, // BDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
PassThrough, // AElementwiseOperation
|
||||
PassThrough, // BElementwiseOperation
|
||||
PassThrough, // CElementwiseOperation
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
|
||||
ALayout, // typename ALayout,
|
||||
BLayout, // typename BLayout,
|
||||
CLayout, // typename CLayout,
|
||||
ADataType, // typename ADataType,
|
||||
BDataType, // typename BDataType,
|
||||
CDataType, // typename CDataType,
|
||||
AccDataType, // typename GemmAccDataType,
|
||||
CShuffleDataType, // typename CShuffleDataType,
|
||||
PassThrough, // typename AElementwiseOperation,
|
||||
PassThrough, // typename BElementwiseOperation,
|
||||
PassThrough, // typename CElementwiseOperation,
|
||||
GemmDefault, // GemmSpecialization GemmSpec,
|
||||
1, // index_t NumGemmKPrefetchStage,
|
||||
256, // index_t BlockSize,
|
||||
256, // index_t MPerBlock,
|
||||
128, // index_t NPerBlock,
|
||||
64, // index_t KPerBlock,
|
||||
16, // index_t AK1,
|
||||
16, // index_t BK1,
|
||||
32, // index_t MPerXDL,
|
||||
32, // index_t NPerXDL,
|
||||
4, // index_t MXdlPerWave,
|
||||
2, // index_t NXdlPerWave,
|
||||
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder,
|
||||
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder,
|
||||
2, // index_t ABlockTransferSrcVectorDim,
|
||||
16, // index_t ABlockTransferSrcScalarPerVector,
|
||||
16, // index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
1, // bool ABlockLdsExtraM,
|
||||
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder,
|
||||
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder,
|
||||
2, // index_t BBlockTransferSrcVectorDim,
|
||||
8, // index_t BBlockTransferSrcScalarPerVector,
|
||||
8, // index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
1, // bool BBlockLdsExtraN,
|
||||
1, // index_t CShuffleMXdlPerWavePerShuffle,
|
||||
1, // index_t CShuffleNXdlPerWavePerShuffle,
|
||||
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
@@ -51,7 +51,7 @@ template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename DstElementwiseOperation,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx,
|
||||
const ElementwiseOperation& element_op)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
|
||||
dst_element_op_{dst_element_op}
|
||||
element_op_{element_op}
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData dst_v;
|
||||
SrcData v;
|
||||
|
||||
// apply element-wise operation
|
||||
dst_element_op_(dst_v, src_buf[Number<src_offset>{}]);
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(dst_v);
|
||||
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
private:
|
||||
DstCoord dst_coord_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
const ElementwiseOperation element_op_;
|
||||
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
// Assume:
|
||||
|
||||
@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
|
||||
|
||||
// apply pointwise operation
|
||||
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
|
||||
element_op_(dst_vector_container.template AsType<DstData>()(i),
|
||||
src_vector_container.template AsType<SrcData>()[i]);
|
||||
SrcData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
|
||||
|
||||
// apply type convert
|
||||
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
|
||||
Reference in New Issue
Block a user