mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Remove int8 from batchnorm-forward instances since it is not needed for forward training and could fail test (#516)
[ROCm/composable_kernel commit: 5bf0475afd]
This commit is contained in:
@@ -31,10 +31,6 @@ void add_device_batchnorm_forward_rank_4_3_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// Int8
|
||||
void add_device_batchnorm_forward_rank_4_3_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP64
|
||||
void add_device_batchnorm_forward_rank_4_3_f64_instances(
|
||||
std::vector<
|
||||
@@ -101,15 +97,6 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, I8> && is_same_v<YDataType, I8> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, I8> &&
|
||||
is_same_v<BiasDataType, I8> && is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_forward_rank_4_3_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
|
||||
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
|
||||
|
||||
@@ -2,6 +2,5 @@ add_instance_library(device_batchnorm_instance
|
||||
device_batchnorm_forward_f16_instance.cpp
|
||||
device_batchnorm_forward_f32_instance.cpp
|
||||
device_batchnorm_forward_bf16_instance.cpp
|
||||
device_batchnorm_forward_i8_instance.cpp
|
||||
device_batchnorm_forward_f64_instance.cpp
|
||||
)
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
|
||||
using device_batchnorm_forward_i8_blockwise_instances = std::tuple<
|
||||
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename YElementwiseOp>
|
||||
using device_batchnorm_forward_i8_multiblock_instances =
|
||||
std::tuple <
|
||||
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormFwdImpl<I8, I8, F32, I8, I8, F32, YElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
void add_device_batchnorm_forward_rank_4_3_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceBatchNormFwd<I8, I8, F32, I8, I8, F32, PassThrough, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_forward_i8_blockwise_instances<4, 3, PassThrough>{});
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_forward_i8_multiblock_instances<4, 3, PassThrough>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -85,39 +85,22 @@ bool profile_batchnorm_forward_impl(int do_verification,
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
if constexpr(ck::is_same_v<XDataType, int8_t>)
|
||||
{
|
||||
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5}, num_thread);
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 1.0f;
|
||||
const float noise_stddev = 0.04f;
|
||||
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 2.5f;
|
||||
const float noise_stddev = 0.04f;
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
|
||||
resultRunningMean_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
|
||||
// initialize the runningMean to be values with tiny variation to the mean of the x
|
||||
// values
|
||||
resultRunningMean_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
|
||||
|
||||
resultRunningVariance_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 1.0f;
|
||||
const float noise_stddev = 0.04f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
|
||||
// initialize the runningMean to be values with tiny variation to the mean of the x
|
||||
// values
|
||||
resultRunningMean_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev}, num_thread);
|
||||
|
||||
// initialize the runningVariance to be values with tiny variation to the variance of
|
||||
// the x values
|
||||
resultRunningVariance_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
};
|
||||
// initialize the runningVariance to be values with tiny variation to the variance of
|
||||
// the x values
|
||||
resultRunningVariance_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -129,35 +112,24 @@ bool profile_batchnorm_forward_impl(int do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
if constexpr(ck::is_same_v<ScaleDataType, int8_t> && ck::is_same_v<BiasDataType, int8_t>)
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_0<BiasDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_1<BiasDataType>{0}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-1.0f, 1.0f}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1.0f, 1.0f}, num_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_0<BiasDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_1<BiasDataType>{0}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_2<BiasDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-1.0f, 1.0f},
|
||||
num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1.0f, 1.0f},
|
||||
num_thread);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
|
||||
@@ -48,7 +48,7 @@ class BatchnormFwdArgParser
|
||||
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl;
|
||||
std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl;
|
||||
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
|
||||
std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl;
|
||||
std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance (0=no, 1=yes)" << std::endl;
|
||||
std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance (0=no, 1=yes)" << std::endl;
|
||||
std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl;
|
||||
@@ -141,7 +141,6 @@ int profile_batchnorm_forward(int argc, char* argv[])
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F64 = double;
|
||||
|
||||
if(arg_parser.data_type == 0)
|
||||
@@ -178,23 +177,6 @@ int profile_batchnorm_forward(int argc, char* argv[])
|
||||
averageFactor);
|
||||
};
|
||||
}
|
||||
else if(arg_parser.data_type == 3)
|
||||
{
|
||||
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
|
||||
{
|
||||
profile_batchnorm_forward_impl<I8, I8, F32, I8, I8, F32, 4, 3>(
|
||||
arg_parser.do_verification,
|
||||
arg_parser.init_method,
|
||||
arg_parser.do_dumpout,
|
||||
arg_parser.time_kernel,
|
||||
arg_parser.inLengths,
|
||||
arg_parser.reduceDims,
|
||||
arg_parser.updateMovingAverage,
|
||||
arg_parser.saveMeanAndInvVariance,
|
||||
epsilon,
|
||||
averageFactor);
|
||||
};
|
||||
}
|
||||
else if(arg_parser.data_type == 5)
|
||||
{
|
||||
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
|
||||
|
||||
@@ -90,7 +90,6 @@ class TestBatchNormFwdRank4 : public ::testing::Test
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, F32>,
|
||||
std::tuple<F32, F32, F32, F32, F32, F32>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, F32>,
|
||||
std::tuple<I8, I8, F32, I8, I8, F32>,
|
||||
std::tuple<F64, F64, F64, F64, F64, F64>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes);
|
||||
|
||||
Reference in New Issue
Block a user