mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Simulate TF32 with BF16x3 (#3142)
* tf32:bf16x3:use bf16x3 emulate tf32 gemm
* change blockwiseGemm to demo bf16x3
* temp push
* self review
* self review
* fix multi-device compile error
* bug fix
* code refactor
* limit to gfx950
* enhance gemm gfx942 threshold
* lower change from blockwise to warpwise
* refact codes
* refact codes
* error fix
* change threshold
* bug fix
* fix threshold error
* change host reference implement to same as device
* bug fix
* bug fix
* code refact
* fix clang-format fail
* code refine
[ROCm/composable_kernel commit: 2a73eb3bc0]
This commit is contained in:
@@ -105,7 +105,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942)
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
|
||||
|
||||
@@ -21,7 +21,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942)
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
|
||||
|
||||
@@ -77,7 +77,7 @@ inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-2;
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
|
||||
@@ -33,3 +33,13 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
|
||||
endif()
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
66
example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp
Normal file
66
example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#define EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F32;
|
||||
using ComputeDataType = ck::tf32_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
|
||||
#undef EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// use macro to minimize code change
|
||||
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
using ComputeDataType = AccDataType;
|
||||
#endif
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
@@ -231,7 +236,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
CDEElementOp,
|
||||
ComputeDataType>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
@@ -253,7 +259,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]);
|
||||
|
||||
#else
|
||||
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
|
||||
pass &= ck::utils::check_err<decltype(c_device_tensors[i]),
|
||||
decltype(c_host_tensors[i]),
|
||||
ComputeDataType>(c_device_tensors[i], c_host_tensors[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +129,10 @@ inline bool is_wmma_supported()
|
||||
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
|
||||
}
|
||||
|
||||
inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; }
|
||||
inline bool is_tf32_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -168,8 +168,8 @@ typename std::enable_if<
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!",
|
||||
double rtol = 1e-5,
|
||||
double atol = 3e-5)
|
||||
double rtol = 5e-4,
|
||||
double atol = 5e-4)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
|
||||
@@ -94,7 +94,8 @@ template <typename ALayout,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeDataType = ADataType>
|
||||
struct DeviceGroupedGemm : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
@@ -134,7 +134,8 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeDataType = ADataType>
|
||||
struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -145,7 +146,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemm_Xdl;
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
@@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
|
||||
@@ -80,8 +80,10 @@ enum struct MfmaInstr
|
||||
mfma_f32_16x16x128f8f6f4,
|
||||
mfma_scale_f32_32x32x64f8f6f4,
|
||||
mfma_scale_f32_16x16x128f8f6f4,
|
||||
mfma_f32_16x16x8xf32, // tf32
|
||||
mfma_f32_32x32x4xf32,
|
||||
mfma_f32_16x16x8xf32, // tf32 on gfx942
|
||||
mfma_f32_32x32x4xf32, // tf32 on gfx942
|
||||
mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
|
||||
mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
|
||||
// gfx11
|
||||
wmma_f32_16x16x16_f16,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
@@ -1015,6 +1017,51 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16xf32>
|
||||
{
|
||||
// gfx950 specific: use bf16x3 simulate tf32
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32xf32>
|
||||
{
|
||||
// gfx950 specific: use bf16x3 simulate tf32
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
// gfx11
|
||||
struct mfma_type_gfx11_base
|
||||
{
|
||||
@@ -1275,12 +1322,14 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 32, 32>()
|
||||
constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x16xf32;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_f32_32x32x4xf32;
|
||||
#else
|
||||
@@ -1289,12 +1338,14 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<tf32_t, 16, 16>()
|
||||
constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx12;
|
||||
#elif defined(__gfx11__)
|
||||
return MfmaInstr::wmma_unsupport_16x16_gfx11;
|
||||
#elif defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x32xf32;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_f32_16x16x8xf32;
|
||||
#else
|
||||
@@ -2185,6 +2236,10 @@ struct XdlopsGemm
|
||||
(is_same<base_type, int8_t>::value && KPack <= 8) ||
|
||||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
|
||||
is_same<additional_type, pk_i4_t>::value)
|
||||
#if defined(__gfx950__)
|
||||
// tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
|
||||
|| (is_same<base_type, tf32_t>::value && KPack <= 4)
|
||||
#endif
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto mfma = MfmaSelector<base_type,
|
||||
|
||||
@@ -10,6 +10,25 @@ namespace ck {
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
// Helper function to convert float vector to bf16 vectors (big and small parts)
|
||||
// This is used by both tf32 and xf32 implementations
|
||||
template <index_t VecSize>
|
||||
__device__ __forceinline__ void
|
||||
convert_float_to_bf16_pairs(const vector_type<float, VecSize>& reg_f32,
|
||||
vector_type<bhalf_t, VecSize>& reg_bf16_big,
|
||||
vector_type<bhalf_t, VecSize>& reg_bf16_small)
|
||||
{
|
||||
static_for<0, VecSize, 1>{}([&](auto k) {
|
||||
using IK = Number<k>;
|
||||
reg_bf16_big.template AsType<bhalf_t>()(k) =
|
||||
type_convert<bhalf_t, float>(reg_f32.template AsType<float>()[IK{}]);
|
||||
reg_bf16_small.template AsType<bhalf_t>()(k) = type_convert<bhalf_t, float>(
|
||||
reg_f32.template AsType<float>()[IK{}] -
|
||||
type_convert<float, bhalf_t>(reg_bf16_big.template AsType<bhalf_t>()[IK{}]));
|
||||
});
|
||||
}
|
||||
/* */
|
||||
|
||||
// fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
@@ -1636,7 +1655,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32 *************************************/
|
||||
/******************* tf32 on gfx942 *************************************/
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x8xf32;
|
||||
|
||||
@@ -1646,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx942__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
@@ -1666,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx942__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
@@ -1677,4 +1696,102 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32>
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32/xf32 on gfx950 ********************************/
|
||||
/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */
|
||||
/* step: */
|
||||
/* 1. separate one input to 2 bf16 registers: */
|
||||
/* in_bf16_big = f32_to_bf16(in_f32) */
|
||||
/* in_bf16_small = in_f32 - in_bf16_big */
|
||||
/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */
|
||||
/* out_f32 = A_bf16_big * B_bf16_big */
|
||||
/* out_f32 += A_bf16_small * B_bf16_big */
|
||||
/* out_f32 += A_bf16_big * B_bf16_small */
|
||||
/************************************************************************/
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x32xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x32xf32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using I0 = Number<0>;
|
||||
vector_type<float, 8> reg_a_v(reg_a);
|
||||
vector_type<float, 8> reg_b_v(reg_b);
|
||||
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
|
||||
|
||||
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
|
||||
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
|
||||
|
||||
// Run 3 times: big*big, small*big, big*small
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_16x16x32bf16<16, 16>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x16xf32<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using I0 = Number<0>;
|
||||
vector_type<float, 8> reg_a_v(reg_a);
|
||||
vector_type<float, 8> reg_b_v(reg_b);
|
||||
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_a_bf16_small;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_big;
|
||||
vector_type<bhalf_t, 8> v_reg_b_bf16_small;
|
||||
|
||||
convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
|
||||
convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
|
||||
|
||||
// Run 3 times: big*big, small*big, big*small
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
intrin_mfma_f32_32x32x16bf16<32, 32>::Run(
|
||||
v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
|
||||
reg_c);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif // defined(__gfx950__)
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32/xf32 on gfx950 end ************************************/
|
||||
} // namespace ck
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
@@ -92,7 +94,8 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
in_right_pads_{input_right_pads},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op}
|
||||
out_element_op_{out_element_op},
|
||||
device_name_{ck::get_device_name()}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -112,6 +115,7 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
::std::string device_name_; // the device which this conv is compared with
|
||||
};
|
||||
|
||||
struct Invoker : public device::BaseInvoker
|
||||
@@ -251,10 +255,39 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
x);
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
if(arg.device_name_ == "gfx942")
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_wei));
|
||||
}
|
||||
else if(arg.device_name_ == "gfx950")
|
||||
{
|
||||
ck::bhalf_t v_in_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_in);
|
||||
ck::bhalf_t v_in_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_in - type_convert<float>(v_in_bf16_big));
|
||||
ck::bhalf_t v_wei_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_wei);
|
||||
ck::bhalf_t v_wei_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_wei - type_convert<float>(v_wei_bf16_big));
|
||||
|
||||
v_acc += ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_small) +
|
||||
ck::type_convert<float>(v_in_bf16_small) *
|
||||
ck::type_convert<float>(v_wei_bf16_big) +
|
||||
ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_big);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported device: " + arg.device_name_ +
|
||||
" for tf32 computation");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -350,10 +383,41 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
x);
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
if(arg.device_name_ == "gfx942")
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_wei));
|
||||
}
|
||||
else if(arg.device_name_ == "gfx950")
|
||||
{
|
||||
ck::bhalf_t v_in_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_in);
|
||||
ck::bhalf_t v_in_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_in - type_convert<float>(v_in_bf16_big));
|
||||
ck::bhalf_t v_wei_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_wei);
|
||||
ck::bhalf_t v_wei_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_wei -
|
||||
type_convert<float>(v_wei_bf16_big));
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_small) +
|
||||
ck::type_convert<float>(v_in_bf16_small) *
|
||||
ck::type_convert<float>(v_wei_bf16_big) +
|
||||
ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_big);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported device: " + arg.device_name_ +
|
||||
" for tf32 computation");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
@@ -45,7 +46,8 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
c_m_n_{c_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
device_name_{ck::get_device_name()}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -56,6 +58,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
::std::string device_name_; // the device which this gemm is compared with
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -142,12 +145,37 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t>)
|
||||
{ // only for tf32 now
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float> && is_same_v<AccDataType, float> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t> &&
|
||||
is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(arg.device_name_ == "gfx942")
|
||||
{
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
|
||||
}
|
||||
else if(arg.device_name_ == "gfx950")
|
||||
{
|
||||
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
|
||||
ck::bhalf_t v_a_bf16_small = ck::type_convert<ck::bhalf_t>(
|
||||
v_a - type_convert<float>(v_a_bf16_big));
|
||||
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
|
||||
ck::bhalf_t v_b_bf16_small = ck::type_convert<ck::bhalf_t>(
|
||||
v_b - type_convert<float>(v_b_bf16_big));
|
||||
|
||||
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_small) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_small) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported device: " + arg.device_name_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -82,9 +82,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
// multiply and accumulate
|
||||
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t>)
|
||||
{ // only for tf32 now
|
||||
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
|
||||
#elif defined(__gfx950__)
|
||||
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
|
||||
ck::bhalf_t v_a_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(v_a - type_convert<float>(v_a_bf16_big));
|
||||
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
|
||||
ck::bhalf_t v_b_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(v_b - type_convert<float>(v_b_bf16_big));
|
||||
|
||||
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_small) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_small) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big);
|
||||
#else
|
||||
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -105,7 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using INT8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
using TF32 = ck::tf32_t;
|
||||
#endif
|
||||
|
||||
@@ -228,7 +228,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -253,7 +253,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -280,7 +280,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -306,7 +306,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -331,7 +331,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -352,7 +352,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -373,7 +373,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -416,7 +416,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
@@ -439,7 +439,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user