From c4830c565def372f556c49eeea25df236b519ef5 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 11 Nov 2022 03:03:01 +0800 Subject: [PATCH] Rangify FillUniformDistributionIntegerValue<> (#443) Allow passing forward range to its call operator [ROCm/composable_kernel commit: 6f0564f013bbc33428c6343042cf658461b1bbed] --- example/01_gemm/run_gemm_example.inc | 10 ++++------ .../run_convnd_fwd_max_example.inc | 11 ++++------ .../gemm_add_addsquare_xdl_int8.cpp | 10 ++++------ .../gemm_reduce_xdl_common.hpp | 20 ++++++++----------- .../42_groupnorm/groupnorm_sigmoid_fp16.cpp | 6 +++--- library/include/ck/library/utility/fill.hpp | 17 +++++++++++++--- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 10b9917376..4d3759eb9d 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -32,14 +32,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k.begin(), - a_m_k.end()); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n.begin(), - b_k_n.end()); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n.begin(), b_k_n.end()); + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); } Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index 32c6475020..c93ee941c1 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -77,15 +77,12 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input.begin(), - conv_input.end()); - ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_weight.begin(), - conv_weight.end()); + ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); + ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_weight); break; default: - ck::utils::FillUniformDistribution{-5, 5}(conv_input.begin(), conv_input.end()); - ck::utils::FillUniformDistribution{-5, 5}(conv_weight.begin(), - conv_weight.end()); + ck::utils::FillUniformDistribution{-5, 5}(conv_input); + ck::utils::FillUniformDistribution{-5, 5}(conv_weight); } DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize()); diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp index bc621a4b8b..f644440334 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp @@ -160,14 +160,12 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M, { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k.begin(), - a_m_k.end()); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n.begin(), - b_k_n.end()); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n.begin(), b_k_n.end()); + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); break; } diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 036ab436cc..8ba6342c8d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -134,14 +134,12 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k.begin(), - a_m_k.end()); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n.begin(), - b_k_n.end()); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n.begin(), b_k_n.end()); + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); break; } @@ -339,14 +337,12 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M, { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k.begin(), - a_m_k.end()); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n.begin(), - b_k_n.end()); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n.begin(), b_k_n.end()); + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); break; } diff --git a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp index 69cacfd143..d8a8a27c97 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp @@ -100,9 +100,9 @@ int main(int argc, char* argv[]) Tensor gamma({G, C}); Tensor beta({G, C}); - ck::utils::FillUniformDistribution{0.f, 1.f}(x.begin(), x.end()); - ck::utils::FillUniformDistribution{0.f, 1.f}(gamma.begin(), gamma.end()); - ck::utils::FillUniformDistribution{0.f, 1.f}(beta.begin(), beta.end()); + ck::utils::FillUniformDistribution{0.f, 1.f}(x); + ck::utils::FillUniformDistribution{0.f, 1.f}(gamma); + ck::utils::FillUniformDistribution{0.f, 1.f}(beta); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index d717738dc4..54d58f362c 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -30,9 +30,10 @@ struct FillUniformDistribution } template - auto operator()(ForwardRange&& range) -> std::void_t()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); @@ -72,6 +73,16 @@ struct FillUniformDistributionIntegerValue std::generate( first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } }; template