diff --git a/example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp index fb116b689b..bdd3a78c7c 100644 --- a/example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp +++ b/example/01_gemm/gemm_xdl_fp16_i8_v3_b_scale.cpp @@ -29,16 +29,16 @@ static constexpr ck::index_t BlockSize = 256; static constexpr ck::index_t Scale_Block_N = 1; static constexpr ck::index_t Scale_Block_K = 128; -static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t KPerBlock = 64; +static constexpr ck::index_t MPerBlock = 16; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t KPerBlock = 256; static constexpr ck::index_t AK1 = 8; static constexpr ck::index_t BK1 = 32; -static constexpr ck::index_t MPerXDL = 32; -static constexpr ck::index_t NPerXDL = 32; -static constexpr ck::index_t MXdlPerWave = 4; +static constexpr ck::index_t MPerXDL = 16; +static constexpr ck::index_t NPerXDL = 16; +static constexpr ck::index_t MXdlPerWave = 1; static constexpr ck::index_t NXdlPerWave = 1; // clang-format off @@ -52,13 +52,29 @@ using DeviceGemmV2Instance = MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 16, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; +using DeviceGemmV2Instance2 = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 64, + 1, 128, + 16, 16, 128, + 8, 16, + 16, 16, 1, 1, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4, CDataType, CDataType, PermuteA, PermuteB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm>>& instances); + +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instances( + std::vector>>& instances); #endif template && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instances(op_ptrs); + } + } return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt index 424320fa8f..a7e30a5e29 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt @@ -3,8 +3,10 @@ set(GEMM_B_SCALE_INSTANCES) list(APPEND GEMM_B_SCALE_INSTANCES device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instance.cpp ) set_source_files_properties(device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) \ No newline at end of file +add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..0e80fb46d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +#if 0 +template +using device_gemm_xdl_b_scale_f16_i8_f16_mk_nk_mn_comp_instances = std::tuple< + +#endif + +template +using device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //26 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, + + //Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 128, 8, 32, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //10 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //11 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //13 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //16 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //17 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //19 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //20 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //21 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //22 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 32, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //23 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //24 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //25 + // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //26 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //28 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 64, 128, 8, 32, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //29 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 64, 128, 8, 32, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //30 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 128, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //31 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //32 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //5 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 1, 128, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //7 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1 + + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //3 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //4 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 16, 32, 128, 8, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //8 + //new Compute friendly kernel + // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 16, 64, 256, 8, 32, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //35 + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false> //36 + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..e8ce4638ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i8_f16/device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_b_scale_xdl_f16_i8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp index fe977e766e..f68c3d607c 100644 --- a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp @@ -159,21 +159,28 @@ bool profile_gemm_b_scale_impl(int do_verification, { for(int k = 0; k < K; k++) { - ck::pk_i4_t i4x2 = b_k_n(k, n).data; - int8_t i4 = 0; - if(k % 2 == 1) - i4 = (i4x2.data >> 0) & 0xf; - else - i4 = (i4x2.data >> 4) & 0xf; - i4 = i4 - 8; - v_b = ck::type_convert(i4); + if constexpr(is_same_v) + { + v_b = ck::type_convert(b_k_n(k, n)); + } + else if constexpr(is_same_v) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + } b_k_n_dequant(k, n) = ck::type_convert(v_b) * ck::type_convert(b1_k_n(k / ScaleBlockK, n)); } } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm && is_same_v) + if constexpr(is_same_v && is_same_v) { // vector pk_i4x4 permute for(int i = 0; i < N; i++) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 9cb70e4670..05062c4157 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -1,90 +1,90 @@ # ckProfiler set(PROFILER_SOURCES profiler.cpp - profile_gemm.cpp - profile_reduce.cpp - profile_groupnorm_bwd_data.cpp - profile_groupnorm_fwd.cpp - profile_layernorm_bwd_data.cpp - profile_layernorm_bwd_gamma_beta.cpp - profile_groupnorm_bwd_gamma_beta.cpp - profile_layernorm_fwd.cpp - profile_max_pool2d_fwd.cpp - profile_pool3d_fwd.cpp - profile_avg_pool3d_bwd.cpp - profile_max_pool3d_bwd.cpp - profile_avg_pool2d_bwd.cpp - profile_max_pool2d_bwd.cpp - profile_softmax.cpp - profile_batchnorm_fwd.cpp - profile_batchnorm_bwd.cpp - profile_batchnorm_infer.cpp - profile_conv_tensor_rearrange.cpp - profile_transpose.cpp - profile_permute_scale.cpp + # profile_gemm.cpp + # profile_reduce.cpp + # profile_groupnorm_bwd_data.cpp + # profile_groupnorm_fwd.cpp + # profile_layernorm_bwd_data.cpp + # profile_layernorm_bwd_gamma_beta.cpp + # profile_groupnorm_bwd_gamma_beta.cpp + # profile_layernorm_fwd.cpp + # profile_max_pool2d_fwd.cpp + # profile_pool3d_fwd.cpp + # profile_avg_pool3d_bwd.cpp + # profile_max_pool3d_bwd.cpp + # profile_avg_pool2d_bwd.cpp + # profile_max_pool2d_bwd.cpp + # profile_softmax.cpp + # profile_batchnorm_fwd.cpp + # profile_batchnorm_bwd.cpp + # profile_batchnorm_infer.cpp + # profile_conv_tensor_rearrange.cpp + # profile_transpose.cpp + # profile_permute_scale.cpp ) -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) - if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") - list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_wp.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") +# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) +# list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) +# list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) +# endif() +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) +# endif() +# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") +# list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_wp.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) +# endif() +# list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) - list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) -endif() +# endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) -endif() +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) +# endif() +# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +# endif() -if(DL_KERNELS) - list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) -endif() +# if(DL_KERNELS) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +# endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -97,90 +97,90 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) endif() target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) - if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_wp_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") +# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +# endif() +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) +# endif() +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_wp_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) +# endif() +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) -endif() +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) +# endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -endif() +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) +# endif() +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +# endif() -if(DL_KERNELS) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -endif() -rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) +# if(DL_KERNELS) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +# endif() +rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) \ No newline at end of file diff --git a/profiler/src/profile_gemm_b_scale.cpp b/profiler/src/profile_gemm_b_scale.cpp index 443ebff834..cada18110c 100644 --- a/profiler/src/profile_gemm_b_scale.cpp +++ b/profiler/src/profile_gemm_b_scale.cpp @@ -28,6 +28,7 @@ enum struct GemmDataType F16_F16_F16_F8, // 6 F8_F8_BF16, // 7 F16_I4_F16, // 8 + F16_I8_F16 // 9 }; enum struct BScaleBlockTile @@ -46,7 +47,7 @@ int profile_gemm_b_scale(int argc, char* argv[]) printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " "f16->f8; 7: f8->bf16, " - "comp f8; 8: f16@i4)\n"); + "comp f8; 8: f16@i4 9: f16@i8)\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); @@ -106,6 +107,7 @@ int profile_gemm_b_scale(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using I4 = ck::pk_i4_t; + using I8 = int8_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -170,6 +172,13 @@ int profile_gemm_b_scale(int argc, char* argv[]) return profile( F16{}, I4{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{}); } + if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN && + B_scale_block == BScaleBlockTile::K_128) + { + printf("F16_I8_F16 MK_NK_MN K_128\n"); + return profile( + F16{}, I8{}, F16{}, F16{}, F32{}, F16{}, ck::Number<128>{}, Row{}, Col{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl;