[GEMM] UniversalGemm update (#1262)

* Add bf16 instances

* Add bf16 gemm universal example

* tempsave

* Add guard to navi compilation

* workground on a specific mixed gemm instance ( bring back it when compiler fix upload)

* fix formatting condition statement issue

* solve conflict

---------

Co-authored-by: Jun Liu <Liu.Jun@amd.com>

[ROCm/composable_kernel commit: 764164b488]
This commit is contained in:
Haocong WANG
2024-04-27 01:56:07 +08:00
committed by GitHub
parent 2d642d2737
commit 1a34c500a6
37 changed files with 1235 additions and 174 deletions

View File

@@ -28,6 +28,8 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)

View File

@@ -0,0 +1,48 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
128, 128,
64, 8, 8,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -3,6 +3,88 @@
#pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
@@ -180,7 +262,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#endif
}