mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Standalone sweep once softmax kernel w/ ckProfiler (#295)
* use 'sweep once' softmax kernel where applicable
* threadwise copy's dst buffer can specify invalid element value
* add int8 in/out float compute softmax support
give a bit of leeway for int absolute tolerance as there's a single data point of all test cases showing off-by-1 error
* format
* softmax inherits DeviceNormalization
* softmax profiler stub
* tighten up reference softmax interface
* example prints tensor dimension
* add fp32 to softmax profiler
* rename header
* hook with ckProfiler
* format
* resolve merge conflict
* resolve merge conflicts
* update normalization profiler help string
* resolve conflict
* typo
* remove residual
* softmax profiler: address feedback
* test for mixed precision input/output
* fully qualify ck::math::isnan
* add comment for device normalization interface
* revise wording
* constness for alpha/beta scaler pointer
[ROCm/composable_kernel commit: 93c99f3d87]
This commit is contained in:
@@ -150,6 +150,9 @@ int main(int argc, char* argv[])
|
||||
AccDataType alpha = args.scales[0];
|
||||
AccDataType beta = args.scales[1];
|
||||
|
||||
std::cout << "in: " << in.mDesc << std::endl;
|
||||
std::cout << "out: " << out.mDesc << std::endl;
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
if(args.do_verification)
|
||||
@@ -195,7 +198,7 @@ int main(int argc, char* argv[])
|
||||
using ReferenceInstance =
|
||||
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
|
||||
ReferenceInstance ref;
|
||||
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, Rank, reduceDims);
|
||||
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims);
|
||||
auto invoker = ref.MakeInvoker();
|
||||
invoker.Run(ref_arg);
|
||||
// LogRangeAsType<float>(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl;
|
||||
@@ -212,8 +215,8 @@ int main(int argc, char* argv[])
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
&alpha,
|
||||
&beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user