mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Device Op GroupedGemmMultipleD + example fp16 (#633)
* Pass shared mem pointer as pointer to void. * Device Op GroupedGEMM Multiple D * Example for grouped gemm multiple d. * Add MI200 to supported archs. --------- Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -4,12 +4,15 @@ add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
||||
|
||||
|
||||
add_dependencies(example_grouped_gemm_xdl
|
||||
example_grouped_gemm_xdl_fp32
|
||||
example_grouped_gemm_xdl_fp16
|
||||
example_grouped_gemm_xdl_bfp16
|
||||
example_grouped_gemm_xdl_int8)
|
||||
example_grouped_gemm_xdl_int8
|
||||
example_grouped_gemm_multiple_d_dl_fp16)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
|
||||
67
example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp
Normal file
67
example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::
|
||||
// ##################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||
// ##################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// ##################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// ##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmMultipleD_Dl< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
Reference in New Issue
Block a user