From 12f4cfce96a8dab6ff0e790ae9028d39ee88e303 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 23 Mar 2022 22:19:38 -0500 Subject: [PATCH] fixed alloc mem size (#145) --- example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp | 6 +++--- profiler/include/profile_grouped_gemm_impl.hpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 03afb7c44c..7c23a2f468 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -169,11 +169,11 @@ int main(int argc, char* argv[]) for(int i = 0; i < gemm_shapes.size(); i++) { a_tensors_device.emplace_back( - std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); + std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace())); b_tensors_device.emplace_back( - std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize())); + std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace())); c_tensors_device.emplace_back(std::make_unique( - sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize())); + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace())); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp index 2d99e93cfd..33ea11c341 100644 --- a/profiler/include/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -145,12 +145,12 @@ void profile_grouped_gemm_impl(int do_verification, for(int i = 0; i < group_count; i++) { a_device_buf.emplace_back( - std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize())); + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace())); b_device_buf.emplace_back( - std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSize())); + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace())); c_device_buf.emplace_back(std::make_unique( - sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSize())); + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace())); a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data());