mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
* Add support for mixed precision in contraction scale and bilinear (#936)
* Extract common functionality to separate files
* Reference contraction: Remove incorrect consts from type_converts
* Reference contraction: Add missing type_convert for dst value
* Reference contraction: Fix incorrect order of B matrix dimensions
* Add support for mixed precision in contraction scale and bilinear
* Move using statements from instances to a common file
* Move using statements from examples to a common file
* Fix the order of B matrix dimensions across examples and profiler
* Fix the computation of error threshold
* Make ComputeDataType an optional argument
* Include possible DataType -> ComputeDataType casting error in the threshold
* Remove commented code
* Make the ComputeDataType an optional argument in instance
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: 4ef704d8a6]
86 lines
4.4 KiB
C++
86 lines
4.4 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
|
|
|
#include "common_instances.hpp"
|
|
|
|
using ADataType = F32;
|
|
using BDataType = F32;
|
|
using AccDataType = F32;
|
|
using CShuffleDataType = F32;
|
|
using DsDataType = ck::Tuple<>;
|
|
using EDataType = F32;
|
|
using ComputeDataType = F32;
|
|
|
|
static constexpr ck::index_t NumDimM = 2;
|
|
static constexpr ck::index_t NumDimN = 2;
|
|
static constexpr ck::index_t NumDimK = 2;
|
|
|
|
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
|
|
|
|
using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstance = DeviceOpInstanceKKN;
|
|
|
|
#include "run_contraction_scale_example.inc"
|
|
|
|
int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); }
|