mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add Normalization splitk instances (#829)
* Add normalization splitK to layernorm and groupnorm instances
* Fix bug of GetKPerThread()
* Refine naming
* clang format
[ROCm/composable_kernel commit: 03b8119e2e]
This commit is contained in:
@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
|
||||
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
|
||||
__device__ static int
|
||||
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
|
||||
GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
|
||||
{
|
||||
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
|
||||
|
||||
if(is_rightmost_block)
|
||||
{
|
||||
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
|
||||
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock;
|
||||
int kPerThread =
|
||||
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
|
||||
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
|
||||
int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
|
||||
int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
|
||||
int kPerThread = kRightmostBlock < K_BlockTileSize
|
||||
? 0
|
||||
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
|
||||
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
|
||||
|
||||
if(kPerBlockTail > 0)
|
||||
{
|
||||
@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
|
||||
}
|
||||
else
|
||||
{
|
||||
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
|
||||
int kPerBlock = math::integer_divide_ceil(k, kGridSize);
|
||||
return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
|
||||
}
|
||||
}
|
||||
@@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st
|
||||
auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
|
||||
threadwise_welford.max_count_ =
|
||||
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id);
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
|
||||
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
|
||||
kRaw,
|
||||
k_grid_size,
|
||||
block_k_cluster_id,
|
||||
thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f16_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f16_generic_instance<Pass, 5, 3>{});
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f16_instances<Pass, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f32_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f32_generic_instance<Pass, 5, 3>{});
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 5, 3>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f32_instances<Pass, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -18,6 +18,8 @@ void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
instances, device_normalization_f16_f32_f32_f16_generic_instance<Swish, 5, 3>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f16_f32_f32_f16_instances<Swish, 5, 3>{});
|
||||
add_device_operation_instances(
|
||||
instances, device_normalization_splitk_f16_f32_f32_f16_instances<Swish, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f16_generic_instance<Swish, 5, 3>{});
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Swish, 5, 3>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f16_instances<Swish, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f32_generic_instance<Swish, 5, 3>{});
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Swish, 5, 3>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f32_instances<Swish, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f16_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f16_generic_instance<Pass, 2, 1>{});
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 2, 1>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f16_instances<Pass, 2, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f32_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f32_generic_instance<Pass, 2, 1>{});
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 2, 1>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f32_instances<Pass, 2, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f16_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f16_generic_instance<Pass, 4, 3>{});
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 4, 3>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f16_instances<Pass, 4, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f32_instances(
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f32_generic_instance<Pass, 4, 3>{});
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 4, 3>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_splitk_f32_instances<Pass, 4, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
@@ -43,6 +44,32 @@ using device_normalization_f16_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f16_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationSplitKImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
@@ -76,6 +103,32 @@ using device_normalization_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f32_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
@@ -109,6 +162,32 @@ using device_normalization_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationSplitKImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple<
|
||||
// clang-format off
|
||||
|
||||
@@ -139,6 +139,10 @@ bool profile_groupnorm_impl(int do_verification,
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
@@ -155,6 +155,10 @@ bool profile_layernorm_impl(int do_verification,
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
Reference in New Issue
Block a user