mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* Add support to fp16 + compute fp16 and bf16 + compute bf16 contractions Enables hipTensor to access the WMMA HW functionalities for these combinations of datatype on gfx11 and gfx12. * Fix change to contraction scale tests * Fix clang-format
87 lines
4.4 KiB
C++
87 lines
4.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
|
|
|
#include "common_instances.hpp"
|
|
|
|
using ADataType = F16;
|
|
using BDataType = F16;
|
|
using AccDataType = F32;
|
|
using CShuffleDataType = F16;
|
|
using DDataType = F16;
|
|
using DsDataType = ck::Tuple<DDataType>;
|
|
using EDataType = F16;
|
|
using ComputeDataType = F16;
|
|
|
|
static constexpr ck::index_t NumDimM = 2;
|
|
static constexpr ck::index_t NumDimN = 2;
|
|
static constexpr ck::index_t NumDimK = 2;
|
|
|
|
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
|
|
|
|
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
|
|
|
#include "run_contraction_bilinear_example.inc"
|
|
|
|
int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); }
|