mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
[Navi3x-LWPCK-449] wmma_op + unit test (#484)
* wmma_op + unit test
* add arch limitation to wmma test
* change arch limitation
* Refactor + Add all type unit test(int4 compile failed)
* Add f32_16x16x16_bf16 unit test
* Remote int4 related
* delete deprecated test
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: abf9cc6c5c]
This commit is contained in:
@@ -25,7 +25,7 @@
|
||||
// check GPU target
|
||||
#ifdef __HIP_DEVICE_COMPILE__
|
||||
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx1030__))
|
||||
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__))
|
||||
#error Not supported target
|
||||
#endif
|
||||
#endif
|
||||
@@ -38,6 +38,8 @@
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx1030__) // for GPU code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx1100__) // for GPU code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000
|
||||
#endif
|
||||
|
||||
// FMA instruction
|
||||
@@ -62,6 +64,13 @@
|
||||
#define CK_USE_AMD_MFMA_BF16_1K_OP
|
||||
#endif
|
||||
|
||||
// WMMA instruction
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_WMMA
|
||||
#elif defined(__gfx1100__) // for GPU code
|
||||
#define CK_USE_AMD_WMMA
|
||||
#endif
|
||||
|
||||
// buffer load
|
||||
#define CK_USE_AMD_BUFFER_LOAD 1
|
||||
|
||||
|
||||
102
include/ck/utility/amd_wmma.hpp
Normal file
102
include/ck/utility/amd_wmma.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_AMD_WMMA_HPP
|
||||
#define CK_AMD_WMMA_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
// TODO: Add arch limitation
|
||||
namespace ck {
|
||||
|
||||
// wave32 only
|
||||
// src: fp16, dst: fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_wmma_f32_16x16x16_f16_w32;
|
||||
|
||||
template <>
|
||||
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
|
||||
}
|
||||
};
|
||||
|
||||
// src: bf16, dst: fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_wmma_f32_16x16x16_bf16_w32;
|
||||
|
||||
template <>
|
||||
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float8_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
|
||||
}
|
||||
};
|
||||
|
||||
// src: fp16, dst: fp16
|
||||
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
|
||||
struct intrin_wmma_f16_16x16x16_f16_w32;
|
||||
|
||||
template <index_t Opsel>
|
||||
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
// opsel usage
|
||||
// false: D0.[0:15] = result
|
||||
// true : D0.[16:31]= result
|
||||
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
|
||||
}
|
||||
};
|
||||
|
||||
// src: bf16, dst: bf16
|
||||
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
|
||||
struct intrin_wmma_bf16_16x16x16_bf16_w32;
|
||||
|
||||
template <index_t Opsel>
|
||||
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
// opsel usage
|
||||
// false: D0.[0:15] = result
|
||||
// true : D0.[16:31]= result
|
||||
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
|
||||
}
|
||||
};
|
||||
|
||||
// src: iu8, dst: i32
|
||||
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
|
||||
struct intrin_wmma_i32_16x16x16_iu8_w32;
|
||||
|
||||
template <bool neg_a, bool neg_b, bool clamp>
|
||||
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
||||
neg_a,
|
||||
bit_cast<int32x4_t>(reg_a),
|
||||
neg_b,
|
||||
bit_cast<int32x4_t>(reg_b),
|
||||
reg_c.template AsType<int32x8_t>()[Number<0>{}],
|
||||
clamp);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user