Files
composable_kernel/include/ck/utility/amd_inline_asm.hpp
zjing14 1837040a9c Navi3 rel (#1176)
* wmma_op + unit test

* add arch limitation to wmma test

* change arch limitation

* Refactor + Add all type unit test(int4 compile failed)

* Add f32_16x16x16_bf16 unit test

* tempsave

* tempsave

* tempsave

* runtime bug, cannot find symbol

* workaround for incorrect HIP warpSize return value

* debugging

* tempsave

* Correctness OK, waiting for optimization

* Tidy up + format

* temp save

* temp save, reproduce the v_bfi_b32 issue

* add inline asm for wmmaop test

* tidy up

* clean some debug purpose code

* discard some codes

* clang format

* clang format

* compiler issue fixed + increase tile size

* navi3x_multipleD+example

* temp save

* workable

* batchedgemm[OK], groupconv[debug]

* groupconv: Sanity check[OK], Performance[Bad]

* navi3x_groupconv_need_optimization

* create necessary files

* save progress

* Add Inter-Row thread transfer

* save progress

* save debugging progress

* sanity check pass

* fix a host tensor bug and clean up flash-attn code

* format

* cancel unnecessary change

* cancel unnecessary change

* cancel unnecessary change

* temp save, add asm backend flag to amd_wmma

* Mat-A LDS Bypass sanity pass

* temp save

* gemm sanity fix

* Porting new blockwise gemm to flash attention

* Example branch provide to compiler team

* tempsave

* Fix a bug

* batched gemm ported

* conv A-skip lds ported

* Skip B-Lds real gemm

* Skip B Lds Gemm + MulD

* batched gemm, conv, skip b lds

* format

* Attn, skip b lds

* Change GridwiseOp nam

* fix a typo caused bug

* Skip A_Lds sanity pass, Skip B_Lds scratch occured

* Bug found, intra-row permute off caused

* bug found

* a fix

* disable buffer load due to incorrect 3rd dword

* update fmha config, no scratch generated

* update 3rd dword

* fmha config update

* FMHA, add support to gfx1101/gfx1102

* Merge origin dev (#2)

* [Navi3x] Fix Gridwise_multiple_d operation (#649)

* Add CMake Option "USE_OPT_NAVI3X"

* fix bug

* standardize docs (#655)

* Separate bibtex requirement from rocm-docs-core (#656)

* separate bibtex requirement from rocm-docs-core

* point requirements to source rocm-docs-core repo

* Add CMake Option "USE_OPT_NAVI3X" (#647)

* Add CMake Option "USE_OPT_NAVI3X"

* remove navi3x opt compile option from cmake script

* Conv + quantization + tanh  (#645)

* Rename file. Prepare to support another activation

* Add comment for quantization

* Extract out_elementop

* Add tanh example

* Add conv + bias + tanh quantization instance

* Add missing parameter

* Refine cmake

* Add external api and client example

* Extract variable in example

* Fix the comment

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

* Add a denorm test fix (#603)

* Add type_convert implementations for bf16

* Add the fix for conv_fwd

* Add the fix for conv_bwd_data

* Add the fix for conv_bwd_weight

* Format

* Format

* Another format

* Add a macro to use workaround on MI200 only

* Format

---------

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>

* simplify karg in device/grid of split-k op (#644)

* simplify karg in device/grid split-k op

* fix mk_kn_mn instances

* add more instances

* use name from tensor layout

* fix 3rd dword of buffer source descriptor (#659)

* add fp64 instances (#658)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665)

This reverts commit bb5530af91.

* Groupnorm + swish external api (#668)

* Rename to proper naming

* Add example of groupnorm + swish

* Extract duplicate code in example

* Add groupnorm + swish instances

* Ractor instance generation, split into multiple cpp file

* Add external api and client example

* Refine profiler message

* Use ck math version of exp

* Refine problem size in example

* Add host version of exp

* add a marco to turn on/off denorm fix (off by default) (#673)

* add a marco to turn off denorm fix by default

* expose the marco

---------

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* fixed quant example (#672)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Add dependabot config and pin rocm-docs-core (#663)

* [gtest] suppress unsafe buffer warn (#670)

ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912

* Add memory index guard in wmma device ops (#667)

* Add more macros to turn on/off denorm fix (#678)

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>

* Fix a typo (#676)

* Add (#677)

* Allow using ROCm release candidate compilers. (#679)

* enable use of rocm5.5 release candidate 4

* upgrade to ROCM5.5 RC5

* try fix the PUB_KEY error, remove the cmake-data package

* upgrade to latest cmake version

* use private dockerhub repo for rocm5.5 rc5

* add missing bracket

* add vector load check

* solve conflicts

---------

Co-authored-by: Sam Wu <sjwu@ualberta.ca>
Co-authored-by: Sam Wu <sam.wu2@amd.com>
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
Co-authored-by: Jun Liu <Liu.Jun@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

* Disable SkipLDS & Align AIT api (#3)

* fix layernorm, reduction Ops (#4)

* [Navi3x] Fix Gridwise_multiple_d operation (#649)

* Add CMake Option "USE_OPT_NAVI3X"

* fix bug

* standardize docs (#655)

* Separate bibtex requirement from rocm-docs-core (#656)

* separate bibtex requirement from rocm-docs-core

* point requirements to source rocm-docs-core repo

* Add CMake Option "USE_OPT_NAVI3X" (#647)

* Add CMake Option "USE_OPT_NAVI3X"

* remove navi3x opt compile option from cmake script

* Conv + quantization + tanh  (#645)

* Rename file. Prepare to support another activation

* Add comment for quantization

* Extract out_elementop

* Add tanh example

* Add conv + bias + tanh quantization instance

* Add missing parameter

* Refine cmake

* Add external api and client example

* Extract variable in example

* Fix the comment

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

* Add a denorm test fix (#603)

* Add type_convert implementations for bf16

* Add the fix for conv_fwd

* Add the fix for conv_bwd_data

* Add the fix for conv_bwd_weight

* Format

* Format

* Another format

* Add a macro to use workaround on MI200 only

* Format

---------

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>

* simplify karg in device/grid of split-k op (#644)

* simplify karg in device/grid split-k op

* fix mk_kn_mn instances

* add more instances

* use name from tensor layout

* fix 3rd dword of buffer source descriptor (#659)

* add fp64 instances (#658)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665)

This reverts commit bb5530af91.

* Groupnorm + swish external api (#668)

* Rename to proper naming

* Add example of groupnorm + swish

* Extract duplicate code in example

* Add groupnorm + swish instances

* Ractor instance generation, split into multiple cpp file

* Add external api and client example

* Refine profiler message

* Use ck math version of exp

* Refine problem size in example

* Add host version of exp

* add a marco to turn on/off denorm fix (off by default) (#673)

* add a marco to turn off denorm fix by default

* expose the marco

---------

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* fixed quant example (#672)

Co-authored-by: root <root@ctr-ubbsmc15.amd.com>

* Add dependabot config and pin rocm-docs-core (#663)

* [gtest] suppress unsafe buffer warn (#670)

ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912

* Add memory index guard in wmma device ops (#667)

* Add more macros to turn on/off denorm fix (#678)

Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>

* Fix a typo (#676)

* Add (#677)

* Allow using ROCm release candidate compilers. (#679)

* enable use of rocm5.5 release candidate 4

* upgrade to ROCM5.5 RC5

* try fix the PUB_KEY error, remove the cmake-data package

* upgrade to latest cmake version

* use private dockerhub repo for rocm5.5 rc5

* add missing bracket

* Disable SkipLDS & Align AIT api

* Update dependabot config (#682)

Co-authored-by: samjwu <samjwu@users.noreply.github.com>

* update attn api

* solve type_convert bug + enable

---------

Co-authored-by: Sam Wu <sjwu@ualberta.ca>
Co-authored-by: Sam Wu <sam.wu2@amd.com>
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
Co-authored-by: Jun Liu <Liu.Jun@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: samjwu <samjwu@users.noreply.github.com>
Co-authored-by: haocwang <Haocong.WANG@amd.com>

* fix typo

* Fix attention with causal mask

* multiple fix, try ait compile

* Add A/B not use LDS pipeline

* Clang format, Add gfx1101, gfx1102 support of FMHA example

* cancel change of format script

* 1. Enable 2-stage global Prefetch ( May cause VGPR spilling)
2. Enable FP16 accumulator blockwise_gemm

* clang-format

* 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement)
2. change kernel timing mode to 50 warmup + 50 timed repeat

* Update low level abstration of blockwise gemm wmma

* (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (4/5) grouped conv pass

* (5/5) attention pass, todo: debug lds perf bug

* AIT Attention API refactor (#8)

* sanity pass

* sanity pass 2

* confirm significant performance regression.

* turn on all instances

* turn off instance format

* Fix bug & tunning & format

* DML meta, self_attn+cross_attn

* sanity pass

* remove useless flag

* update tile and problem size used in AIT attention

* bug fix in grouped conv supporting check

* deprecate inline asm wmma

* Bug fix: double lds skip

* clang-format

* Fix errors in
1. example, fmha
2. gridwise pipeline
3. deviceop, fmha, change some containers from vector to array

* part2 of previous commit

* clang format

* API fix of gridwisegemmpipeline

* separate array base and vector base attention tensor transformation

* fix gemm

* clang format

* add gemm fp16 instances

* Temp save

* fpAintB kernel compile pass

* Sanity pass.

* Temp save

* debug code enabled

* Fp16AInt8B_GEMM sanity

* MQA implementation

* GQA-4 example

* tempsave

* Compile pass

* New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm

* format

* Todo: fix gemm_bilinear_wmma instances compilation bug

* Solve a bug when K1=16

* remove unnecessary changes

* Remove tensor layout limitation to LDS usage in tesnor contraction

* update self-attention and cross-attention

* fix a typo of name

* Add arch limiter for fp8 gemm

* enable fp8 gemm_xdl for all gfx9 targets

* temporarily disable gemm_xdl_fp16_fp8 on MI100/200

* fix the cmake logic for gemm_xdl_fp16_fp8

* re-enable the gemm_xdl_fp16_fp8 on MI100/200

---------

Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: Sam Wu <sjwu@ualberta.ca>
Co-authored-by: Sam Wu <sam.wu2@amd.com>
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
Co-authored-by: Jun Liu <Liu.Jun@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: samjwu <samjwu@users.noreply.github.com>
Co-authored-by: haocwang <Haocong.WANG@amd.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
2024-03-08 17:11:51 -08:00

360 lines
15 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace ck {
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{
asm volatile("\n \
v_fmac_f32 %0, %2, %3 \n \
v_fmac_f32 %1, %2, %4 \n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{
asm volatile("\n \
v_fmac_f32 %0, %4, %5 \n \
v_fmac_f32 %1, %4, %6 \n \
v_fmac_f32 %2, %4, %7 \n \
v_fmac_f32 %3, %4, %8 \n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
{
asm volatile("\n \
v_dot2_f32_f16 %0, %2, %3, %0\n \
v_dot2_f32_f16 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
// do dot2 two times
asm volatile("\n \
v_dot2_f32_f16 %0, %2, %4, %0\n \
v_dot2_f32_f16 %1, %2, %6, %1\n \
v_dot2_f32_f16 %0, %3, %5, %0\n \
v_dot2_f32_f16 %1, %3, %7, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]),
"0"(c0),
"1"(c1));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half2_t a,
half2_t b0,
half2_t b1,
half2_t b2,
half2_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
asm volatile("\n \
v_dot2_f32_f16 %0, %4, %5, %0\n \
v_dot2_f32_f16 %1, %4, %6, %1\n \
v_dot2_f32_f16 %2, %4, %7, %2\n \
v_dot2_f32_f16 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half4_t a,
half4_t b0,
half4_t b1,
half4_t b2,
half4_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
// do dot2 two times
asm volatile("\n \
v_dot2_f32_f16 %0, %4, %6, %0\n \
v_dot2_f32_f16 %1, %4, %8, %1\n \
v_dot2_f32_f16 %2, %4, %10, %2\n \
v_dot2_f32_f16 %3, %4, %12, %3\n \
v_dot2_f32_f16 %0, %5, %7, %0\n \
v_dot2_f32_f16 %1, %5, %9, %1\n \
v_dot2_f32_f16 %2, %5, %11, %2\n \
v_dot2_f32_f16 %3, %5, %13, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]),
"v"(p_b2_half2[0]),
"v"(p_b2_half2[1]),
"v"(p_b3_half2[0]),
"v"(p_b3_half2[1]),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
}
__device__ void amd_assembly_outer_product_1x4(half8_t a,
half8_t b0,
half8_t b1,
half8_t b2,
half8_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}
__device__ void amd_assembly_outer_product_1x4(half16_t a,
half16_t b0,
half16_t b1,
half16_t b2,
half16_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)),
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)),
"v"(bit_cast<int32_t>(b2)),
"v"(bit_cast<int32_t>(b3)),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif
}
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int8x8_t b0,
int8x8_t b1,
int8x8_t b2,
int8x8_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
}
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int8x16_t b0,
int8x16_t b1,
int8x16_t b2,
int8x16_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
c0,
c1,
c2,
c3);
}
} // namespace ck
#endif