diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index 7a8e1d9a37..ca17255b25 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/library/utility/host_tensor.hpp" +#include namespace ck { namespace tensor_operation { @@ -30,14 +31,18 @@ struct ReferenceBatchedGemm : public device::BaseOperator Tensor& c_g_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const int k_batch=1) : a_g_m_k_{a_g_m_k}, b_g_k_n_{b_g_k_n}, c_g_m_n_{c_g_m_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - c_element_op_{c_element_op} + c_element_op_{c_element_op}, + k_batch_(k_batch) { + if(k_batch < 1) + throw std::invalid_argument("Batch size must be at least 1"); } const Tensor& a_g_m_k_; @@ -47,6 +52,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; + + const int k_batch_; }; // Invoker @@ -59,23 +66,52 @@ struct ReferenceBatchedGemm : public device::BaseOperator auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; - AccDataType v_acc = 0; + // simulate fp accuacy implications of k batching + std::vector partialSums(arg.k_batch_); - for(int k = 0; k < K; ++k) + for(int batchIdx = 0; batchIdx < arg.k_batch_; ++batchIdx) { - ADataType v_a; - BDataType v_b; + int batchSize = std::max(K/arg.k_batch_, 1); + int batchStart = batchSize*batchIdx; + int batchEnd = batchSize*(batchIdx+1); + // add any extra round-off to last batch + if(batchIdx == arg.k_batch_-1) + batchEnd = K; - arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); - arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); + AccDataType v_acc = 0; + for(int k = batchStart; k < batchEnd; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); + arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + arg.c_element_op_(v_c, v_acc); + partialSums[batchIdx] = ck::type_convert(v_c); - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); } - AccDataType v_c; - - arg.c_element_op_(v_c, v_acc); + // finally, sum up partial sums + // note that we can't simulate the random nature of atomic additions, but at least we can + // simulate the effect of partial sums + AccDataType v_c = 0; + if(arg.k_batch_ > 1) + { + for(int batchIdx = 0;batchIdx < arg.k_batch_;batchIdx++) + { + v_c = ck::type_convert(ck::type_convert(v_c) + ck::type_convert(partialSums[batchIdx])); + } + } + else + { + v_c = partialSums[0]; + } arg.c_g_m_n_(g, m, n) = ck::type_convert(v_c); }; @@ -108,9 +144,10 @@ struct ReferenceBatchedGemm : public device::BaseOperator Tensor& c_g_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const int k_batch=1) { - return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op}; + return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op, k_batch}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp index 4d0d69593d..8908376b08 100644 --- a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp @@ -105,7 +105,7 @@ bool profile_gemm_b_scale_impl(int do_verification, break; case 2: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 2}); b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; default: diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index bb73c4e3da..bee907dd76 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -90,7 +90,7 @@ bool profile_gemm_universal_impl(int do_verification, break; case 2: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 2}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0});