Standalone sweep once softmax kernel w/ ckProfiler (#295)

* use 'sweep once' softmax kernel where applicable

* threadwise copy's dst buffer can specify invalid element value

* add int8 in/out float compute softmax support

give a bit of leeway for int absolute tolerance as there's a single data point of all test cases showing off-by-1 error

* format

* softmax inherits DeviceNormalization

* softmax profiler stub

* tighten up reference softmax interface

* example prints tensor dimension

* add fp32 to softmax profiler

* rename header

* hook with ckProfiler

* format

* resolve merge conflict

* resolve merge conflicts

* update normalization profiler help string

* resolve conflict

* typo

* remove residual

* softmax profiler: address feedback

* test for mixed precision input/output

* fully qualify ck::math::isnan

* add comment for device normalization interface

* revise wording

* constness for alpha/beta scaler pointer
This commit is contained in:
Anthony Chang
2022-07-01 01:08:50 +08:00
committed by GitHub
parent eccf8773a6
commit 93c99f3d87
24 changed files with 809 additions and 106 deletions

View File

@@ -150,6 +150,9 @@ int main(int argc, char* argv[])
AccDataType alpha = args.scales[0];
AccDataType beta = args.scales[1];
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "out: " << out.mDesc << std::endl;
std::size_t num_thread = 1;
if(args.do_verification)
@@ -195,7 +198,7 @@ int main(int argc, char* argv[])
using ReferenceInstance =
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
ReferenceInstance ref;
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, Rank, reduceDims);
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims);
auto invoker = ref.MakeInvoker();
invoker.Run(ref_arg);
// LogRangeAsType<float>(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl;
@@ -212,8 +215,8 @@ int main(int argc, char* argv[])
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
i_inStrides,
reduceDims,
alpha,
beta,
&alpha,
&beta,
in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer());