mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Batchnorm splitk single kernel (#771)
* Use dim 0 as faster dim for writing mean/var/count workspace in batchnorm multiblock method [performance] * Add CountDataType as template parameter in blockwise_welford * Add utility/get_shift.hpp * Add BatchNorm multiblock single-kernel implementation * Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a * Renaming in device_batchnorm_forward_impl.hpp * Tiny fix in the batchnorm_fwd profiler * Revert "Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a" This reverts commitd16d00919c. * Use the old two-kernel batchnorm multiblock method for gfx1030 * Use the old two-kernel batchnorm multiblock method for gfx908 * use the single-kernel batchnorm multiblock method only for gfx90a * Remove get_wave_id() from utility/get_id.hpp since it is not used * Set true for testing running mean/variance and saving mean/invvariance in the examples * Fix to copy-right words * Remove un-needed including in utility/get_id.hpp * Add comments to workgroup_synchronization.hpp * Remove un-used codes in gridwise_multiblock_batchnorm_forward.hpp * Renaming in the kernels * Remove un-used kernel file [ROCm/composable_kernel commit:8f5cafaf04]
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp)
|
||||
add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp)
|
||||
add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp)
|
||||
add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp)
|
||||
|
||||
@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
|
||||
|
||||
y_dev.FromDevice(y.mData.data());
|
||||
pass = pass && ck::utils::check_err(y, y_ref);
|
||||
pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values");
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
@@ -424,8 +424,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
|
||||
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref);
|
||||
pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref);
|
||||
pass = pass && ck::utils::check_err(resultRunningMean,
|
||||
resultRunningMean_ref,
|
||||
"Incorrect running mean values");
|
||||
pass = pass && ck::utils::check_err(resultRunningVariance,
|
||||
resultRunningVariance_ref,
|
||||
"Incorrect running variance values");
|
||||
};
|
||||
|
||||
if(saveMeanAndInvVariance)
|
||||
@@ -438,8 +442,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
|
||||
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref);
|
||||
pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref);
|
||||
pass = pass && ck::utils::check_err(
|
||||
resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values");
|
||||
pass = pass && ck::utils::check_err(resultSaveInvVariance,
|
||||
resultSaveInvVariance_ref,
|
||||
"Incorrect saved invvariance values");
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,598 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <limits>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
class BatchNormFwdArg
|
||||
{
|
||||
private:
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inOutLengths;
|
||||
|
||||
bool do_verification = false;
|
||||
|
||||
bool updateMovingAverage;
|
||||
bool saveMeanAndInvVariance;
|
||||
|
||||
int data_type = 0;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
bool use_multiblock_welford = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
{
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension "
|
||||
"lengths, must have 4 integers for nhwc"
|
||||
<< std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization "
|
||||
"result by "
|
||||
"comparing with the host-based batch-normalization"
|
||||
<< std::endl;
|
||||
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
|
||||
std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance "
|
||||
"(0=no, 1=yes)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance "
|
||||
"(0=no, 1=yes)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer "
|
||||
"value, 2=scope integer "
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
|
||||
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
{
|
||||
using ck::host_common::getTypeValuesFromString;
|
||||
|
||||
int ch;
|
||||
|
||||
while(1)
|
||||
{
|
||||
ch = getopt_long(argc, argv, "D:v:", long_options, &option_index);
|
||||
if(ch == -1)
|
||||
break;
|
||||
switch(ch)
|
||||
{
|
||||
case 'D':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
inOutLengths = getTypeValuesFromString<size_t>(optarg);
|
||||
|
||||
if(inOutLengths.size() != 4)
|
||||
throw std::runtime_error(
|
||||
"NHWC tensor layout should have 4 length values specified!");
|
||||
break;
|
||||
case 'v':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
do_verification = static_cast<bool>(std::atoi(optarg));
|
||||
break;
|
||||
case '?':
|
||||
if(std::string(long_options[option_index].name) == "help")
|
||||
{
|
||||
show_usage(argv[0]);
|
||||
return (-1);
|
||||
};
|
||||
break;
|
||||
default: show_usage(argv[0]); return (-1);
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 6 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
data_type = std::atoi(argv[optind++]);
|
||||
updateMovingAverage = std::atoi(argv[optind++]);
|
||||
saveMeanAndInvVariance = std::atoi(argv[optind++]);
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
|
||||
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
|
||||
return (-1);
|
||||
|
||||
return (0);
|
||||
};
|
||||
};
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const std::vector<size_t> inOutLengths,
|
||||
bool updateMovingAverage,
|
||||
bool saveMeanAndInvVariance,
|
||||
double averageFactor,
|
||||
double epsilon)
|
||||
{
|
||||
// for NHWC BatchNorm calculation of mean and meansquare
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
|
||||
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
|
||||
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
|
||||
|
||||
// input data of the batchnorm forward algorithm
|
||||
Tensor<InOutDataType> x(inOutLengths);
|
||||
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> bnBias(scaleBiasMeanVarLengths);
|
||||
|
||||
// output data of the batchnorm forward algorithm
|
||||
Tensor<InOutDataType> y_ref(inOutLengths);
|
||||
Tensor<InOutDataType> y(inOutLengths);
|
||||
|
||||
Tensor<AccDataType> resultSaveMean_ref(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultSaveInvVariance_ref(scaleBiasMeanVarLengths);
|
||||
|
||||
Tensor<AccDataType> resultRunningMean_ref(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultRunningVariance_ref(scaleBiasMeanVarLengths);
|
||||
|
||||
auto inOutStrides = x.mDesc.GetStrides();
|
||||
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides();
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, int8_t>::value)
|
||||
{
|
||||
x.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 2.5f;
|
||||
const float noise_stddev = 0.04f;
|
||||
|
||||
resultRunningMean_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, num_thread);
|
||||
|
||||
resultRunningVariance_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 1.0f;
|
||||
const float noise_stddev = 0.04f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
|
||||
|
||||
// initialize the runningMean to be values with tiny variation to the mean of the x
|
||||
// values
|
||||
resultRunningMean_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, num_thread);
|
||||
|
||||
// initialize the runningVariance to be values with tiny variation to the variance of
|
||||
// the x values
|
||||
resultRunningVariance_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, int8_t>::value)
|
||||
x.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
else
|
||||
x.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0f, 5.0f}, num_thread);
|
||||
};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
|
||||
}
|
||||
};
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize());
|
||||
|
||||
// mean_dev or resultSaveMean_dev
|
||||
DeviceMem resultSaveMean_dev(sizeof(AccDataType) *
|
||||
resultSaveMean_ref.mDesc.GetElementSpaceSize());
|
||||
// meansquare_dev or resultSaveInvVariance_dev
|
||||
DeviceMem resultSaveInvVariance_dev(sizeof(AccDataType) *
|
||||
resultSaveInvVariance_ref.mDesc.GetElementSpaceSize());
|
||||
// resultRunningMean_dev
|
||||
DeviceMem resultRunningMean_dev(sizeof(AccDataType) *
|
||||
resultRunningMean_ref.mDesc.GetElementSpaceSize());
|
||||
// resultRunningVariance_dev
|
||||
DeviceMem resultRunningVariance_dev(sizeof(AccDataType) *
|
||||
resultRunningVariance_ref.mDesc.GetElementSpaceSize());
|
||||
|
||||
x_dev.ToDevice(x.mData.data());
|
||||
bnScale_dev.ToDevice(bnScale.mData.data());
|
||||
bnBias_dev.ToDevice(bnBias.mData.data());
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data());
|
||||
resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data());
|
||||
};
|
||||
|
||||
std::array<index_t, Rank> i_inOutLengths;
|
||||
std::array<index_t, Rank> i_inOutStrides;
|
||||
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
|
||||
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
|
||||
|
||||
ck::ranges::copy(inOutLengths, i_inOutLengths.begin());
|
||||
ck::ranges::copy(inOutStrides, i_inOutStrides.begin());
|
||||
ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin());
|
||||
ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin());
|
||||
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceBatchNormFwdInstance =
|
||||
ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType, // ScaleDataType
|
||||
AccDataType, // BiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
PassThroughOp, // YElementwiseOp
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
UseMultiblockInK,
|
||||
256,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1>;
|
||||
|
||||
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
|
||||
|
||||
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_dev.GetDeviceBuffer(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
|
||||
|
||||
if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
|
||||
"exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
|
||||
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float avg_time = 0.0f;
|
||||
size_t num_bytes = 0;
|
||||
|
||||
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
|
||||
size_t invariant_length = inOutLengths[3];
|
||||
|
||||
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// inputing of x, scale, bias, outputing of y
|
||||
num_bytes +=
|
||||
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
|
||||
|
||||
// outputing of mean, inv-variance
|
||||
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
|
||||
|
||||
// updating of moving mean, variance
|
||||
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
}
|
||||
else
|
||||
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
using ReferenceBatchNormFwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
|
||||
|
||||
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x.mData.data(),
|
||||
bnScale.mData.data(),
|
||||
bnBias.mData.data(),
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_ref.mData.data(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
|
||||
|
||||
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
|
||||
"instance, exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
|
||||
|
||||
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
|
||||
|
||||
y_dev.FromDevice(y.mData.data());
|
||||
pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values");
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
Tensor<AccDataType> resultRunningMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultRunningVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
|
||||
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(resultRunningMean,
|
||||
resultRunningMean_ref,
|
||||
"Incorrect running mean values");
|
||||
pass = pass && ck::utils::check_err(resultRunningVariance,
|
||||
resultRunningVariance_ref,
|
||||
"Incorrect running variance values");
|
||||
};
|
||||
|
||||
if(saveMeanAndInvVariance)
|
||||
{
|
||||
using ck::host_common::dumpBufferToFile;
|
||||
|
||||
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
|
||||
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(
|
||||
resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values");
|
||||
pass = pass && ck::utils::check_err(resultSaveInvVariance,
|
||||
resultSaveInvVariance_ref,
|
||||
"Incorrect saved invvariance values");
|
||||
};
|
||||
};
|
||||
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
BatchNormFwdArg arg;
|
||||
|
||||
if(arg.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
|
||||
if(arg.data_type == 0)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 1)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 3)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 5)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 6)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 6, 512},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
|
||||
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 3, 1024},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/get_shift.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -35,10 +35,11 @@ struct BlockwiseWelford
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
template <typename CountDataType>
|
||||
__device__ static inline void
|
||||
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
|
||||
Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
CountDataType count = count_a + count_b;
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
@@ -46,11 +47,12 @@ struct BlockwiseWelford
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
__device__ static void Run(T& mean_value, T& var_value, int& count)
|
||||
template <typename CountDataType>
|
||||
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
|
||||
{
|
||||
__shared__ T mean_block_buf[BlockSize];
|
||||
__shared__ T var_block_buf[BlockSize];
|
||||
__shared__ int count_block_buf[BlockSize];
|
||||
__shared__ CountDataType count_block_buf[BlockSize];
|
||||
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
@@ -76,13 +78,13 @@ struct BlockwiseWelford
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
T mean1 = mean_block_buf[offset1];
|
||||
T var1 = var_block_buf[offset1];
|
||||
int count1 = count_block_buf[offset1];
|
||||
T mean1 = mean_block_buf[offset1];
|
||||
T var1 = var_block_buf[offset1];
|
||||
CountDataType count1 = count_block_buf[offset1];
|
||||
|
||||
T mean2 = mean_block_buf[offset2];
|
||||
T var2 = var_block_buf[offset2];
|
||||
int count2 = count_block_buf[offset2];
|
||||
T mean2 = mean_block_buf[offset2];
|
||||
T var2 = var_block_buf[offset2];
|
||||
CountDataType count2 = count_block_buf[offset2];
|
||||
|
||||
Merge(mean1, var1, count1, mean2, var2, count2);
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/get_shift.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -10,12 +10,14 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
|
||||
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto grid_desc_m_g =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
|
||||
const auto grid_desc_m_g = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
|
||||
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto reduceLength = blkGroupSize;
|
||||
const auto grid_desc_m_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
|
||||
const auto reduceLength = blkGroupSize;
|
||||
const auto grid_desc_m_k = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
// we want the blkGroupSize be not more than 16
|
||||
if(testBlkGroupSize <= 16)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
void* workspace_mean_;
|
||||
void* workspace_variance_;
|
||||
void* workspace_count_;
|
||||
|
||||
void* control_;
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
// workspace for welford intermediate count
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
|
||||
|
||||
// workspace for barrier objects, each barrier object consists of two integers
|
||||
// TODO: allocate barrier object memory globally to reuse it by other operators
|
||||
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
|
||||
sizeof(int) * 2;
|
||||
}
|
||||
|
||||
return (workspace_size);
|
||||
@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
|
||||
// setup buffer used for intermediate welford mean
|
||||
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
|
||||
|
||||
@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
// setup buffer used for intermediate welfor count
|
||||
pArg_->workspace_count_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
|
||||
|
||||
index_t count_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
|
||||
|
||||
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
|
||||
|
||||
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
|
||||
|
||||
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
|
||||
M_BlockTileSize * sizeof(int) * 2;
|
||||
|
||||
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
|
||||
};
|
||||
};
|
||||
|
||||
@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
|
||||
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
|
||||
|
||||
using GridwiseMultiblockBatchNormForward_ =
|
||||
GridwiseMultiblockBatchNormForward<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
using GridwiseMultiblockWelfordFirstHalf_ =
|
||||
GridwiseMultiblockWelfordFirstHalf<XDataType,
|
||||
AccDataType,
|
||||
@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
index_t numMeanVarCountBlockTileIteration =
|
||||
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize;
|
||||
// It is found that:
|
||||
// 1) gfx1030 does not support the GLC enabled vector load/store, so using the
|
||||
// two-kernel method for gfx1030
|
||||
// 2) Profiler on gfx908 could hang even though it works when running examples
|
||||
// 3) Single-kernel method works on gfx1100, but the performance it not better
|
||||
// than two-kernel method (due to more warps participating the barrier)
|
||||
if(ck::get_device_name() == "gfx90a")
|
||||
{
|
||||
const auto kern_multiblock_batchnorm_fwd_ =
|
||||
kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
const auto kern_multiblock_welford_first_half =
|
||||
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
|
||||
XDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_multiblock_batchnorm_fwd_,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
|
||||
// workspace by multiple workgroups
|
||||
mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
|
||||
// workspace by each workgroup
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
static_cast<int*>(arg.control_),
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_, // true or false
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_, // true or false
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kern_multiblock_welford_first_half =
|
||||
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
|
||||
XDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
const auto kern_welford_second_half_batchnorm_forward_final =
|
||||
kernel_welford_second_half_batchnorm_forward_final<
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M>;
|
||||
const auto kern_welford_second_half_batchnorm_forward_final =
|
||||
kernel_welford_second_half_batchnorm_forward_final<
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M>;
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_multiblock_welford_first_half,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_));
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_multiblock_welford_first_half,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_));
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_welford_second_half_batchnorm_forward_final,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
arg.blkGroupSize_,
|
||||
arg.numBlockTileIteration_,
|
||||
numMeanVarCountBlockTileIteration,
|
||||
arg.epsilon_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_,
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_,
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_welford_second_half_batchnorm_forward_final,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
arg.blkGroupSize_,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_,
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_,
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -0,0 +1,714 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim,
|
||||
bool UseMultiblockInK,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcYDstVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t YDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
|
||||
const std::array<index_t, Rank>& xyStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleXYLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleXYStrides =
|
||||
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
|
||||
|
||||
const auto grid_desc_m_k = [&]() {
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
|
||||
Number<NumBatchNormReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(raw_grid_desc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}();
|
||||
|
||||
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto grid_desc_m_g = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_g_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_g,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_g_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto reduceLength = blkGroupSize;
|
||||
const auto grid_desc_m_k = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad =
|
||||
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto
|
||||
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
|
||||
const std::array<index_t, NumInvariantDim>& strides)
|
||||
{
|
||||
const auto tupleLengths =
|
||||
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
|
||||
const auto tupleStrides =
|
||||
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
|
||||
|
||||
auto grid_desc_m = transform_tensor_descriptor(
|
||||
raw_grid_desc,
|
||||
make_tuple(make_merge_transform(tupleLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_padded =
|
||||
transform_tensor_descriptor(grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (grid_desc_m_padded);
|
||||
};
|
||||
|
||||
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
|
||||
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const ScaleDataType* p_scale,
|
||||
const BiasDataType* p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
double epsilon,
|
||||
YDataType* p_y,
|
||||
MeanVarDataType* resultSaveMean,
|
||||
MeanVarDataType* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
MeanVarDataType* resultRunningMean,
|
||||
MeanVarDataType* resultRunningVariance)
|
||||
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnBiasStrides_(bnBiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_scale_(p_scale),
|
||||
p_bias_(p_bias),
|
||||
y_elementwise_op_(y_elementwise_op),
|
||||
p_y_(p_y),
|
||||
resultSaveMean_(resultSaveMean),
|
||||
resultSaveInvVariance_(resultSaveInvVariance),
|
||||
resultRunningMean_(resultRunningMean),
|
||||
resultRunningVariance_(resultRunningVariance)
|
||||
{
|
||||
xyLengths_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
|
||||
xStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
|
||||
yStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
|
||||
|
||||
std::tie(invariant_length_, reduce_length_) =
|
||||
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
|
||||
|
||||
epsilon_ = type_convert<AccDataType>(epsilon);
|
||||
averageFactor_ = type_convert<AccDataType>(averageFactor);
|
||||
|
||||
updateMovingAverage_ =
|
||||
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
|
||||
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
|
||||
|
||||
if(UseMultiblockInK)
|
||||
{
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 16
|
||||
if(testBlkGroupSize <= 16)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
numBlockTileIteration_ = iterations;
|
||||
}
|
||||
else
|
||||
{
|
||||
blkGroupSize_ = 1;
|
||||
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
|
||||
|
||||
x_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
y_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
scale_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
|
||||
bias_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
|
||||
mean_var_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
AccDataType averageFactor_;
|
||||
|
||||
bool updateMovingAverage_;
|
||||
bool saveMeanInvVariance_;
|
||||
|
||||
std::array<index_t, Rank> xyLengths_;
|
||||
std::array<index_t, Rank> xStrides_;
|
||||
std::array<index_t, Rank> yStrides_;
|
||||
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const ScaleDataType* p_scale_;
|
||||
const BiasDataType* p_bias_;
|
||||
const YElementwiseOp y_elementwise_op_;
|
||||
YDataType* p_y_;
|
||||
|
||||
MeanVarDataType* resultSaveMean_;
|
||||
MeanVarDataType* resultSaveInvVariance_;
|
||||
|
||||
MeanVarDataType* resultRunningMean_;
|
||||
MeanVarDataType* resultRunningVariance_;
|
||||
|
||||
long_index_t invariant_length_;
|
||||
long_index_t reduce_length_;
|
||||
|
||||
int blkGroupSize_;
|
||||
int numBlockTileIteration_;
|
||||
size_t gridSize_;
|
||||
|
||||
XYGridDesc_M_K x_grid_desc_m_k_;
|
||||
XYGridDesc_M_K y_grid_desc_m_k_;
|
||||
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
|
||||
|
||||
void* workspace_mean_;
|
||||
void* workspace_variance_;
|
||||
void* workspace_count_;
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
// workspace for welford intermediate mean
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate variance
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate count
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
|
||||
}
|
||||
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
|
||||
// setup buffer used for intermediate welford mean
|
||||
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
|
||||
|
||||
index_t mean_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welford varirance
|
||||
pArg_->workspace_variance_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
|
||||
|
||||
index_t variance_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welfor count
|
||||
pArg_->workspace_count_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
|
||||
};
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0;
|
||||
|
||||
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_g =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_k =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
|
||||
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
|
||||
|
||||
using GridwiseMultiblockWelfordFirstHalf_ =
|
||||
GridwiseMultiblockWelfordFirstHalf<XDataType,
|
||||
AccDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize>;
|
||||
|
||||
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
const auto kern_multiblock_welford_first_half =
|
||||
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
|
||||
XDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
const auto kern_welford_second_half_batchnorm_forward_final =
|
||||
kernel_welford_second_half_batchnorm_forward_final<
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M>;
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_multiblock_welford_first_half,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_));
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_welford_second_half_batchnorm_forward_final,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
arg.blkGroupSize_,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_,
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_,
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kern_batchnorm_fwd,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_, // true or false
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_, // true or false
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
};
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* pArg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* pArg) override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
if constexpr(XSrcYDstVectorDim == 0)
|
||||
{
|
||||
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
|
||||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
|
||||
return false;
|
||||
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
|
||||
return false;
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
bool is_valid = true;
|
||||
|
||||
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
|
||||
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
|
||||
is_valid = false;
|
||||
});
|
||||
|
||||
if(!is_valid)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
void* p_y,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
yStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<const BiasDataType*>(p_bias),
|
||||
y_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<YDataType*>(p_y),
|
||||
static_cast<MeanVarDataType*>(resultSaveMean),
|
||||
static_cast<MeanVarDataType*>(resultSaveInvVariance),
|
||||
averageFactor,
|
||||
static_cast<MeanVarDataType*>(resultRunningMean),
|
||||
static_cast<MeanVarDataType*>(resultRunningVariance));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,704 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/utility/workgroup_synchronization.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseMultiblockBatchNormForward_,
|
||||
typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename MeanVarCountGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename GetReduceCountPerThreadFunctor>
|
||||
__global__ void kernel_multiblock_batchnorm_forward(
|
||||
const XYGridDesc_M_K x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K y_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
|
||||
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const __restrict__ p_welford_mean,
|
||||
MeanVarDataType* const __restrict__ p_welford_variance,
|
||||
int32_t* const __restrict__ p_welford_count,
|
||||
int32_t* const __restrict__ p_control,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const BiasDataType* const __restrict__ p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
YDataType* const __restrict__ p_y,
|
||||
bool updateMovingAverage,
|
||||
AccDataType averageFactor,
|
||||
MeanVarDataType* const __restrict__ resultRunningMean,
|
||||
MeanVarDataType* const __restrict__ resultRunningVariance,
|
||||
bool saveMeanInvVariance,
|
||||
MeanVarDataType* const __restrict__ resultSaveMean,
|
||||
MeanVarDataType* const __restrict__ resultSaveInvVariance)
|
||||
{
|
||||
GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k,
|
||||
y_grid_desc_m_k,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
scale_grid_desc_m,
|
||||
bias_grid_desc_m,
|
||||
mean_var_grid_desc_m,
|
||||
get_reduce_count_per_thread,
|
||||
num_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_x,
|
||||
p_welford_mean,
|
||||
p_welford_variance,
|
||||
p_welford_count,
|
||||
p_control,
|
||||
p_scale,
|
||||
p_bias,
|
||||
y_elementwise_op,
|
||||
p_y,
|
||||
updateMovingAverage,
|
||||
averageFactor,
|
||||
resultRunningMean,
|
||||
resultRunningVariance,
|
||||
saveMeanInvVariance,
|
||||
resultSaveMean,
|
||||
resultSaveInvVariance);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename MeanVarCountGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename GetReduceCountPerThreadFunctor,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcYDstVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t YDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct GridwiseMultiblockBatchNormForward
|
||||
{
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
|
||||
using ThreadwiseWelford1 =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
using ThreadwiseWelford2 =
|
||||
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford1 = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
false>;
|
||||
|
||||
using BlockwiseWelford2 = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
true>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K& y_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
|
||||
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M& scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M& mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const __restrict__ p_welford_mean,
|
||||
MeanVarDataType* const __restrict__ p_welford_variance,
|
||||
int32_t* const __restrict__ p_welford_count,
|
||||
int32_t* const __restrict__ p_control,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const BiasDataType* const __restrict__ p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
YDataType* const __restrict__ p_y,
|
||||
bool updateMovingAverage,
|
||||
AccDataType averageFactor,
|
||||
MeanVarDataType* const __restrict__ resultRunningMean,
|
||||
MeanVarDataType* const __restrict__ resultRunningVariance,
|
||||
bool saveMeanInvVariance,
|
||||
MeanVarDataType* const __restrict__ resultSaveMean,
|
||||
MeanVarDataType* const __restrict__ resultSaveInvVariance)
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / blkgroup_size;
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
if(block_local_id == 0)
|
||||
gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
tmp_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
tmp_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> tmp_count_thread_buf;
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
// Step 1: each workgroup does local welford reduction
|
||||
|
||||
auto threadwise_welford_1 = ThreadwiseWelford1();
|
||||
threadwise_welford_1.max_count_ =
|
||||
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
|
||||
threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
count_thread_buf(I) = threadwise_welford_1.cur_count_;
|
||||
BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
|
||||
});
|
||||
|
||||
// Step 2: each workgroup writes its local welford result to workspace memory
|
||||
|
||||
auto mean_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
|
||||
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto var_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
|
||||
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto count_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
|
||||
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_mean_var_store_m_g =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_count_store_m_g =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
|
||||
int32_t,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
mean_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
mean_global_val_buf);
|
||||
|
||||
threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
var_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
var_global_val_buf);
|
||||
|
||||
threadwise_count_store_m_g.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
count_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
count_global_val_buf);
|
||||
};
|
||||
|
||||
gms_barrier(&p_control[blkgroup_id * 2]);
|
||||
|
||||
if(block_local_id == 0)
|
||||
gms_reset(&p_control[blkgroup_id * 2]);
|
||||
|
||||
// Step 3: each workgroup reads welford results from workspace memory and does final welford
|
||||
// reduction
|
||||
|
||||
auto threadwise_mean_var_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
auto threadwise_count_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<int32_t,
|
||||
int32_t,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
count_thread_buf(I) = 0;
|
||||
});
|
||||
|
||||
constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize);
|
||||
|
||||
int32_t reducedSize = 0;
|
||||
while(reducedSize < blkgroup_size)
|
||||
{
|
||||
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
mean_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
tmp_mean_thread_buf);
|
||||
|
||||
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
var_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
tmp_var_thread_buf);
|
||||
|
||||
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
count_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
tmp_count_thread_buf);
|
||||
|
||||
ThreadwiseWelford2::Run(tmp_mean_thread_buf,
|
||||
tmp_var_thread_buf,
|
||||
tmp_count_thread_buf,
|
||||
mean_thread_buf,
|
||||
var_thread_buf,
|
||||
count_thread_buf);
|
||||
|
||||
reducedSize += KThreadClusterSize;
|
||||
|
||||
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
|
||||
mean_var_count_read_fwd_step_m_k);
|
||||
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
|
||||
mean_var_count_read_fwd_step_m_k);
|
||||
};
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
|
||||
});
|
||||
|
||||
// Step 4: do normalization using the mean/variance
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
y_thread_buf;
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
YDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
XYGridDesc_M_K,
|
||||
YElementwiseOp,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcYDstVectorDim,
|
||||
YDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
y_grid_desc_m_k,
|
||||
make_multi_index(
|
||||
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
|
||||
y_elementwise_op);
|
||||
|
||||
auto threadwise_scale_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BiasSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
bias_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_scale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bias, bias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_y, y_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
threadwise_scale_load.Run(scale_grid_desc_m,
|
||||
scale_global_val_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
scale_thread_buf);
|
||||
|
||||
threadwise_bias_load.Run(bias_grid_desc_m,
|
||||
bias_global_val_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
bias_thread_buf);
|
||||
|
||||
threadwise_x_load.SetSrcSliceOrigin(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType multiplier =
|
||||
scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
|
||||
|
||||
AccDataType fused_mean_bias =
|
||||
bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(Number<offset>{}) =
|
||||
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf,
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k);
|
||||
}
|
||||
|
||||
// Step 5: update the moving average of mean and variance (optional)
|
||||
|
||||
if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
running_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
running_var_thread_buf;
|
||||
|
||||
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_mean_var_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
|
||||
running_mean_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_mean_thread_buf);
|
||||
|
||||
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
|
||||
running_var_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_var_thread_buf);
|
||||
|
||||
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
|
||||
mean_thread_buf[I] * averageFactor;
|
||||
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
|
||||
var_thread_buf[I] * averageFactor;
|
||||
});
|
||||
|
||||
auto threadwise_mean_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
MeanVarGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_mean_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_mean_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
running_mean_global_buf);
|
||||
|
||||
threadwise_mean_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_var_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
running_var_global_buf);
|
||||
};
|
||||
|
||||
// Step 6: save mean and inv-variance (optional)
|
||||
|
||||
if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
|
||||
{
|
||||
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
var_thread_buf(I) =
|
||||
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
|
||||
});
|
||||
|
||||
auto threadwise_mean_inv_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
MeanVarGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
mean_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
result_mean_global_buf);
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
var_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
result_inv_var_global_buf);
|
||||
};
|
||||
}
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
|
||||
@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
index_t blkgroup_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_mean_var_count_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_mean,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_variance,
|
||||
@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
|
||||
mean_var_grid_desc_m,
|
||||
blkgroup_size,
|
||||
num_xy_k_block_tile_iteration,
|
||||
num_mean_var_count_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_in_welford_mean,
|
||||
p_in_welford_variance,
|
||||
@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
const MeanVarGridDesc_M& mean_var_grid_desc_m,
|
||||
index_t blkgroup_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_mean_var_count_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_mean,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_variance,
|
||||
@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
constexpr auto mean_var_count_thread_copy_step_m_k =
|
||||
make_multi_index(0, KThreadClusterSize * 1);
|
||||
|
||||
// Step 1: do final welford reduction to get mean and variance
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
welford_count_thread_buf(I) = 0;
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
|
||||
++reducedTiles)
|
||||
constexpr auto mean_var_count_thread_copy_step_m_k =
|
||||
make_multi_index(0, KThreadClusterSize);
|
||||
|
||||
int32_t reducedSize = 0;
|
||||
while(reducedSize < blkgroup_size)
|
||||
{
|
||||
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
welford_mean_global_val_buf,
|
||||
@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
welford_var_thread_buf,
|
||||
welford_count_thread_buf);
|
||||
|
||||
reducedSize += KThreadClusterSize;
|
||||
|
||||
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
|
||||
mean_var_count_thread_copy_step_m_k);
|
||||
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
|
||||
20
include/ck/utility/get_shift.hpp
Normal file
20
include/ck/utility/get_shift.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
static constexpr __device__ index_t get_shift()
|
||||
{
|
||||
return (get_shift<N / 2>() + 1);
|
||||
};
|
||||
|
||||
template <>
|
||||
constexpr __device__ index_t get_shift<1>()
|
||||
{
|
||||
return (0);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -25,16 +25,4 @@ struct float_equal_zero
|
||||
};
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
static constexpr __device__ index_t get_shift()
|
||||
{
|
||||
return (get_shift<N / 2>() + 1);
|
||||
};
|
||||
|
||||
template <>
|
||||
constexpr __device__ index_t get_shift<1>()
|
||||
{
|
||||
return (0);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
74
include/ck/utility/workgroup_synchronization.hpp
Normal file
74
include/ck/utility/workgroup_synchronization.hpp
Normal file
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Initialization flag of Barrier object, can be any value except for zero
|
||||
static constexpr int BarrierInitFlag = 0x7856;
|
||||
|
||||
// 1) only the first thread-block in the synchronizaton group is supposed to call this function. It
|
||||
// is the responsibility of the user to ensure the two integer values in p_control_bits are zeros
|
||||
// before calling gms_init().
|
||||
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
|
||||
// repetitious initialization of p_control_bits buffer is required
|
||||
static __device__ void gms_init(int NumWarps, int* p_control_bits)
|
||||
{
|
||||
union
|
||||
{
|
||||
int two32[2];
|
||||
unsigned long one64;
|
||||
} regs;
|
||||
|
||||
regs.two32[0] = BarrierInitFlag;
|
||||
regs.two32[1] = NumWarps;
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
atomicCAS(reinterpret_cast<unsigned long*>(p_control_bits), 0, regs.one64);
|
||||
};
|
||||
|
||||
// all the workgroups in the synchronization group is supposed to call this function
|
||||
static __device__ void gms_barrier(int* p_control_bits)
|
||||
{
|
||||
constexpr int mask = warpSize - 1;
|
||||
|
||||
if((threadIdx.x & mask) == 0)
|
||||
{
|
||||
// ensure the barrier object is initialized
|
||||
do
|
||||
{
|
||||
const int r0 = __atomic_load_n(&p_control_bits[0], __ATOMIC_RELAXED);
|
||||
|
||||
if(r0 == BarrierInitFlag)
|
||||
break;
|
||||
|
||||
} while(true);
|
||||
|
||||
// go ahead toward the barrier line
|
||||
atomicSub(&p_control_bits[1], 1);
|
||||
|
||||
// wait until all warps have arrived
|
||||
do
|
||||
{
|
||||
const int r1 = __atomic_load_n(&p_control_bits[1], __ATOMIC_RELAXED);
|
||||
|
||||
if(r1 == 0)
|
||||
break;
|
||||
|
||||
} while(true);
|
||||
};
|
||||
};
|
||||
|
||||
// 1) Only the first thread-block in the synchronizaton group is supposed to call this function.
|
||||
// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no
|
||||
// repetitious initialization of p_control_bits buffer is required
|
||||
static __device__ void gms_reset(int* p_control_bits)
|
||||
{
|
||||
// reset the barrier object
|
||||
if(threadIdx.x == 0)
|
||||
(void)atomicCAS(&p_control_bits[0], BarrierInitFlag, 0);
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -148,7 +148,7 @@ int profile_batchnorm_forward(int argc, char* argv[])
|
||||
{
|
||||
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
|
||||
{
|
||||
profile_batchnorm_forward_impl<F16, F16, F32, F16, F16, F16, 4, 3>(
|
||||
profile_batchnorm_forward_impl<F16, F16, F32, F16, F16, F32, 4, 3>(
|
||||
arg_parser.do_verification,
|
||||
arg_parser.init_method,
|
||||
arg_parser.do_dumpout,
|
||||
|
||||
Reference in New Issue
Block a user