// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "data_type.hpp" #include "ck/utility/amd_buffer_coherence.hpp" #if defined(__gfx125__) #include "ck/utility/amd_address_space.hpp" #endif namespace ck { template union BufferResource { __device__ constexpr BufferResource() : content{} {} // 128 bit SGPRs to supply buffer resource in buffer instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions int32x4_t content; StaticallyIndexedArray address; StaticallyIndexedArray range; StaticallyIndexedArray config; }; template __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) { BufferResource wave_buffer_resource; #if defined(__gfx125__) // wavewise base address (57 bit) wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); // wavewise range (45 bit) // NOTE: high 6bits is in wave_buffer_resource.range[3], it is overlapped with config dword. // because element_space_size only has 32bits, it is safe to assume the high 6bits are 0. uint64_t num_records = element_space_size * sizeof(T); wave_buffer_resource.range(Number<1>{}) |= (num_records & 0x7f) << 25; wave_buffer_resource.range(Number<2>{}) = (num_records >> 7); // wavewise setting (26 bit) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; #else // wavewise base address (64 bit) wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); // wavewise range (32 bit) wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T); #endif // wavewise setting (32 bit) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; return wave_buffer_resource.content; } #if defined(__gfx125__) // SW workaround for HW issue: SMEM Buffer Ops Misinterpreting V# NUM_RECORDS (A) // W/A - set STRIDE (bits 121:108) to 1 for constant address space buffer access template __device__ int32x4_t make_wave_buffer_resource(T CK_CONSTANT_ADDRESS_SPACE* p_wave, index_t element_space_size) { BufferResource wave_buffer_resource; // Cast constant address space pointer to generic wave_buffer_resource.address(Number<0>{}) = const_cast*>(cast_pointer_to_generic_address_space(p_wave)); // wavewise range (45 bit) uint64_t num_records = element_space_size * sizeof(T); wave_buffer_resource.range(Number<1>{}) |= (num_records & 0x7f) << 25; wave_buffer_resource.range(Number<2>{}) = (num_records >> 7); // wavewise setting (26 bit) with STRIDE=1 at bits 121:108 (bits 25:12 of dword 3) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD | (1 << 12); return wave_buffer_resource.content; } #endif template __device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) { BufferResource wave_buffer_resource; // wavewise base address (64 bit) wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); // wavewise range (32 bit) wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range // wavewise setting (32 bit) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; return wave_buffer_resource.content; } #if defined(__gfx125__) // SW workaround for HW issue: SMEM Buffer Ops Misinterpreting V# NUM_RECORDS (A) // W/A - set STRIDE (bits 121:108) to 1 for constant address space buffer access template __device__ int32x4_t make_wave_buffer_resource_with_default_range(T CK_CONSTANT_ADDRESS_SPACE* p_wave) { BufferResource wave_buffer_resource; // Cast constant address space pointer to generic wave_buffer_resource.address(Number<0>{}) = const_cast*>(cast_pointer_to_generic_address_space(p_wave)); // wavewise range (32 bit) wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range // wavewise setting (32 bit) with STRIDE=1 at bits 121:108 (bits 25:12 of dword 3) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD | (1 << 12); return wave_buffer_resource.content; } #endif template __device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T* p_wave, index_t element_space_size) { // wavewise base address (64 bit) auto p = const_cast*>(p_wave); int32_t stride = 0; int32_t num = element_space_size * sizeof(T); auto flags = CK_BUFFER_RESOURCE_3RD_DWORD; return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags); } #if defined(__gfx125__) // SW workaround for HW issue: SMEM Buffer Ops Misinterpreting V# NUM_RECORDS (A) // W/A - set STRIDE to 1 for constant address space buffer access template __device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T CK_CONSTANT_ADDRESS_SPACE* p_wave, index_t element_space_size) { // Cast constant address space pointer to generic and set stride = 1 auto p = const_cast*>(cast_pointer_to_generic_address_space(p_wave)); int32_t stride = 1; // stride = 1 for constant address space buffer access int32_t num = element_space_size * sizeof(T); auto flags = CK_BUFFER_RESOURCE_3RD_DWORD; return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags); } #endif template __device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T* p_wave) { // wavewise base address (64 bit) auto p = const_cast*>(p_wave); int32_t stride = 0; int32_t num = 0xffffffff; auto flags = CK_BUFFER_RESOURCE_3RD_DWORD; return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags); } #if defined(__gfx125__) // SW workaround for HW issue: SMEM Buffer Ops Misinterpreting V# NUM_RECORDS (A) // W/A - set STRIDE to 1 for constant address space buffer access template __device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T CK_CONSTANT_ADDRESS_SPACE* p_wave) { // Cast constant address space pointer to generic and set stride = 1 auto p = const_cast*>(cast_pointer_to_generic_address_space(p_wave)); int32_t stride = 1; // stride = 1 for constant address space buffer access int32_t num = 0xffffffff; auto flags = CK_BUFFER_RESOURCE_3RD_DWORD; return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags); } #endif // buffer atomic-add fp16 __device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( half2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32"); // buffer atomic-add i32 __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32"); // buffer atomic-add fp32 __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32"); // buffer atomic-add fp32 __device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64( double vdata, int32x4_t rsrc, // dst_wave_buffer_resource int voffset, // dst_thread_addr_offset int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32"); template __device__ typename vector_type::type amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); if constexpr(N == 1) { return __builtin_amdgcn_raw_buffer_load_b8(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 2) { int16_t tmp = __builtin_amdgcn_raw_buffer_load_b16(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 4) { int32_t tmp = __builtin_amdgcn_raw_buffer_load_b32(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 8) { int32x2_t tmp = __builtin_amdgcn_raw_buffer_load_b64(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 16) { int32x4_t tmp = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 32) { int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); int32x4_t tmp1 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), static_cast(coherence)); vector_type tmp; tmp.AsType()(Number<0>{}) = tmp0; tmp.AsType()(Number<1>{}) = tmp1; return bit_cast(tmp); } else if constexpr(N == 64) { int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); int32x4_t tmp1 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), static_cast(coherence)); int32x4_t tmp2 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 8 * sizeof(int32_t), static_cast(coherence)); int32x4_t tmp3 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 12 * sizeof(int32_t), static_cast(coherence)); vector_type tmp; tmp.AsType()(Number<0>{}) = tmp0; tmp.AsType()(Number<1>{}) = tmp1; tmp.AsType()(Number<2>{}) = tmp2; tmp.AsType()(Number<3>{}) = tmp3; return bit_cast(tmp); } } template __device__ typename vector_type::type amd_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { static_assert( (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using r_t = typename vector_type::type; auto raw_data = amd_buffer_load_impl_raw( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); return bit_cast(raw_data); } template __device__ void amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, __amdgpu_buffer_rsrc_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); if constexpr(N == 1) { __builtin_amdgcn_raw_buffer_store_b8(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 2) { __builtin_amdgcn_raw_buffer_store_b16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 4) { __builtin_amdgcn_raw_buffer_store_b32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 8) { __builtin_amdgcn_raw_buffer_store_b64(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 16) { __builtin_amdgcn_raw_buffer_store_b128(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 32) { vector_type tmp{bit_cast(src_thread_data)}; __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 4, static_cast(coherence)); } else if constexpr(N == 64) { vector_type tmp{bit_cast(src_thread_data)}; __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 4, static_cast(coherence)); __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<2>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 8, static_cast(coherence)); __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<3>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 12, static_cast(coherence)); } } template __device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, __amdgpu_buffer_rsrc_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert( (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; amd_buffer_store_impl_raw(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset); } template __device__ void amd_global_atomic_add_impl(const typename vector_type::type src_thread_data, T* addr) { static_assert((is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 2 || N == 4 || N == 8)), "wrong! not implemented"); if constexpr(is_same::value) { vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, tmp.template AsType()[i]); }); } #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, tmp.template AsType()[i]); }); } #endif } template __device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), "wrong! not implemented"); if constexpr(is_same::value) { if constexpr(N == 1) { llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else { vector_type tmp{src_thread_data}; static_for<0, N, 1>{}([&](auto i) { llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.template AsType()[i], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + i * sizeof(float), 0); }); } } else if constexpr(is_same::value) { if constexpr(N == 2) { llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else if constexpr(N == 4) { vector_type tmp{src_thread_data}; static_for<0, 2, 1>{}([&](auto i) { llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + i * sizeof(half2_t), 0); }); } else if constexpr(N == 8) { vector_type tmp{src_thread_data}; static_for<0, 4, 1>{}([&](auto i) { llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + i * sizeof(half2_t), 0); }); } } else if constexpr(is_same::value) { if constexpr(N == 1) { llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else { vector_type tmp{src_thread_data}; static_for<0, N, 1>{}([&](auto i) { llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.template AsType()[i], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + i * sizeof(int32_t), 0); }); } } } template __device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); if constexpr(is_same::value) { if constexpr(N == 1) { llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else if constexpr(N == 2) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(double), 0); } else if constexpr(N == 4) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(double), 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 2 * sizeof(double), 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 3 * sizeof(double), 0); } } } // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size) { const __amdgpu_buffer_rsrc_t src_wave_buffer_resource = make_wave_buffer_resource_new(p_src_wave, src_element_space_size); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else vector_t tmp{amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(0); #endif } // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value) { const __amdgpu_buffer_rsrc_t src_wave_buffer_resource = make_wave_buffer_resource_new(p_src_wave, src_element_space_size); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; vector_t tmp{amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(customized_value); } // buffer_store requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size) { const __amdgpu_buffer_rsrc_t dst_wave_buffer_resource = make_wave_buffer_resource_new(p_dst_wave, dst_element_space_size); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } // buffer_atomic_add requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = make_wave_buffer_resource(p_dst_wave, dst_element_space_size); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; if constexpr(is_same::value) { if(dst_thread_element_valid) { amd_global_atomic_add_impl( src_thread_data, p_dst_wave + dst_thread_element_offset); } } else { #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_add_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { amd_buffer_atomic_add_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } } // buffer_atomic_max requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = make_wave_buffer_resource(p_dst_wave, dst_element_space_size); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_max_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { amd_buffer_atomic_max_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } // Direct loads from global to LDS. #if __clang_major__ >= 21 && __clang_major__ < 23 __device__ void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, __attribute__((address_space(3))) uint32_t* lds_ptr, index_t size, index_t voffset, index_t soffset, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32"); #else __device__ void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, __attribute__((address_space(3))) uint32_t* lds_ptr, index_t size, index_t voffset, index_t soffset, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); #endif #ifndef __HIPCC_RTC__ template __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, T* lds_base_ptr, const index_t lds_offset, const bool is_valid, const index_t src_element_space_size) { // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). // For gfx950: supports 1, 3, or 4 DWORDs per thread // For gfx942: supports exactly 1 DWORD per thread constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; #if defined(__gfx950__) constexpr auto dword_bytes = 4; static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || bytes_per_thread == dword_bytes * 4); #elif defined(__gfx942__) constexpr auto dword_bytes = 4; static_assert(bytes_per_thread == dword_bytes); #else ignore = bytes_per_thread; #endif const int32x4_t src_resource = make_wave_buffer_resource(global_base_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM T* lds_ptr = lds_base_ptr + lds_offset; #ifndef CK_CODE_GEN_RTC auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); #else auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); #endif asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), "s"(src_resource) : "memory"); #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = #ifndef CK_CODE_GEN_RTC reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast(lds_base_ptr + lds_offset)); #else reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast(lds_base_ptr + lds_offset)); #endif llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } #endif template __device__ void amd_async_copy_to_lds_impl_raw(__attribute__((address_space(1))) const T* src_ptr, index_t src_offset, __attribute__((address_space(3))) T* dst_ptr) { static_assert(NumBytesPerThread == 1 || NumBytesPerThread == 4 || NumBytesPerThread == 8 || NumBytesPerThread == 16, "NumBytesPerThread must be 1, 4, 8, or 16"); // ROCm 7.0.1 compiler flags unsupported builtins even though the function is never instantiated // for gfx9xx architectures #if defined(__gfx125__) #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM constexpr bool use_asm_path = is_uniform_src_ptr; #else constexpr bool use_asm_path = false; #endif if constexpr(NumBytesPerThread == 1) { if constexpr(use_asm_path) { asm volatile("global_load_async_to_lds_b8 %0, %1, %2, offset:%3\n\t" ::"v"( static_cast(reinterpret_cast(dst_ptr))), "v"(static_cast((src_offset - static_dst_offset) * sizeof(T))), "s"(reinterpret_cast(src_ptr)), "n"(static_cast(static_dst_offset * sizeof(T))) : "memory"); } else { __attribute__((address_space(1))) char* cv_ptr = const_cast<__attribute__((address_space(1))) char*>( reinterpret_cast( src_ptr + src_offset - static_dst_offset)); __attribute__((address_space(3))) char* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) char*>(dst_ptr); __builtin_amdgcn_global_load_async_to_lds_b8( cv_ptr, lds_ptr, static_dst_offset * sizeof(T), static_cast(coherence)); } return; } if constexpr(NumBytesPerThread == 4) { if constexpr(use_asm_path) { asm volatile("global_load_async_to_lds_b32 %0, %1, %2, offset:%3\n\t" ::"v"( static_cast(reinterpret_cast(dst_ptr))), "v"(static_cast((src_offset - static_dst_offset) * sizeof(T))), "s"(reinterpret_cast(src_ptr)), "n"(static_cast(static_dst_offset * sizeof(T))) : "memory"); } else { __attribute__((address_space(1))) int* cv_ptr = const_cast<__attribute__((address_space(1))) int*>( reinterpret_cast( src_ptr + src_offset - static_dst_offset)); __attribute__((address_space(3))) int* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) int*>(dst_ptr); __builtin_amdgcn_global_load_async_to_lds_b32( cv_ptr, lds_ptr, static_dst_offset * sizeof(T), static_cast(coherence)); } return; } if constexpr(NumBytesPerThread == 8) { if constexpr(use_asm_path) { asm volatile("global_load_async_to_lds_b64 %0, %1, %2, offset:%3\n\t" ::"v"( static_cast(reinterpret_cast(dst_ptr))), "v"(static_cast((src_offset - static_dst_offset) * sizeof(T))), "s"(reinterpret_cast(src_ptr)), "n"(static_cast(static_dst_offset * sizeof(T))) : "memory"); } else { __attribute__((address_space(1))) int32x2_t* cv_ptr = const_cast<__attribute__((address_space(1))) int32x2_t*>( reinterpret_cast( src_ptr + src_offset - static_dst_offset)); __attribute__((address_space(3))) int32x2_t* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) int32x2_t*>(dst_ptr); __builtin_amdgcn_global_load_async_to_lds_b64( cv_ptr, lds_ptr, static_dst_offset * sizeof(T), static_cast(coherence)); } return; } if constexpr(NumBytesPerThread == 16) { if constexpr(use_asm_path) { asm volatile("global_load_async_to_lds_b128 %0, %1, %2, offset:%3\n\t" ::"v"( static_cast(reinterpret_cast(dst_ptr))), "v"(static_cast((src_offset - static_dst_offset) * sizeof(T))), "s"(reinterpret_cast(src_ptr)), "n"(static_cast(static_dst_offset * sizeof(T))) : "memory"); } else { __attribute__((address_space(1))) int32x4_t* cv_ptr = const_cast<__attribute__((address_space(1))) int32x4_t*>( reinterpret_cast( src_ptr + src_offset - static_dst_offset)); __attribute__((address_space(3))) int32x4_t* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) int32x4_t*>(dst_ptr); __builtin_amdgcn_global_load_async_to_lds_b128( cv_ptr, lds_ptr, static_dst_offset * sizeof(T), static_cast(coherence)); } return; } #else ignore = src_ptr; ignore = dst_ptr; ignore = src_offset; #endif } template __device__ void amd_async_store_to_global_impl_raw(__attribute__((address_space(3))) const T* src_ptr, __attribute__((address_space(1))) T* dst_ptr) { static_assert(NumBytesPerThread == 1 || NumBytesPerThread == 4 || NumBytesPerThread == 8 || NumBytesPerThread == 16, "NumBytesPerThread must be 1, 4, 8, or 16"); // ROCm 7.0.1 compiler flags unsupported builtins even though the function is never instantiated // for gfx9xx architectures #if defined(__gfx125__) if constexpr(NumBytesPerThread == 1) { __attribute__((address_space(3))) char* lds_ptr = const_cast<__attribute__((address_space(3))) char*>( reinterpret_cast(src_ptr)); __attribute__((address_space(1))) char* global_ptr = reinterpret_cast<__attribute__((address_space(1))) char*>(dst_ptr); __builtin_amdgcn_global_store_async_from_lds_b8( global_ptr, lds_ptr, 0, static_cast(coherence)); return; } if constexpr(NumBytesPerThread == 4) { __attribute__((address_space(3))) int* lds_ptr = const_cast<__attribute__((address_space(3))) int*>( reinterpret_cast(src_ptr)); __attribute__((address_space(1))) int* global_ptr = reinterpret_cast<__attribute__((address_space(1))) int*>(dst_ptr); __builtin_amdgcn_global_store_async_from_lds_b32( global_ptr, lds_ptr, 0, static_cast(coherence)); return; } if constexpr(NumBytesPerThread == 8) { __attribute__((address_space(3))) int32x2_t* lds_ptr = const_cast<__attribute__((address_space(3))) int32x2_t*>( reinterpret_cast(src_ptr)); __attribute__((address_space(1))) int32x2_t* global_ptr = reinterpret_cast<__attribute__((address_space(1))) int32x2_t*>(dst_ptr); __builtin_amdgcn_global_store_async_from_lds_b64( global_ptr, lds_ptr, 0, static_cast(coherence)); return; } if constexpr(NumBytesPerThread == 16) { __attribute__((address_space(3))) int32x4_t* lds_ptr = const_cast<__attribute__((address_space(3))) int32x4_t*>( reinterpret_cast(src_ptr)); __attribute__((address_space(1))) int32x4_t* global_ptr = reinterpret_cast<__attribute__((address_space(1))) int32x4_t*>(dst_ptr); __builtin_amdgcn_global_store_async_from_lds_b128( global_ptr, lds_ptr, 0, static_cast(coherence)); return; } #else ignore = src_ptr; ignore = dst_ptr; #endif } template __device__ void amd_async_copy_to_lds_impl(__attribute__((address_space(1))) const T* src_ptr, index_t src_offfset, __attribute__((address_space(3))) T* dst_ptr) { #if defined(__gfx125__) // currently only support to b8, b32, b64, b128 when one async copy static_assert((is_same::value && (N == 1 || N == 2)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), "wrong! not yet supported"); amd_async_copy_to_lds_impl_raw(src_ptr, src_offfset, dst_ptr); #else ignore = src_ptr; ignore = dst_ptr; ignore = src_offfset; #endif return; } template __device__ void amd_async_store_to_global_impl(__attribute__((address_space(3))) const T* src_ptr, __attribute__((address_space(1))) T* dst_ptr) { #if defined(__gfx125__) // copy 8, 32, 64, or 128 bit per thread static_assert((is_same::value && (N == 1 || N == 2)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); amd_async_store_to_global_impl_raw(src_ptr, dst_ptr); #else ignore = src_ptr; ignore = dst_ptr; #endif return; } template __device__ void amd_async_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, T* lds_base_ptr, const index_t lds_offset, const bool is_src_valid) { if(is_src_valid) { __attribute__((address_space(1))) const T* global_ptr = reinterpret_cast<__attribute__((address_space(1))) T*>( reinterpret_cast(global_base_ptr)); __attribute__((address_space(3))) T* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) T*>( reinterpret_cast(lds_base_ptr + lds_offset)); amd_async_copy_to_lds_impl(global_ptr, global_offset, lds_ptr); } else { using DstVecType = typename vector_type_maker::type; DstVecType* lds_ptr = reinterpret_cast(lds_base_ptr + lds_offset + static_dst_offset); *lds_ptr = {}; } } template __device__ void amd_async_store_lds_to_global(const T* lds_base_ptr, const index_t lds_offset, T* global_base_ptr, const index_t global_offset, const bool is_src_valid, const bool is_dst_valid) { if(is_src_valid && is_dst_valid) { __attribute__((address_space(3))) const T* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) T*>( reinterpret_cast(lds_base_ptr + lds_offset)); __attribute__((address_space(1))) T* global_ptr = reinterpret_cast<__attribute__((address_space(1))) T*>( reinterpret_cast(global_base_ptr + global_offset)); amd_async_store_to_global_impl(lds_ptr, global_ptr); } } } // namespace ck