mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 01:27:40 +00:00
* chore(copyright): update copyright header for codegen directory * chore(copyright): update copyright header for example directory
86 lines
4.4 KiB
C++
86 lines
4.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
|
|
|
#include "common_instances.hpp"
|
|
|
|
using ADataType = F32;
|
|
using BDataType = F32;
|
|
using AccDataType = F32;
|
|
using CShuffleDataType = F32;
|
|
using DsDataType = ck::Tuple<>;
|
|
using EDataType = F32;
|
|
using ComputeDataType = F32;
|
|
|
|
static constexpr ck::index_t NumDimM = 2;
|
|
static constexpr ck::index_t NumDimN = 2;
|
|
static constexpr ck::index_t NumDimK = 2;
|
|
|
|
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
|
|
|
|
using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstance = DeviceOpInstanceKKN;
|
|
|
|
#include "run_contraction_scale_example.inc"
|
|
|
|
int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); }
|