mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add instances for conv_scale with fp8@bf8->fp8 (#1220)
* Update device op api to support BComputeType * Add example * Add instances * Add profiler mode * Add client example * Update copyright year * Add BComputeType check * Fix compute types
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation CDE elementwise operation.
|
||||
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
|
||||
* \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
|
||||
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
@@ -54,12 +55,13 @@ template <index_t NDimSpatial,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeType =
|
||||
typename AComputeType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>())> // ComputeType is InputType by default (first
|
||||
ADataType>()), // AComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
{
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -254,13 +254,14 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
AComputeDataType,
|
||||
BComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define GridwiseGemmTemplateParameters \
|
||||
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
|
||||
@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -75,13 +75,14 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
typename AComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename BComputeDataType = AComputeDataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ALayout,
|
||||
@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
ComputeDataType,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
LoopSched>;
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user