mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Fast GeLU using built-in function (#587)
* clean up
* fast gelu using builtin function
* clean
* clean
* clean
* clean:
* clean
* fix compilation
* clean
* clean
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 8f455615a8]
This commit is contained in:
@@ -62,7 +62,7 @@ struct ExecutionConfig final
|
||||
};
|
||||
|
||||
inline bool
|
||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig config)
|
||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -7,10 +7,11 @@ using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = BF16;
|
||||
using D1DataType = BF16;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = BF16;
|
||||
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||
using D0DataType = BF16;
|
||||
using D1DataType = BF16;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
|
||||
@@ -7,10 +7,11 @@ using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = F16;
|
||||
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||
using D0DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
@@ -7,10 +6,11 @@ using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = F32;
|
||||
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = F32;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
|
||||
@@ -11,10 +11,11 @@ using ADataType = I4;
|
||||
using BDataType = I4;
|
||||
using AccDataType = I32;
|
||||
using CShuffleDataType = I32;
|
||||
using D0DataType = I4;
|
||||
using D1DataType = I4;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = I4;
|
||||
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||
using D0DataType = I4;
|
||||
using D1DataType = I4;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = I4;
|
||||
|
||||
using KernelADataType = I8;
|
||||
using KernelBDataType = I8;
|
||||
@@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
|
||||
@@ -7,10 +7,11 @@ using ADataType = I8;
|
||||
using BDataType = I8;
|
||||
using AccDataType = I32;
|
||||
using CShuffleDataType = I32;
|
||||
using D0DataType = I8;
|
||||
using D1DataType = I8;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = I8;
|
||||
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
|
||||
using D0DataType = I8;
|
||||
using D1DataType = I8;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = I8;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
|
||||
@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
Tensor<CDataType> c_m_n({M, N});
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
Reference in New Issue
Block a user