mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
* Add support for mixed precision in contraction scale and bilinear (#936) * Extract common functionality to separate files * Reference contraction: Remove incorrect consts from type_converts * Reference contraction: Add missing type_convert for dst value * Reference contraction: Fix incorrect order of B matrix dimensions * Add support for mixed precision in contraction scale and bilinear * Move using statements from instances to a common file * Move using statements from examples to a common file * Fix the order of B matrix dimensions across examples and profiler * Fix the computation of error threshold * Make ComputeDataType an optional argument * Include possible DataType -> ComputeDataType casting error in the threshold * Remove commented code * Make the ComputeDataType an optional argument in instance --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
86 lines
4.4 KiB
C++
86 lines
4.4 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
|
|
|
#include "common_instances.hpp"
|
|
|
|
using ADataType = F32;
|
|
using BDataType = F32;
|
|
using AccDataType = F32;
|
|
using CShuffleDataType = F32;
|
|
using DsDataType = ck::Tuple<>;
|
|
using EDataType = F32;
|
|
using ComputeDataType = F32;
|
|
|
|
static constexpr ck::index_t NumDimM = 2;
|
|
static constexpr ck::index_t NumDimN = 2;
|
|
static constexpr ck::index_t NumDimK = 2;
|
|
|
|
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
|
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
|
|
|
|
using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM,
|
|
NumDimN,
|
|
NumDimK,
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CShuffleDataType,
|
|
DsDataType,
|
|
EDataType,
|
|
ComputeDataType,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>;
|
|
|
|
using DeviceOpInstance = DeviceOpInstanceKKN;
|
|
|
|
#include "run_contraction_scale_example.inc"
|
|
|
|
int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); }
|