mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
* enable batched_gemm_softmax_gemm_perm_wmma for gfx12 * disable instances with blocksize=256 in attention examples * debuggging * debug * fixed lds_enabled * debugging * Fix and add limit to skiplds feature * Enable skipLds feature and fix compilation bugs * add ck_tile definitions for gfx12 * fix clang format and test/wmma_op * updage instances cmake for gfx12 * disable the test_wmma_op on gfx12 * fix the builds for gfx950 * add gfx12 and gfx950 to default target list * clean-up cmake file * Initial introduction of OFP8 data types. * Renamed FP8 and BF8 tests into FP8_FNUZ and BF8_FNUZ. * Implementation of ConvertFP32Nearest in test_fp8_ocp. * Remove dependence on possibly undeclared alias. * Implement FP8OCP test for stochastic rounding mode. * Implement FP8OCP tests for half_t type conversions. * enable bf16 atomic add on gfx950 * Implement ConvertFP32Nearest test. * Implement ConvertFP32Stochastic test. * Implement ConvertFP16Nearest and ConvertFP16Stochastic tests. * Refactoring. Move FP8 definitions into a separate header file. * Enable easy switching between architectures. * Fix compilation error for gfx942 architecture. * Add fp4 type with constants * only builf gfx950 branch for gfx950 target by default * Enable OCP build of example_gemm_xdl_fp8. * Fix formatting. * fix the build logic for gfx950 * Improve GEMM example verbosity. * Add constexpr where applicable. * fix the logic of enabling XDL and WMMA instances * Improve GEMM example verbosity. * Enable build of example_gemm_xdl_fp8_bf8 test. * Fix tests for gfx1101 architecture. * Build DPP examples only on gfx103 and gfx11 architectures. * Optionaly run either CPU or GPU verifications with GEMM examples. * Extend GeneratorTensor_Sequential to produce values of prescribed data types. * Add missing constructor. * Add scale type and mxfp conversions * Update conversions * Add conversion tests * Fix typo * Improve infrastructure for OFP8 data type support. * BUGFIX. Should not use FP8 as Compute/Accum data type. * Add custom target for grouped_convnd_bwd_weight tests. * Can build `tests` target on gfx950. * Bugfixes on gfx1101 architecture. * Fix dependencies. * Add stochastic rounding tests * Provide single point of truth for FP8 INF and NAN checks * Prevent instantiation of operators that are not supported by FP8 data types * Add FP8 type selection into client_axample CMakeLists.txt * Prevent sccache server from shutting down during build * Fix test success reporting logic * Change default verification method to CPU. GPU verification takes too much time to complete on the emulator. * Add scale <-> float conversions * Add scaled conversions with tests * Add device conversions * Make sure all tests and examples are built for gfx950 * Facilitate testing of FP8 data types on the emulator * Introduce two new tensor generators * Enable instances built for gfx94 to be built on gfx950 * Verify 35_splitk_gemm on floating point numbers. splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved. * Format * Verify 04_gemm_add_add_fastgelu on floating point numbers * Verify 20_grouped_conv_bwd_weight on floating point numbers * Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers * Verify more tests on floating point data * Fix data types and improve testing verbocity. * Add fp4 vectors * Add debug tests * Upgrade to NPI 573 build docker. * Skip on gemm_universal tests. The tests take too long to complete on the emulator. Need to see if it is possible to reduce the scope of the testing to just FP8 data types. * Add new mfma instructions and examples * Add preprocessor directives for gfx950 specific code * Fix gfx1101 build * Document test availability * Re-enable fp8 gemms for gfx94/95 * Cherry-pick GEMM Universal tests for FP8 data types * Cleanup * Add vector types and tests * Add check_err function * Add tensor generators * CK_USE_GFX94 has already been set on this branch * Fix * Address formatting issues and leftovers * Make fail/pass logic consistent within 01_gemm folder Removed multiple negations in fail/pass logic to propagate `true` as the success indicator. * Fix GPU verification reporting logic. * Update year in copyright notice. * Cleanup * Use `enum class` instead of `enum` * Remove set_property for FP8 tests * Add vector conversions * Fix * Fix linker errror * Clean up * Fix gfx950 conversions * Clean up * Fix more gfx950 conversions * Fix even more gfx950 conversions * Narrowing the scope of PR to OCP FP8 enablement only * Add tests for OCP FP8 vector_type storage * Fix client examples build * Fix typo * Update e8m0 casting * Rename E8M0 type * Update unpack method * Cleanup merge artifacts * Enable gemm kernel on all gfx9 architectures (#227) * clean-up * Implement `non_native_vector_base` with `ext_vector_type` array. (#232) * Enable support of 1, 2, 4, and 8-byte custom types in CK. * Fix pool tests for OCP FP8 data type * Fix build * Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on gfx950 * fix clang format * Add new mfma instructions and examples * Add preprocessor directives for gfx950 specific code * Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on gfx950 * fix clang format * Fix clang format for the newly merged files * Use the existing example instances for fp16 bf16 and int8 * Remove comment on new mfma instructions in MfmaInstr * Update include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * merge from public repo * Fix ck build * Fix ck build * Use double for max_abs_in_val * Move scaled_type_convert functions to a separate header (#251) * re-enable building mha lib and gemm_universal_f8 instances for gfx950 * Update library/src/tensor_operation_instance/gpu/CMakeLists.txt Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * fix typo for CK_USE_OCP_FP8 * fix typo for CK_USE_OCP_FP8 * Add FP6 and BF6 types (#261) * Add a rounding flag * Add FP6 and BF6 * Add tests Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * Clean up --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * fix one more typo * Refactor E8M0 scale implementation (#262) * Refactor E8M0 scale implementation * Add MXFP6 and MXBF6 conversion methods (#270) * Add conversions * Add tests * Add docstrings * Add scaled conversions * Add fp6/bf6 tests * Remove misleading fp4 test case * Add docstrings * Clean up * Address comments * Set stricter tolerances for RNE tests * Add missing tests * Add native conversions to float * Revert "Add native conversions to float" This reverts commit 09467111f73b753c8cc3d597533b187940353dab. * Update copyright years * replace the fp6 with bf6 convert calls in test_bf6 * fix test_bf6 * enable smfmac test * [MX FP8] Add Scaled Type Convert Functions for OCP FP8/BF8 data types (#271) * Move scaled_type_convert functions to a separate header * Introduce MX data tests * Build MX tests only on relevant architectures * Refactor E8M0 scale implementation * Fix `config.h` typo * Cleanup deprecated symbols * Refactor `amd_ck_fp8.hpp` * `scaled_type_convert` for `f8_ocp_t` * Implement test for MX FP8 scaled type convert * Implement test for MX BF8 scaled type convert * Scaled type convert for vectors of 2 FP8 elements * Scaled type convert for vectors of 16 FP8 elements * Implementation of scaled conversion from F32 to F8 * Add tests for scaled conversions from FP32 to FP8 * Add documentation to the test functions * Implementation of scaled conversion from F32x2 to F8x2 * Implementation of scaled conversion from F32x16 to F8x16 * Implementation of scaled conversion from F32x32 to F8x32 * Implementation of scaled conversion from F8x32 to F32x32 * Verified on the emulator * MX FP GEMM - Example Template (#277) Temporarily uses `DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3` kernel and 128x128 scaling matrices. Must be modified to use MX-native GEMM kernell with 16 or 32 component vectors per scale. Verified on the emulator. * Add vector support * Add tests * Add missing type aliases * Fix test naming * only build mx example for gfx950 * disable CK_USE_AMD_MFMA_GFX950 by default * fic build for multiple archs * fix typo * fix typo * Update unpack signature * Fix merge * Add size checks in pack function * Add a flag * Add conversions * Fix build logic * Update pack/unpack methods * Remove unneeded AsType accessors * Add docstrings * Add a flag to config file * Test the functionality of V_MFMA_F32_16X16X128_F8F6F4 and V_MFMA_F32_32X32X64_F8F6F4 instructions. (#293) * Introduced MFMA tests * Verified f8f6f4 MFMA Instructions * Move flag logic to scaled_type_convert header * Use pointers instead of array indices * Fix a typo * Update tests and pack functions * Fix gemm gemm on gfx950 * Fix clang format * restore the default gput target lists * fix the jenkinsfile * add missing ifdef --------- Co-authored-by: Jing Zhang <jizhan@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: root <root@banff-cyxtera-s83-2.ctr.dcgpu> Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: jefyang1 <146495389+jefyang1@users.noreply.github.com> Co-authored-by: jefyang1 <Jeffreyj.Yang@amd.com>
1064 lines
48 KiB
C++
1064 lines
48 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
#include "data_type.hpp"
|
|
|
|
namespace ck {
|
|
|
|
template <typename T>
|
|
union BufferResource
|
|
{
|
|
__device__ constexpr BufferResource() : content{} {}
|
|
|
|
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
|
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
|
int32x4_t content;
|
|
StaticallyIndexedArray<T*, 2> address;
|
|
StaticallyIndexedArray<int32_t, 4> range;
|
|
StaticallyIndexedArray<int32_t, 4> config;
|
|
};
|
|
|
|
template <typename T>
|
|
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
|
|
{
|
|
BufferResource<T> wave_buffer_resource;
|
|
|
|
// wavewise base address (64 bit)
|
|
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
|
// wavewise range (32 bit)
|
|
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
|
|
// wavewise setting (32 bit)
|
|
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
|
|
|
return wave_buffer_resource.content;
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave)
|
|
{
|
|
BufferResource<T> wave_buffer_resource;
|
|
|
|
// wavewise base address (64 bit)
|
|
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
|
// wavewise range (32 bit)
|
|
wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range
|
|
// wavewise setting (32 bit)
|
|
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
|
|
|
return wave_buffer_resource.content;
|
|
}
|
|
|
|
// buffer load i8
|
|
__device__ int8_t
|
|
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
|
|
|
__device__ int8x2_t
|
|
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
|
|
|
__device__ int8x4_t
|
|
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
|
|
|
// buffer load i16
|
|
__device__ bhalf_t
|
|
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
|
|
|
|
__device__ bhalf2_t
|
|
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
|
|
|
|
__device__ bhalf4_t
|
|
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
|
|
|
|
// buffer load i32
|
|
__device__ int32_t
|
|
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
|
|
|
__device__ int32x2_t
|
|
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
|
|
|
__device__ int32x4_t
|
|
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
|
|
|
// buffer load fp16
|
|
__device__ half_t
|
|
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
|
|
|
__device__ half2_t
|
|
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
|
|
|
__device__ half4_t
|
|
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
|
|
|
// buffer load fp32
|
|
__device__ float
|
|
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
|
|
|
|
__device__ float2_t
|
|
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
|
|
|
|
__device__ float4_t
|
|
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
|
|
|
|
// buffer store i8
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
|
|
|
// buffer store i16
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
|
|
|
// buffer store i32
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
|
|
|
// buffer store fp16
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
|
|
|
// buffer store fp32
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
|
|
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
|
|
|
// buffer atomic-add fp16
|
|
__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
|
half2_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
|
|
|
// buffer atomic-add i32
|
|
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
|
int32_t vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
|
|
|
|
// buffer atomic-add fp32
|
|
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
|
float vdata,
|
|
int32x4_t rsrc,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
|
|
|
// buffer atomic-add fp32
|
|
__device__ double
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
|
int32x4_t rsrc, // dst_wave_buffer_resource
|
|
int voffset, // dst_thread_addr_offset
|
|
int soffset, // dst_wave_addr_offset
|
|
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
|
|
|
// memory coherency bit for buffer store/load instruction
|
|
// check ISA manual for each GFX target
|
|
// e.g. for
|
|
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
|
|
// page 67~68
|
|
enum struct AmdBufferCoherenceEnum
|
|
{
|
|
DefaultCoherence = 0, // default value
|
|
GLC = 1,
|
|
SLC = 2,
|
|
GLC_SLC = 3,
|
|
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
|
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
|
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
|
WAVE_NT0 = 0,
|
|
WAVE_NT1 = 2,
|
|
GROUP_NT0 = 1,
|
|
GROUP_NT1 = 3,
|
|
DEVICE_NT0 = 8,
|
|
DEVICE_NT1 = 10,
|
|
SYSTEM_NT0 = 9,
|
|
SYSTEM_NT1 = 11,
|
|
};
|
|
|
|
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
|
__device__ typename vector_type<int8_t, N>::type
|
|
amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource,
|
|
index_t src_thread_addr_offset,
|
|
index_t src_wave_addr_offset)
|
|
{
|
|
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
|
"wrong! not implemented");
|
|
|
|
if constexpr(N == 1)
|
|
{
|
|
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
else if constexpr(N == 2)
|
|
{
|
|
|
|
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
|
|
return bit_cast<int8x2_t>(tmp);
|
|
}
|
|
else if constexpr(N == 4)
|
|
{
|
|
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
|
|
return bit_cast<int8x4_t>(tmp);
|
|
}
|
|
else if constexpr(N == 8)
|
|
{
|
|
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
|
|
return bit_cast<int8x8_t>(tmp);
|
|
}
|
|
else if constexpr(N == 16)
|
|
{
|
|
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
return bit_cast<int8x16_t>(tmp);
|
|
}
|
|
else if constexpr(N == 32)
|
|
{
|
|
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
int32x4_t tmp1 =
|
|
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset + 4 * sizeof(int32_t),
|
|
static_cast<index_t>(coherence));
|
|
vector_type<int32_t, 8> tmp;
|
|
|
|
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
|
|
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
|
|
|
|
return bit_cast<int8x32_t>(tmp);
|
|
}
|
|
else if constexpr(N == 64)
|
|
{
|
|
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
int32x4_t tmp1 =
|
|
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset + 4 * sizeof(int32_t),
|
|
static_cast<index_t>(coherence));
|
|
int32x4_t tmp2 =
|
|
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset + 8 * sizeof(int32_t),
|
|
static_cast<index_t>(coherence));
|
|
int32x4_t tmp3 =
|
|
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
|
src_thread_addr_offset,
|
|
src_wave_addr_offset + 12 * sizeof(int32_t),
|
|
static_cast<index_t>(coherence));
|
|
|
|
vector_type<int32_t, 16> tmp;
|
|
|
|
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
|
|
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
|
|
tmp.AsType<int32x4_t>()(Number<2>{}) = tmp2;
|
|
tmp.AsType<int32x4_t>()(Number<3>{}) = tmp3;
|
|
|
|
return bit_cast<int8x64_t>(tmp);
|
|
}
|
|
}
|
|
|
|
template <typename T,
|
|
index_t N,
|
|
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
|
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
|
index_t src_thread_addr_offset,
|
|
index_t src_wave_addr_offset)
|
|
{
|
|
static_assert(
|
|
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
|
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
|
"wrong! not implemented");
|
|
|
|
using r_t = typename vector_type<T, N>::type;
|
|
auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>(
|
|
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
|
|
return bit_cast<r_t>(raw_data);
|
|
}
|
|
|
|
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
|
__device__ void
|
|
amd_buffer_store_impl_raw(const typename vector_type<int8_t, N>::type src_thread_data,
|
|
int32x4_t dst_wave_buffer_resource,
|
|
index_t dst_thread_addr_offset,
|
|
index_t dst_wave_addr_offset)
|
|
{
|
|
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
|
"wrong! not implemented");
|
|
|
|
if constexpr(N == 1)
|
|
{
|
|
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
else if constexpr(N == 2)
|
|
{
|
|
|
|
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
else if constexpr(N == 4)
|
|
{
|
|
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
else if constexpr(N == 8)
|
|
{
|
|
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
else if constexpr(N == 16)
|
|
{
|
|
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
else if constexpr(N == 32)
|
|
{
|
|
vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
|
|
|
|
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
|
|
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(int32_t) * 4,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
else if constexpr(N == 64)
|
|
{
|
|
vector_type<int32_t, 16> tmp{bit_cast<int32x16_t>(src_thread_data)};
|
|
|
|
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
static_cast<index_t>(coherence));
|
|
|
|
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(int32_t) * 4,
|
|
static_cast<index_t>(coherence));
|
|
|
|
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<2>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(int32_t) * 8,
|
|
static_cast<index_t>(coherence));
|
|
|
|
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<3>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(int32_t) * 12,
|
|
static_cast<index_t>(coherence));
|
|
}
|
|
}
|
|
|
|
template <typename T,
|
|
index_t N,
|
|
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
|
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
|
int32x4_t dst_wave_buffer_resource,
|
|
index_t dst_thread_addr_offset,
|
|
index_t dst_wave_addr_offset)
|
|
{
|
|
static_assert(
|
|
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
|
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, fp8_storage_t>::value &&
|
|
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
|
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
|
"wrong! not implemented");
|
|
|
|
using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
|
|
|
|
amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset);
|
|
}
|
|
|
|
template <typename T, index_t N>
|
|
__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
|
|
T* addr)
|
|
{
|
|
static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
|
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
|
|
"wrong! not implemented");
|
|
|
|
if constexpr(is_same<T, half_t>::value)
|
|
{
|
|
vector_type<half_t, N> tmp{src_thread_data};
|
|
static_for<0, N / 2, 1>{}([&](auto i) {
|
|
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
|
|
tmp.template AsType<half2_t>()[i]);
|
|
});
|
|
}
|
|
#if defined(__gfx942__) || defined(__gfx950__)
|
|
else if constexpr(is_same<T, bhalf_t>::value)
|
|
{
|
|
vector_type<bhalf_t, N> tmp{src_thread_data};
|
|
static_for<0, N / 2, 1>{}([&](auto i) {
|
|
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
|
|
tmp.template AsType<bhalf2_t>()[i]);
|
|
});
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <typename T, index_t N>
|
|
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
|
|
int32x4_t dst_wave_buffer_resource,
|
|
index_t dst_thread_addr_offset,
|
|
index_t dst_wave_addr_offset)
|
|
{
|
|
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
|
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
|
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
|
"wrong! not implemented");
|
|
|
|
if constexpr(is_same<T, float>::value)
|
|
{
|
|
if constexpr(N == 1)
|
|
{
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data,
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
}
|
|
else if constexpr(N == 2)
|
|
{
|
|
vector_type<float, 2> tmp{src_thread_data};
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(float),
|
|
0);
|
|
}
|
|
else if constexpr(N == 4)
|
|
{
|
|
vector_type<float, 4> tmp{src_thread_data};
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(float),
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<2>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + 2 * sizeof(float),
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<3>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + 3 * sizeof(float),
|
|
0);
|
|
}
|
|
}
|
|
else if constexpr(is_same<T, half_t>::value)
|
|
{
|
|
if constexpr(N == 2)
|
|
{
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data,
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
}
|
|
else if constexpr(N == 4)
|
|
{
|
|
vector_type<half_t, 4> tmp{src_thread_data};
|
|
|
|
static_for<0, 2, 1>{}([&](auto i) {
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType<half2_t>()[i],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + i * sizeof(half2_t),
|
|
0);
|
|
});
|
|
}
|
|
else if constexpr(N == 8)
|
|
{
|
|
vector_type<half_t, 8> tmp{src_thread_data};
|
|
|
|
static_for<0, 4, 1>{}([&](auto i) {
|
|
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType<half2_t>()[i],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + i * sizeof(half2_t),
|
|
0);
|
|
});
|
|
}
|
|
}
|
|
else if constexpr(is_same<T, int32_t>::value)
|
|
{
|
|
if constexpr(N == 1)
|
|
{
|
|
llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data,
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
}
|
|
else if constexpr(N == 2)
|
|
{
|
|
vector_type<int32_t, 2> tmp{src_thread_data};
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(int32_t),
|
|
0);
|
|
}
|
|
else if constexpr(N == 4)
|
|
{
|
|
vector_type<int32_t, 4> tmp{src_thread_data};
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(int32_t),
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<2>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + 2 * sizeof(int32_t),
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<3>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + 3 * sizeof(int32_t),
|
|
0);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, index_t N>
|
|
__device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::type src_thread_data,
|
|
int32x4_t dst_wave_buffer_resource,
|
|
index_t dst_thread_addr_offset,
|
|
index_t dst_wave_addr_offset)
|
|
{
|
|
static_assert((is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
|
|
"wrong! not implemented");
|
|
if constexpr(is_same<T, double>::value)
|
|
{
|
|
if constexpr(N == 1)
|
|
{
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data,
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
}
|
|
else if constexpr(N == 2)
|
|
{
|
|
vector_type<double, 2> tmp{src_thread_data};
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(double),
|
|
0);
|
|
}
|
|
else if constexpr(N == 4)
|
|
{
|
|
vector_type<double, 4> tmp{src_thread_data};
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset,
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + sizeof(double),
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<2>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + 2 * sizeof(double),
|
|
0);
|
|
|
|
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<3>{}],
|
|
dst_wave_buffer_resource,
|
|
dst_thread_addr_offset,
|
|
dst_wave_addr_offset + 3 * sizeof(double),
|
|
0);
|
|
}
|
|
}
|
|
}
|
|
|
|
// buffer_load requires:
|
|
// 1) p_src_wave must point to global memory space
|
|
// 2) p_src_wave must be a wavewise pointer.
|
|
// It is user's responsibility to make sure that is true.
|
|
template <typename T,
|
|
index_t N,
|
|
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
|
__device__ typename vector_type_maker<T, N>::type::type
|
|
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
|
index_t src_thread_element_offset,
|
|
bool src_thread_element_valid,
|
|
index_t src_element_space_size)
|
|
{
|
|
const int32x4_t src_wave_buffer_resource =
|
|
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
|
|
|
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
|
|
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
|
using scalar_t = typename scalar_type<vector_t>::type;
|
|
|
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
|
|
|
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
|
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
|
|
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
|
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
|
|
|
#else
|
|
|
|
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
|
src_wave_buffer_resource, src_thread_addr_offset, 0)};
|
|
return src_thread_element_valid ? tmp : vector_t(0);
|
|
#endif
|
|
}
|
|
|
|
// buffer_load requires:
|
|
// 1) p_src_wave must point to global memory space
|
|
// 2) p_src_wave must be a wavewise pointer.
|
|
// It is user's responsibility to make sure that is true.
|
|
template <typename T,
|
|
index_t N,
|
|
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
|
__device__ typename vector_type_maker<T, N>::type::type
|
|
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
|
index_t src_thread_element_offset,
|
|
bool src_thread_element_valid,
|
|
index_t src_element_space_size,
|
|
T customized_value)
|
|
{
|
|
const int32x4_t src_wave_buffer_resource =
|
|
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
|
|
|
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
|
|
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
|
using scalar_t = typename scalar_type<vector_t>::type;
|
|
|
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
|
|
|
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
|
src_wave_buffer_resource, src_thread_addr_offset, 0)};
|
|
|
|
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
|
}
|
|
|
|
// buffer_store requires:
|
|
// 1) p_dst_wave must point to global memory
|
|
// 2) p_dst_wave must be a wavewise pointer.
|
|
// It is user's responsibility to make sure that is true.
|
|
template <typename T,
|
|
index_t N,
|
|
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
|
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
|
T* p_dst_wave,
|
|
const index_t dst_thread_element_offset,
|
|
const bool dst_thread_element_valid,
|
|
const index_t dst_element_space_size)
|
|
{
|
|
const int32x4_t dst_wave_buffer_resource =
|
|
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
|
|
|
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
|
|
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
|
using scalar_t = typename scalar_type<vector_t>::type;
|
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
|
|
|
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
|
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
|
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
|
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
|
#else
|
|
if(dst_thread_element_valid)
|
|
{
|
|
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
|
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
// buffer_atomic_add requires:
|
|
// 1) p_dst_wave must point to global memory
|
|
// 2) p_dst_wave must be a wavewise pointer.
|
|
// It is user's responsibility to make sure that is true.
|
|
template <typename T, index_t N>
|
|
__device__ void
|
|
amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
|
T* p_dst_wave,
|
|
const index_t dst_thread_element_offset,
|
|
const bool dst_thread_element_valid,
|
|
const index_t dst_element_space_size)
|
|
{
|
|
const int32x4_t dst_wave_buffer_resource =
|
|
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
|
|
|
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
|
|
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
|
using scalar_t = typename scalar_type<vector_t>::type;
|
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
|
|
|
if constexpr(is_same<T, bhalf_t>::value)
|
|
{
|
|
if(dst_thread_element_valid)
|
|
{
|
|
amd_global_atomic_add_impl<scalar_t, vector_size>(
|
|
src_thread_data, p_dst_wave + dst_thread_element_offset);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
|
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
|
|
|
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
|
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
|
#else
|
|
if(dst_thread_element_valid)
|
|
{
|
|
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
|
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
|
}
|
|
#endif
|
|
}
|
|
}
|
|
|
|
// buffer_atomic_max requires:
|
|
// 1) p_dst_wave must point to global memory
|
|
// 2) p_dst_wave must be a wavewise pointer.
|
|
// It is user's responsibility to make sure that is true.
|
|
template <typename T, index_t N>
|
|
__device__ void
|
|
amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
|
T* p_dst_wave,
|
|
const index_t dst_thread_element_offset,
|
|
const bool dst_thread_element_valid,
|
|
const index_t dst_element_space_size)
|
|
{
|
|
const int32x4_t dst_wave_buffer_resource =
|
|
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
|
|
|
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
|
|
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
|
using scalar_t = typename scalar_type<vector_t>::type;
|
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
|
|
|
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
|
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
|
|
|
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
|
|
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
|
#else
|
|
if(dst_thread_element_valid)
|
|
{
|
|
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
|
|
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
// Direct loads from global to LDS.
|
|
__device__ void
|
|
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
|
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
|
index_t size,
|
|
index_t voffset,
|
|
index_t soffset,
|
|
index_t offset,
|
|
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
|
|
|
template <typename T, index_t NumElemsPerThread>
|
|
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
|
const index_t global_offset,
|
|
T* lds_base_ptr,
|
|
const index_t lds_offset,
|
|
const bool is_valid,
|
|
const index_t src_element_space_size)
|
|
{
|
|
// Direct loads require that each thread reads and writes exactly a single DWORD.
|
|
constexpr auto dword_bytes = 4;
|
|
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
|
|
static_assert(bytes_per_thread == dword_bytes);
|
|
|
|
#ifndef CK_CODE_GEN_RTC
|
|
const uint32_t* global_ptr =
|
|
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
|
|
#else
|
|
const uint32_t* global_ptr =
|
|
reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr));
|
|
#endif
|
|
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
|
|
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
|
|
|
|
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
|
T* lds_ptr = lds_base_ptr + lds_offset;
|
|
#ifndef CK_CODE_GEN_RTC
|
|
auto const lds_ptr_sgpr =
|
|
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
|
|
#else
|
|
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr)));
|
|
#endif
|
|
asm volatile("s_mov_b32 m0, %0; \n\t"
|
|
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
|
|
"v"(global_offset_bytes),
|
|
"s"(src_resource)
|
|
: "memory");
|
|
#else
|
|
// LDS pointer must be attributed with the LDS address space.
|
|
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
|
#ifndef CK_CODE_GEN_RTC
|
|
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
|
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
|
#else
|
|
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
|
reinterpret_cast<size_t>(lds_base_ptr + lds_offset));
|
|
#endif
|
|
|
|
llvm_amdgcn_raw_buffer_load_lds(
|
|
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
|
#endif
|
|
}
|
|
|
|
} // namespace ck
|