mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Add bf16 and int8 wmma gemms for Navi3x and Navi4x. (#1671)
* add bf16 gemms for gfx11/gfx12
* reduce the input values in test_gemm
* add int8 wmma gemm instances for gfx11/gfx12
* add example gemm_wmma_int8
* fix bug in gemm_wmma_int8 test
* increase bf16 gemm test tolerance
* update the dates and clean-up commented-out instances
[ROCm/composable_kernel commit: 8aba2724cc]
This commit is contained in:
@@ -13,6 +13,11 @@ namespace ck {
|
||||
defined(__gfx1103__) || defined(__gfx11_generic__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
/********************************WAVE32 MODE***********************************************/
|
||||
|
||||
// src: fp16, dst: fp32
|
||||
@@ -99,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
|
||||
// opsel usage
|
||||
// false: D0.[0:15] = result
|
||||
// true : D0.[16:31]= result
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
|
||||
@@ -261,10 +266,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
|
||||
// gfx12
|
||||
/********************************WAVE32 MODE***********************************************/
|
||||
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
// src: fp16, dst: fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_wmma_f32_16x16x16_f16_w32_gfx12;
|
||||
|
||||
Reference in New Issue
Block a user