mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Add instances for conv_scale with fp8 in/out (#1193)
* Add fp8 conv instances and client example
* Format
* Add example
* Update cmakelists
* Add profiler mode
* Format
* Fix copyright headers
[ROCm/composable_kernel commit: e626d5202a]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -23,6 +23,7 @@ enum struct ConvDataType
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_fwd"
|
||||
@@ -36,7 +37,8 @@ static void print_helper_msg()
|
||||
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
|
||||
<< " 1: Input fp16, Weight fp16, Output fp16\n"
|
||||
<< " 2: Input bf16, Weight bf16, Output bf16\n"
|
||||
<< " 3: Input int8, Weight int8, Output int8)\n"
|
||||
<< " 3: Input int8, Weight int8, Output int8\n"
|
||||
<< " 4: Input fp8, Weight fp8, Output fp8)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
@@ -79,6 +81,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using INT8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
//
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
@@ -250,6 +253,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F8_F8_F8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user