mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
* [fix] align v3 gufusion pipeline * fix device kernel selection. * Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE * experimental optimization for scale load in blkscale gemm * Add asm for no-loop v3_128x128x128 * fix bugs * tune fp8 example * Update v1_128x128x128 to 2x2 instead of 4x1 * wip * add warmup to asm launch * wip2 * 16x16 function merged to moe * temp save, a performant version. * wip3 * Update .co binary to 16x16 * 16x16x128 correct; 64x64x128 failed * update * use mem_op::set when topk=1 * add mx fp8 b_preshuffle support, function not yet tested. * Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints * some fixes * fix update * remove some unnecessary hacky; enable 256x256x256 tilesize * update for function debug * Add pipeline v3. Have some runtime issue and register spill * Fix pipe v3 correctness issue * remove unnecessary hacky * clang format * fix a bug * fix the bug, functional test passed * tempsave; buggy at passed 4 e8m0 to scaled mfma * added fp4_bpreshuffle example, build failures * fixed some bugs * implement shuffled scale mxfp4gemm, blocker: opsel not effect * hotfix * fix bugs, build passed * (M, N, K)=(128, 128, 128) function failed. * temp save for gemm1. Function not ready * fix compile error. Gemm2 pass. Gemm1 WIP * fix bug for a lds read * update moe * Compile pass. Gemm1 function WIP * update moe * fix fp8; fix even/odd * tempsave * update moe * Revert "update" This reverts commit960b2bce1c. * Revert "use mem_op::set when topk=1" This reverts commitdef952a178. * Add v3 128x128x128_4x4_16x16.co for gfx950 * temp cmake flag suppression for aiter test * add code for mxfp4 gemm, blockscale not supported yet * gemm1 up-only pass. GU WIP * function pass with inline asm hacky * revert unexpected file change * updated and build passed * update CE elementOP * added code for debug * Gemm1 GUFusion function pass. Perf WIP * Fix fp8/bf8; remove duplicated code * disable the scheduler in v3; bring it back when compiler feature ready. * update moe v1 pipeline * Add gemm1 v1 32x128x128 * remove schedule barrier * updated * Fix fp8/bf8 B-row * mfma using asm, device result correct, host result need to check * gemm1 v3 64x128x128 debug * fix cpu ref * a/b thread_desc stride fix * Use random scale for init1 * 16x16x128 input size blockscale function passed * fix blockscale gemm bug * tempsave. Almost all instances passed. * v1 fix for mi350. * temp save * debug save * update debug * fix the bug, 128x128x256 tile function passed * v3 * rename moe block selector and pipeline * Add gemm1 v1 * Add gemm1 v1 to selector * added mx moe block v3 support, function passed * compile error fix * Improve the pipeline * Pack e8m0 as int32_t * v1 compile pass. Function not ready * debug synchronize issue over different GPU/ROCm * minor fix * Add profiler filter * Add f4 ckProfiler * Fix example compile error * Add f4 profiler examples * tempsave * v1 function pass. * v3 function pass * align file and function name * mx_moe_fp4 ready for aiter with clang-format. * modify the way we represent fp4 * generalize the pipeline scheduling. * init moe mx f4 scale shuffle * Cmakelist diable compiler-bound flags * mx_fp4 default parameter change * Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler. * update code * tempsave; modify the way we represent fp4 * generalize the pipeline scheduling. * Add gemm1 gfx942 .co support * updated code, build passed. * Update gemm2 asm with latest compiler flag * Fix mx f4 ckProfiler * Fix blockwise gemm mx v1 * lds conflict free + buffer load lds * Add gemm2 v3 64x128x128 * fix a, b scale loading bugs, a, b scale loading now correctly * Add gemm2 v3 64x128x128 * commit with debug info * fix fp4 profiler * Add mx fp4 pileline v1 instances * Fix v2 topk_weight cal. Add silu asm. * v2 tok_weight WIP * init mx fp4 B no preshuffle version * tempsave. compile pass, function wrong * enable fp4 moe no weigth preshuffle, function pass * update the TFlops calculation in the example * Add gemm2 64x128x128 asm. Fix BF16 ref. * fix 2 typos in fp4_preshuffle * Better kernel selection in device classes * correct preShuffleBuffer we should used packed k to do shuffle. * lds conflict free + buffer load lds * optimize offset math in dma * Fix fp4 ckProfiler * Fix MX MFMA tests * fix f4 pipeline issues * gemm1 func pass * update mx moe gemm1_bns tile size to 64x128x256 * update mx moe gemm1 gemm2 TF and BW calculation * fix typo * temp save * Fix example_gemm_mx build * rename the block pipeline * correct a typo in tail * Add rotating to mx examples * fix the correctness issue * Fix v1; use M padding * Add NT flag to B/BScale buffer * Merge gemm_mx_common.hpp * temp save, 4.4~4.5 * Fix 'Merge gemm_mx_common.hpp' * refactor the pipeline * Pad the M for scale buffer unconditionaly * update MX moe GEMM1 hotloopscheduling * change the gemm1 tile from 64x128x128 to 128x64x128 * Unconditional Ascale padding * Pad shuffled a scale only * pad ascale * add vmcnt guard for async copy * Profiler add f4 wp * Merge preshuffle device * Add more fp4 wp instances * Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format * Clang-format after 2 merges * Remove rocm6.3 workaround flags and macro * Fix fp8 config * Fix bf8 config * flag and barrier fix for copmiler branch MainOpSelV3 * Add fp8 profiler instances * Remove debug infos; Enable flags for blockscale f8 * No asm ver. for merging moe blocksale fp8 into mainline * update the flag name for f8blockscale * recover example * fix performance bug of bpreshuffle f8 gemm * clang format, remove single rate mfma restriction for f8 * remove single rate mfma restriction for f8 blockscale gemm * Fix moe blockscale gemm1 barrier 0x800 for new compiler * add pipeline v1 for MOE Gemm2 * Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns * Fix OOB; add MB96 instances * remove unnecessary files * fix the cmake issue * Enable splitk for mxfp4; clang format; * Generate random tensor values with multiple threads * Use packed_size_v for A/BPackedSize * Fix warning * Fix target_compile_options for disabled target on gfx942 * fix moe pki4 on gfx950 * doc the kGroup definition * Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale) * Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example * Fix unknown compiler flag * fix two failed examples. * fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16 * workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major * fix test_gemm_splitk; * Fix compile for mx_mfma_op * add mfma selection logic for multipled_v3 * Clean up * Fix device gemm mx link error * improve the global atomic pattern * Revert unnecessary copyright updates * restore minimum_occupancy logic * Avoid data race in moe gemm2 ref * Build fp8 gemm_multiply_multiply and moe only on gfx94/95 * update the instance in device_mx_gemm * Resolve comments * Copyright 2025 * Remove unused code * fix library linking issue --------- Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
1544 lines
52 KiB
C++
1544 lines
52 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
#include "ck/utility/dtype_fp64.hpp"
|
|
|
|
namespace ck {
|
|
// Define the common macro for MI300 models
|
|
#if defined(__gfx942__) || defined(__gfx950__)
|
|
#define __gfx94__
|
|
#endif
|
|
|
|
// fp32
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x1f32;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x1f32<64, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
|
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
|
reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
|
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x1f32<32, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
|
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x2f32;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x2f32<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x4f32;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x4f32<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x1f32;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x1f32<16, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_4x4x1f32;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_4x4x1f32<4, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_4x4x1f32<8, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
|
reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
|
}
|
|
};
|
|
|
|
// fp16
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x4f16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x4f16<64, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
|
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
|
reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
|
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x4f16<32, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
|
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x16f16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x16f16<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif // defined(__gfx950__)
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x32f16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x32f16<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif // defined(__gfx950__)
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x8f16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x8f16<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x16f16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x16f16<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x4f16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x4f16<16, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_4x4x4f16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_4x4x4f16<4, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_4x4x4f16<8, 64>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
|
reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
|
}
|
|
};
|
|
|
|
// bfp16
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x16bf16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x16bf16<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif // defined(__gfx950__)
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x32bf16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x32bf16<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif // defined(__gfx950__)
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x8bf16_1k;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x16bf16_1k;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x4bf16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x4bf16<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
|
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x8bf16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x8bf16<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
|
|
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_i32_32x32x8i8;
|
|
|
|
template <>
|
|
struct intrin_mfma_i32_32x32x8i8<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
|
|
bit_cast<int32_t>(reg_b),
|
|
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_i32_16x16x16i8;
|
|
|
|
template <>
|
|
struct intrin_mfma_i32_16x16x16i8<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
|
|
bit_cast<int32_t>(reg_b),
|
|
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_i32_32x32x32i8;
|
|
|
|
template <>
|
|
struct intrin_mfma_i32_32x32x32i8<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
|
|
reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif // defined(__gfx950__)
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_i32_16x16x64i8;
|
|
|
|
template <>
|
|
struct intrin_mfma_i32_16x16x64i8<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
|
reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif // defined(__gfx950__)
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_i32_32x32x16i8;
|
|
|
|
template <>
|
|
struct intrin_mfma_i32_32x32x16i8<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
|
|
bit_cast<int64_t>(reg_b),
|
|
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_i32_16x16x32i8;
|
|
|
|
template <>
|
|
struct intrin_mfma_i32_16x16x32i8<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
|
|
bit_cast<int64_t>(reg_b),
|
|
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f64_16x16x4f64;
|
|
|
|
template <>
|
|
struct intrin_mfma_f64_16x16x4f64<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx90a__) || defined(__gfx94__)
|
|
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
|
|
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x64f8f6f4;
|
|
|
|
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
|
|
/// and f4 data types.
|
|
///
|
|
/// @note Calls scaled version of the instruction as the original instruction is not supported in
|
|
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
|
|
/// operation if the scale is 0.
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
1, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
1, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
|
|
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
|
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
4, // cbsz
|
|
4, // blgp
|
|
0, // OPSEL
|
|
0,
|
|
0, // OPSEL
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
2, // blgp
|
|
0, // OPSEL
|
|
0,
|
|
0, // OPSEL
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
3, // blgp
|
|
0, // OPSEL
|
|
0,
|
|
0, // OPSEL
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
|
|
struct intrin_mfma_scale_f32_32x32x64f8f6f4;
|
|
|
|
template <index_t OpselA, index_t OpselB>
|
|
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x32_t& reg_a,
|
|
const int32_t& scale_a,
|
|
const f8x32_t& reg_b,
|
|
const int32_t& scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
// XXX: Note on the scale_a and scale_b parameters:
|
|
// If compiler detects that one or both scales are constant values, it will treat that
|
|
// constant as F32 constant. I.e., if scale_a at some point was declared as
|
|
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
|
|
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
|
|
|
|
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
|
|
// when OPSEL is set otherwise.
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a,
|
|
const int32_t& scale_a,
|
|
const bf8x32_t& reg_b,
|
|
const int32_t& scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
1, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
// XXX: Note on the scale_a and scale_b parameters:
|
|
// If compiler detects that one or both scales are constant values, it will treat that
|
|
// constant as F32 constant. I.e., if scale_a at some point was declared as
|
|
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
|
|
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
|
|
|
|
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
|
|
// when OPSEL is set otherwise.
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a,
|
|
const int32_t& scale_a,
|
|
const f8x32_t& reg_b,
|
|
const int32_t& scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
// XXX: Note on the scale_a and scale_b parameters:
|
|
// If compiler detects that one or both scales are constant values, it will treat that
|
|
// constant as F32 constant. I.e., if scale_a at some point was declared as
|
|
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
|
|
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
|
|
|
|
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
|
|
// when OPSEL is set otherwise.
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f6x32_t& reg_a,
|
|
const int32_t scale_a,
|
|
const f6x32_t& reg_b,
|
|
const int32_t scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
2, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf6x32_t& reg_a,
|
|
const int32_t scale_a,
|
|
const bf6x32_t& reg_b,
|
|
const int32_t scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
3, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f4x32_t& reg_a,
|
|
const int32_t scale_a,
|
|
const f4x32_t& reg_b,
|
|
const int32_t scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
|
|
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
|
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
4, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
|
|
struct intrin_mfma_scale_f32_16x16x128f8f6f4;
|
|
|
|
template <index_t OpselA, index_t OpselB>
|
|
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x32_t& reg_a,
|
|
const int32_t& scale_a,
|
|
const f8x32_t& reg_b,
|
|
const int32_t& scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a,
|
|
const int32_t& scale_a,
|
|
const bf8x32_t& reg_b,
|
|
const int32_t& scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
1, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x32_t& reg_a,
|
|
const int32_t& scale_a,
|
|
const bf8x32_t& reg_b,
|
|
const int32_t& scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
1, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a,
|
|
const int32_t& scale_a,
|
|
const f8x32_t& reg_b,
|
|
const int32_t& scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f6x32_t& reg_a,
|
|
const int32_t scale_a,
|
|
const f6x32_t& reg_b,
|
|
const int32_t scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
2, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf6x32_t& reg_a,
|
|
const int32_t scale_a,
|
|
const bf6x32_t& reg_b,
|
|
const int32_t scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
3, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f4x32_t& reg_a,
|
|
const int32_t scale_a,
|
|
const f4x32_t& reg_b,
|
|
const int32_t scale_b,
|
|
FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
|
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
|
using arg_type = int32x8_t;
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
4, // cbsz
|
|
4, // blgp
|
|
OpselA, // OPSEL
|
|
scale_a,
|
|
OpselB, // OPSEL
|
|
scale_b);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = scale_a;
|
|
ignore = reg_b;
|
|
ignore = scale_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x128f8f6f4;
|
|
|
|
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
|
|
/// data types.
|
|
///
|
|
/// @note Calls scaled version of the instruction as the original instruction is not supported in
|
|
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
|
|
/// operation if the scale is 0.
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
1, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
0, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
1, // blgp
|
|
0,
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
|
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
4, // cbsz
|
|
4, // blgp
|
|
0, // OPSEL
|
|
0,
|
|
0, // OPSEL
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
2, // blgp
|
|
0, // OPSEL
|
|
0,
|
|
0, // OPSEL
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx950__)
|
|
int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
|
|
int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
|
|
|
|
using arg_type = int32x8_t;
|
|
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
|
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
|
|
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
|
3, // blgp
|
|
0, // OPSEL
|
|
0,
|
|
0, // OPSEL
|
|
0);
|
|
#else
|
|
ignore = reg_a;
|
|
ignore = reg_b;
|
|
ignore = reg_c;
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x16f8f8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x16f8f8<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<f8_t, 8> reg_a_v(reg_a);
|
|
vector_type<f8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x32f8f8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<f8_t, 8> reg_a_v(reg_a);
|
|
vector_type<f8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x16bf8bf8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
|
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x32bf8bf8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
|
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x16f8bf8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<f8_t, 8> reg_a_v(reg_a);
|
|
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x32f8bf8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<f8_t, 8> reg_a_v(reg_a);
|
|
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_32x32x16bf8f8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
|
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
|
vector_type<f8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_16x16x32bf8f8;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
|
{
|
|
#if defined(__gfx94__)
|
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
|
bit_cast<long>(reg_a),
|
|
bit_cast<long>(reg_b),
|
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
#else
|
|
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
|
vector_type<f8_t, 8> reg_b_v(reg_b);
|
|
|
|
static_for<0, 8, 1>{}([&](auto k) {
|
|
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
|
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
|
|
|
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
|
});
|
|
#endif
|
|
}
|
|
};
|
|
|
|
} // namespace ck
|