mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Gemm alpha beta profiler (fp32 & fp16) (#91)
* [What] Refactor verification of gemm alpha_beta, move to reference operation
[Why] Sync with other verification
* Profile mk_nk for gemm bias 2d
* Support bias 2d with mn * kn in profiler
* Support bias 2d with km*kn and km*nk in profiler
* Support fp32 bias 2d in profiler
* format
* format
Co-authored-by: rocking <chunylai@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 19c5d6e651]
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm_xdl_c_shuffle_bias_2d.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm_bias_2d.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -72,43 +73,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
|
||||
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
// clang-format on
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename C0Type,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
static void host_verify(const Tensor<AType>& a_m_k,
|
||||
const Tensor<BType>& b_k_n,
|
||||
const Tensor<C0Type>& c0_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a_m_k.mDesc.GetLengths()[1];
|
||||
|
||||
AccDataType v = 0;
|
||||
AccDataType a = 0;
|
||||
AccDataType b = 0;
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
a_element_op(a, a_m_k(m, k));
|
||||
b_element_op(b, b_k_n(k, n));
|
||||
v += a * b;
|
||||
}
|
||||
|
||||
CType y = static_cast<CType>(v);
|
||||
|
||||
c_element_op(c_m_n(m, n), y, c0_k_n(m, n));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn,
|
||||
c_m_n.mDesc.GetLengths()[0],
|
||||
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
|
||||
}
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -259,13 +231,18 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_verify(a_m_k,
|
||||
b_k_n,
|
||||
c0_m_n,
|
||||
c_m_n_host_result,
|
||||
AElementOp{},
|
||||
BElementOp{},
|
||||
CElementOp{alpha, beta});
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
|
||||
b_k_n,
|
||||
c0_m_n,
|
||||
c_m_n_host_result,
|
||||
AElementOp{},
|
||||
BElementOp{},
|
||||
CElementOp{alpha, beta});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user