mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
* Turning compare warnings on
* Cleaning part I
* Cleaning part II
* Explicit static_cast to ck::type_convert
* Resolving large tensor size issue.
* format
* revert change to tensor descriptor; promote lementSpaceSize to 64bit
* use integer value for GEMM test
* Review remarks
* Review remarks + issues with (un)signed arithmetic
* Format fix
* Format
* Clang-format.
* fix 2gb limit issue
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
[ROCm/composable_kernel commit: f03a1738d9]
This commit is contained in:
@@ -222,7 +222,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float max_diff = 1e-6;
|
||||
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
for(std::size_t i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
@@ -236,16 +236,16 @@ template <typename DataType>
|
||||
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int n = 0; n < nhwc.mDesc.GetLengths()[0]; n++)
|
||||
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int hi = 0; hi < nhwc.mDesc.GetLengths()[2]; hi++)
|
||||
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int wi = 0; wi < nhwc.mDesc.GetLengths()[3]; wi++)
|
||||
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int c = 0; c < nhwc.mDesc.GetLengths()[1]; c++)
|
||||
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
|
||||
{
|
||||
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
|
||||
}
|
||||
|
||||
@@ -50,12 +50,12 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
int nrepeat,
|
||||
std::vector<int> Ms,
|
||||
std::vector<int> Ns,
|
||||
std::vector<int> Ks,
|
||||
std::vector<int> StrideAs,
|
||||
std::vector<int> StrideBs,
|
||||
std::vector<int> StrideCs)
|
||||
const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -71,7 +71,7 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
int group_count = Ms.size();
|
||||
std::size_t group_count = Ms.size();
|
||||
|
||||
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
|
||||
group_count == StrideBs.size() && group_count == StrideCs.size()))
|
||||
@@ -83,7 +83,7 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<Tensor<CDataType>> c_m_n_device_results;
|
||||
|
||||
for(int i = 0; i < Ms.size(); i++)
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_m_k.push_back(
|
||||
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
|
||||
@@ -144,7 +144,7 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
gemm_shapes.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace()));
|
||||
@@ -234,7 +234,7 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int i = 0; i < gemm_shapes.size(); i++)
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
@@ -258,7 +258,7 @@ void profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(int i = 0; i < gemm_shapes.size(); i++)
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
{
|
||||
|
||||
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
|
||||
|
||||
Reference in New Issue
Block a user