mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
[Bf16 & int8] [example & ckprofiler] (#100)
* Add int8 of mk_nk_mn to the ckProfiler
* Add example of int8 gemm
* Fix typo, use ushort instead of half_t for bfloat16
* replace ushortXXX_t to bhalfXXX_t
* rename ushort to bhalf_t
* Add bf16 example
* Add bf16 gemm to ckProfiler
* Fix alignment
* Fix typo
* Add unit test for gemm_xdl int8
* Add gemm_xdl fp32 unit test
* Add gemm_xdl bf16 unit test
* fix build
* fix build issue due to merge conflict
* Fix build
* Fix build error
Co-authored-by: rocking <chunylai@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 7e9a9d32c7]
This commit is contained in:
@@ -174,9 +174,9 @@ void profile_conv_fwd_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> &&
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>)
|
||||
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, bhalf_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, bhalf_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, bhalf_t>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
|
||||
@@ -26,11 +26,17 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNo
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
|
||||
@@ -91,12 +97,11 @@ void profile_gemm_impl(int do_verification,
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
switch(init_method)
|
||||
@@ -122,19 +127,10 @@ void profile_gemm_impl(int do_verification,
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
|
||||
// if(do_verification)
|
||||
// {
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
// }
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
@@ -290,6 +286,29 @@ void profile_gemm_impl(int do_verification,
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
|
||||
is_same<BDataType, ck::bhalf_t>::value &&
|
||||
is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
|
||||
is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
if(gemm_ptrs.size() <= 0)
|
||||
{
|
||||
@@ -351,14 +370,79 @@ void profile_gemm_impl(int do_verification,
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
|
||||
is_same<BDataType, ck::bhalf_t>::value &&
|
||||
is_same<CDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<float> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<float> c_m_n_device_f32_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
bf16_to_f32_(a_m_k, a_f32_m_k);
|
||||
bf16_to_f32_(b_k_n, b_f32_k_n);
|
||||
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
|
||||
b_f32_k_n,
|
||||
c_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
check_error(c_m_n_host_result, c_m_n_device_f32_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user