mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
* wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size * navi3x_multipleD+example * temp save * workable * batchedgemm[OK], groupconv[debug] * groupconv: Sanity check[OK], Performance[Bad] * navi3x_groupconv_need_optimization * create necessary files * save progress * Add Inter-Row thread transfer * save progress * save debugging progress * sanity check pass * fix a host tensor bug and clean up flash-attn code * format * cancel unnecessary change * cancel unnecessary change * cancel unnecessary change * temp save, add asm backend flag to amd_wmma * Mat-A LDS Bypass sanity pass * temp save * gemm sanity fix * Porting new blockwise gemm to flash attention * Example branch provide to compiler team * tempsave * Fix a bug * batched gemm ported * conv A-skip lds ported * Skip B-Lds real gemm * Skip B Lds Gemm + MulD * batched gemm, conv, skip b lds * format * Attn, skip b lds * Change GridwiseOp nam * fix a typo caused bug * Skip A_Lds sanity pass, Skip B_Lds scratch occured * Bug found, intra-row permute off caused * bug found * a fix * disable buffer load due to incorrect 3rd dword * update fmha config, no scratch generated * update 3rd dword * fmha config update * FMHA, add support to gfx1101/gfx1102 * Merge origin dev (#2) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * add vector load check * solve conflicts --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> * Disable SkipLDS & Align AIT api (#3) * fix layernorm, reduction Ops (#4) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * Disable SkipLDS & Align AIT api * Update dependabot config (#682) Co-authored-by: samjwu <samjwu@users.noreply.github.com> * update attn api * solve type_convert bug + enable --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> * fix typo * Fix attention with causal mask * multiple fix, try ait compile * Add A/B not use LDS pipeline * Clang format, Add gfx1101, gfx1102 support of FMHA example * cancel change of format script * 1. Enable 2-stage global Prefetch ( May cause VGPR spilling) 2. Enable FP16 accumulator blockwise_gemm * clang-format * 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement) 2. change kernel timing mode to 50 warmup + 50 timed repeat * Update low level abstration of blockwise gemm wmma * (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * update self-attention and cross-attention * fix a typo of name * Add arch limiter for fp8 gemm * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * re-enable the gemm_xdl_fp16_fp8 on MI100/200 --------- Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> Co-authored-by: illsilin <Illia.Silin@amd.com>
360 lines
15 KiB
C++
360 lines
15 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#ifndef CK_AMD_INLINE_ASM_HPP
|
|
#define CK_AMD_INLINE_ASM_HPP
|
|
|
|
#include "data_type.hpp"
|
|
#include "c_style_pointer_cast.hpp"
|
|
|
|
// TODO: deprecate all amd_assembly_outer_product_xxx
|
|
|
|
namespace ck {
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
|
{
|
|
asm volatile("\n \
|
|
v_fmac_f32 %0, %2, %3 \n \
|
|
v_fmac_f32 %1, %2, %4 \n \
|
|
"
|
|
: "=v"(c0), "=v"(c1)
|
|
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
|
}
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
// c2 += inner_product(a, b2)
|
|
// c3 += inner_product(a, b3)
|
|
__device__ void amd_assembly_outer_product_1x4(
|
|
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
|
|
{
|
|
asm volatile("\n \
|
|
v_fmac_f32 %0, %4, %5 \n \
|
|
v_fmac_f32 %1, %4, %6 \n \
|
|
v_fmac_f32 %2, %4, %7 \n \
|
|
v_fmac_f32 %3, %4, %8 \n \
|
|
"
|
|
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
|
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
|
}
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
__device__ void
|
|
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
|
|
{
|
|
asm volatile("\n \
|
|
v_dot2_f32_f16 %0, %2, %3, %0\n \
|
|
v_dot2_f32_f16 %1, %2, %4, %1\n \
|
|
"
|
|
: "=v"(c0), "=v"(c1)
|
|
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
|
}
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
__device__ void
|
|
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
|
{
|
|
// TODO remove pointer casting
|
|
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
|
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
|
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
|
|
|
// do dot2 two times
|
|
asm volatile("\n \
|
|
v_dot2_f32_f16 %0, %2, %4, %0\n \
|
|
v_dot2_f32_f16 %1, %2, %6, %1\n \
|
|
v_dot2_f32_f16 %0, %3, %5, %0\n \
|
|
v_dot2_f32_f16 %1, %3, %7, %1\n \
|
|
"
|
|
: "=v"(c0), "=v"(c1)
|
|
: "v"(p_a_half2[0]),
|
|
"v"(p_a_half2[1]),
|
|
"v"(p_b0_half2[0]),
|
|
"v"(p_b0_half2[1]),
|
|
"v"(p_b1_half2[0]),
|
|
"v"(p_b1_half2[1]),
|
|
"0"(c0),
|
|
"1"(c1));
|
|
}
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
// c2 += inner_product(a, b2)
|
|
// c3 += inner_product(a, b3)
|
|
__device__ void amd_assembly_outer_product_1x4(half2_t a,
|
|
half2_t b0,
|
|
half2_t b1,
|
|
half2_t b2,
|
|
half2_t b3,
|
|
float& c0,
|
|
float& c1,
|
|
float& c2,
|
|
float& c3)
|
|
{
|
|
asm volatile("\n \
|
|
v_dot2_f32_f16 %0, %4, %5, %0\n \
|
|
v_dot2_f32_f16 %1, %4, %6, %1\n \
|
|
v_dot2_f32_f16 %2, %4, %7, %2\n \
|
|
v_dot2_f32_f16 %3, %4, %8, %3\n \
|
|
"
|
|
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
|
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
|
}
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
// c2 += inner_product(a, b2)
|
|
// c3 += inner_product(a, b3)
|
|
__device__ void amd_assembly_outer_product_1x4(half4_t a,
|
|
half4_t b0,
|
|
half4_t b1,
|
|
half4_t b2,
|
|
half4_t b3,
|
|
float& c0,
|
|
float& c1,
|
|
float& c2,
|
|
float& c3)
|
|
{
|
|
// TODO remove pointer casting
|
|
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
|
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
|
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
|
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
|
|
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
|
|
|
|
// do dot2 two times
|
|
asm volatile("\n \
|
|
v_dot2_f32_f16 %0, %4, %6, %0\n \
|
|
v_dot2_f32_f16 %1, %4, %8, %1\n \
|
|
v_dot2_f32_f16 %2, %4, %10, %2\n \
|
|
v_dot2_f32_f16 %3, %4, %12, %3\n \
|
|
v_dot2_f32_f16 %0, %5, %7, %0\n \
|
|
v_dot2_f32_f16 %1, %5, %9, %1\n \
|
|
v_dot2_f32_f16 %2, %5, %11, %2\n \
|
|
v_dot2_f32_f16 %3, %5, %13, %3\n \
|
|
"
|
|
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
|
: "v"(p_a_half2[0]),
|
|
"v"(p_a_half2[1]),
|
|
"v"(p_b0_half2[0]),
|
|
"v"(p_b0_half2[1]),
|
|
"v"(p_b1_half2[0]),
|
|
"v"(p_b1_half2[1]),
|
|
"v"(p_b2_half2[0]),
|
|
"v"(p_b2_half2[1]),
|
|
"v"(p_b3_half2[0]),
|
|
"v"(p_b3_half2[1]),
|
|
"0"(c0),
|
|
"1"(c1),
|
|
"2"(c2),
|
|
"3"(c3));
|
|
}
|
|
|
|
__device__ void amd_assembly_outer_product_1x4(half8_t a,
|
|
half8_t b0,
|
|
half8_t b1,
|
|
half8_t b2,
|
|
half8_t b3,
|
|
float& c0,
|
|
float& c1,
|
|
float& c2,
|
|
float& c3)
|
|
{
|
|
|
|
// TODO remove pointer casting
|
|
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
|
|
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
|
|
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
|
|
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
|
|
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
|
|
|
|
amd_assembly_outer_product_1x4(
|
|
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
|
|
|
|
amd_assembly_outer_product_1x4(
|
|
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
|
|
}
|
|
|
|
__device__ void amd_assembly_outer_product_1x4(half16_t a,
|
|
half16_t b0,
|
|
half16_t b1,
|
|
half16_t b2,
|
|
half16_t b3,
|
|
float& c0,
|
|
float& c1,
|
|
float& c2,
|
|
float& c3)
|
|
{
|
|
// TODO remove pointer casting
|
|
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
|
|
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
|
|
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
|
|
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
|
|
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
|
|
|
|
amd_assembly_outer_product_1x4(
|
|
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
|
|
|
|
amd_assembly_outer_product_1x4(
|
|
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
|
|
}
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
__device__ void
|
|
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
|
|
{
|
|
#if 1
|
|
asm volatile("\n \
|
|
v_dot4_i32_i8 %0, %2, %3, %0\n \
|
|
v_dot4_i32_i8 %1, %2, %4, %1\n \
|
|
"
|
|
: "=v"(c0), "=v"(c1)
|
|
: "v"(bit_cast<int32_t>(a)),
|
|
"v"(bit_cast<int32_t>(b0)),
|
|
"v"(bit_cast<int32_t>(b1)),
|
|
"0"(c0),
|
|
"1"(c1));
|
|
#else
|
|
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
|
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
|
#endif
|
|
}
|
|
|
|
// c0 += inner_product(a, b0)
|
|
// c1 += inner_product(a, b1)
|
|
// c2 += inner_product(a, b2)
|
|
// c3 += inner_product(a, b3)
|
|
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
|
|
int8x4_t b0,
|
|
int8x4_t b1,
|
|
int8x4_t b2,
|
|
int8x4_t b3,
|
|
int32_t& c0,
|
|
int32_t& c1,
|
|
int32_t& c2,
|
|
int32_t& c3)
|
|
{
|
|
#if 1
|
|
asm volatile("\n \
|
|
v_dot4_i32_i8 %0, %4, %5, %0\n \
|
|
v_dot4_i32_i8 %1, %4, %6, %1\n \
|
|
v_dot4_i32_i8 %2, %4, %7, %2\n \
|
|
v_dot4_i32_i8 %3, %4, %8, %3\n \
|
|
"
|
|
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
|
: "v"(bit_cast<int32_t>(a)),
|
|
"v"(bit_cast<int32_t>(b0)),
|
|
"v"(bit_cast<int32_t>(b1)),
|
|
"v"(bit_cast<int32_t>(b2)),
|
|
"v"(bit_cast<int32_t>(b3)),
|
|
"0"(c0),
|
|
"1"(c1),
|
|
"2"(c2),
|
|
"3"(c3));
|
|
#else
|
|
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
|
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
|
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
|
|
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
|
|
#endif
|
|
}
|
|
|
|
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
|
|
int8x8_t b0,
|
|
int8x8_t b1,
|
|
int8x8_t b2,
|
|
int8x8_t b3,
|
|
int32_t& c0,
|
|
int32_t& c1,
|
|
int32_t& c2,
|
|
int32_t& c3)
|
|
{
|
|
constexpr auto I0 = Number<0>{};
|
|
constexpr auto I1 = Number<1>{};
|
|
|
|
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
|
|
c0,
|
|
c1,
|
|
c2,
|
|
c3);
|
|
|
|
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
|
|
c0,
|
|
c1,
|
|
c2,
|
|
c3);
|
|
}
|
|
|
|
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
|
|
int8x16_t b0,
|
|
int8x16_t b1,
|
|
int8x16_t b2,
|
|
int8x16_t b3,
|
|
int32_t& c0,
|
|
int32_t& c1,
|
|
int32_t& c2,
|
|
int32_t& c3)
|
|
|
|
{
|
|
constexpr auto I0 = Number<0>{};
|
|
constexpr auto I1 = Number<1>{};
|
|
constexpr auto I2 = Number<2>{};
|
|
constexpr auto I3 = Number<3>{};
|
|
|
|
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
|
|
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
|
|
c0,
|
|
c1,
|
|
c2,
|
|
c3);
|
|
|
|
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
|
|
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
|
|
c0,
|
|
c1,
|
|
c2,
|
|
c3);
|
|
|
|
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
|
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
|
|
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
|
|
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
|
|
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
|
|
c0,
|
|
c1,
|
|
c2,
|
|
c3);
|
|
|
|
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
|
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
|
|
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
|
|
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
|
|
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
|
|
c0,
|
|
c1,
|
|
c2,
|
|
c3);
|
|
}
|
|
|
|
} // namespace ck
|
|
#endif
|