diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 783fc661ce..e626603949 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -3,6 +3,7 @@ #pragma once #include "data_type.hpp" +#include "amd_inline_asm.hpp" namespace ck { @@ -1064,4 +1065,48 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, } #endif +template +__device__ typename vector_type::type +amd_s_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, + index_t src_wave_addr_offset) +{ + static_assert(N == 4 || N == 8, "wrong! not implemented"); + // TODO: add other variants of s_buffer_load + if constexpr(N == 4) + { + int32_t tmp = + amd_assembly_s_buffer_load_b32(src_wave_buffer_resource, src_wave_addr_offset); + return bit_cast(tmp); + } + else if constexpr(N == 8) + { + int32x2_t tmp = + amd_assembly_s_buffer_load_b64(src_wave_buffer_resource, src_wave_addr_offset); + return bit_cast(tmp); + } +} + +template +__device__ typename vector_type::type +amd_s_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, + index_t src_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1)) || + (is_same::value && (N == 1 || N == 2)) || + (is_same::value && (N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + auto raw_data = + amd_s_buffer_load_impl_raw(src_wave_buffer_resource, src_wave_addr_offset); + return bit_cast(raw_data); +} + } // namespace ck diff --git a/include/ck/utility/amd_buffer_addressing_builtins.hpp b/include/ck/utility/amd_buffer_addressing_builtins.hpp index f642e06050..06a4ec199d 100644 --- a/include/ck/utility/amd_buffer_addressing_builtins.hpp +++ b/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -3,6 +3,7 @@ #pragma once #include "data_type.hpp" +#include "amd_inline_asm.hpp" namespace ck { @@ -885,4 +886,48 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, } #endif +template +__device__ typename vector_type::type +amd_s_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, + index_t src_wave_addr_offset) +{ + static_assert(N == 4 || N == 8, "wrong! not implemented"); + // TODO: add other variants of s_buffer_load + if constexpr(N == 4) + { + int32_t tmp = + amd_assembly_s_buffer_load_b32(src_wave_buffer_resource, src_wave_addr_offset); + return bit_cast(tmp); + } + else if constexpr(N == 8) + { + int32x2_t tmp = + amd_assembly_s_buffer_load_b64(src_wave_buffer_resource, src_wave_addr_offset); + return bit_cast(tmp); + } +} + +template +__device__ typename vector_type::type +amd_s_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, + index_t src_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1)) || + (is_same::value && (N == 1 || N == 2)) || + (is_same::value && (N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)) || + (is_same::value && (N == 4 || N == 8)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + auto raw_data = + amd_s_buffer_load_impl_raw(src_wave_buffer_resource, src_wave_addr_offset); + return bit_cast(raw_data); +} + } // namespace ck diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index efe1f300c2..e9f9e407d6 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -431,5 +431,29 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, c3); } #endif + +// s_buffer_loads +inline __device__ int32_t +amd_assembly_s_buffer_load_b32(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, unsigned int offset) +{ + int32_t result; + asm volatile("s_buffer_load_b32 %0, %1, %2" + : "=s"(result) + : "s"(src_wave_buffer_resource), "s"(offset) + : "memory"); + return result; +} + +inline __device__ int32x2_t +amd_assembly_s_buffer_load_b64(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, unsigned int offset) +{ + int32x2_t result; + asm volatile("s_buffer_load_b64 %0, %1, %2" + : "=s"(result) + : "s"(src_wave_buffer_resource), "s"(offset) + : "memory"); + return result; +} + } // namespace ck #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d47e55db64..3a667cb551 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -296,5 +296,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_subdirectory(mx_mfma_op) add_subdirectory(gemm_mx) endif() +if(SUPPORTED_GPU_TARGETS MATCHES "gfx12") + add_subdirectory(s_prefetch_op) +endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) diff --git a/test/s_prefetch_op/CMakeLists.txt b/test/s_prefetch_op/CMakeLists.txt new file mode 100644 index 0000000000..1b598cc952 --- /dev/null +++ b/test/s_prefetch_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_s_prefetch_op s_prefetch_op.cpp) +target_link_libraries(test_s_prefetch_op PRIVATE utility) diff --git a/test/s_prefetch_op/s_prefetch_op.cpp b/test/s_prefetch_op/s_prefetch_op.cpp new file mode 100644 index 0000000000..1ec3e57794 --- /dev/null +++ b/test/s_prefetch_op/s_prefetch_op.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/device_prop.hpp" + +#include + +#if __clang_major__ >= 20 +#include "ck/utility/amd_buffer_addressing_builtins.hpp" +#else +#include "ck/utility/amd_buffer_addressing.hpp" +#endif + +#include "s_prefetch_op_util.hpp" + +template +bool run_test() +{ + bool pass = true; + + const auto s_prefetch_kernel = ck::s_prefetch_op_util::kernel_with_scalar_prefetch; + const auto s_buffer_prefetch_kernel = + ck::s_prefetch_op_util::kernel_with_scalar_buffer_prefetch; + + const auto prefetch_kernel_container = + std::make_tuple(s_prefetch_kernel, s_buffer_prefetch_kernel); + + ck::static_for<0, 2, 1>{}([&](auto i) { + std::string kernel_name = (i == 1 ? "s_buffer_prefetch" : "s_prefetch"); + pass &= ck::s_prefetch_op_util::test_constant_prefetch_impl< + decltype(std::get{}>(prefetch_kernel_container)), + T>(std::get{}>(prefetch_kernel_container), kernel_name); + }); + + return pass; +} + +int main(int, char*[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This feature is not supported by current HW, skipping tests." << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "=== Testing Constant Cache Prefetch ===" << std::endl; + + // Test different data types + pass &= run_test(); + pass &= run_test(); + + std::cout << "TestConstantPrefetch ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/s_prefetch_op/s_prefetch_op_util.hpp b/test/s_prefetch_op/s_prefetch_op_util.hpp new file mode 100644 index 0000000000..e894baf677 --- /dev/null +++ b/test/s_prefetch_op/s_prefetch_op_util.hpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +namespace ck { +namespace s_prefetch_op_util { + +// Prefetch to constant cache using AMD builtin with chunks_to_prefetch(1..32: 1 chunk = 128B) +template +__device__ __forceinline__ void prefetch_to_constant_cache(const T* addr, + unsigned int chunks_to_prefetch) +{ +#if defined(__gfx12__) + assert(chunks_to_prefetch > 0 && chunks_to_prefetch <= 32); + __builtin_amdgcn_s_prefetch_data(addr, chunks_to_prefetch - 1); // we need to pass 0..31 +#else + // ignore - not supported + (void)addr; + (void)chunks_to_prefetch; +#endif +} + +// Prefetch to constant cache using AMD builtin with chunks_to_prefetch(1..32: 1 chunk = 128B) +template +__device__ __forceinline__ void prefetch_to_constant_cache(__amdgpu_buffer_rsrc_t buf_res, + unsigned int chunks_to_prefetch) +{ +#if defined(__gfx12__) + assert(chunks_to_prefetch > 0 && chunks_to_prefetch <= 32); + __builtin_amdgcn_s_buffer_prefetch_data(buf_res, offset, chunks_to_prefetch - 1); +#else + // ignore - not supported + (void)buf_res; + (void)chunks_to_prefetch; +#endif +} + +template +__global__ void kernel_with_scalar_prefetch(const T* src, + T* dst, + const void CK_CONSTANT_ADDRESS_SPACE* scalar_data, + index_t num_elements, + index_t num_scalars) +{ + index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + const T CK_CONSTANT_ADDRESS_SPACE* scalar_elems = + static_cast(scalar_data); + + // Calculate number of 128B chunks needed to cover num_scalars elements + constexpr index_t chunk_size_bytes = 128; + constexpr index_t elements_per_chunk = chunk_size_bytes / sizeof(T); + unsigned int chunks_needed = (num_scalars + elements_per_chunk - 1) / elements_per_chunk; + + // Prefetch all scalar data at once using chunks parameter + if(threadIdx.x == 0) + { + prefetch_to_constant_cache(scalar_elems, chunks_needed); + } + + T sum = 0; + if(tid < num_elements) + { + sum = src[tid]; // load from global mem to make sure prefetch finished + } + __syncthreads(); // waits on loads from global mem + if(tid < num_elements) + { + // Access prefetched scalar data + for(index_t i = 0; i < num_scalars; i++) + { + sum += scalar_elems[i]; // should be fast due to scalars being preloaded + } + + dst[tid] = sum; + } +} + +template +__global__ void +kernel_with_scalar_buffer_prefetch(const T* src, + T* dst, + const void CK_CONSTANT_ADDRESS_SPACE* scalar_data, + index_t num_elements, + index_t num_scalars) +{ + index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + const T CK_CONSTANT_ADDRESS_SPACE* scalar_elems = + static_cast(scalar_data); + + // Calculate number of 128B chunks needed to cover num_scalars elements + constexpr index_t chunk_size_bytes = 128; + constexpr index_t elements_per_chunk = chunk_size_bytes / sizeof(T); + unsigned int chunks_needed = (num_scalars + elements_per_chunk - 1) / elements_per_chunk; + + __amdgpu_buffer_rsrc_t src_wave_buffer_resource = + make_wave_buffer_resource_new(scalar_elems, num_scalars); + + // Prefetch all scalar data at once using chunks parameter + if(threadIdx.x == 0) + { + prefetch_to_constant_cache<0>(src_wave_buffer_resource, chunks_needed); + } + + T sum = 0; + if(tid < num_elements) + { + sum = src[tid]; // load from global mem to make sure prefetch finished + } + __syncthreads(); // waits on loads from global mem + if(tid < num_elements) + { + // Access prefetched scalar data + for(index_t i = 0; i < num_scalars; i++) + { + sum += amd_s_buffer_load_impl( + src_wave_buffer_resource, + i * sizeof(T)); // should be fast due to scalars being preloaded + } + + dst[tid] = sum; + } +} + +template +bool test_constant_prefetch_impl(const PrefetchKernel& prefetch_kernel, + const std::string& kernel_name) +{ + constexpr index_t num_elements = 512; + constexpr index_t num_scalars = 512; + constexpr index_t block_size = 256; + constexpr index_t grid_size = (num_elements + block_size - 1) / block_size; + + std::cout << "Testing " << kernel_name << " to constant cache for type: " << typeid(T).name() + << std::endl; + std::cout << "Elements: " << num_elements << ", Scalars: " << num_scalars << std::endl; + + // Host data + std::vector h_src(num_elements); + std::vector h_scalar(num_scalars); + std::vector h_dst_with_prefetch_chunks(num_elements); + std::vector h_expected(num_elements); + + // Initialize data + for(index_t i = 0; i < num_elements; i++) + { + h_src[i] = static_cast(i % 100); + } + + T scalar_sum = 0; + for(index_t i = 0; i < num_scalars; i++) + { + h_scalar[i] = static_cast(i + 1); + scalar_sum += h_scalar[i]; + } + + // Expected results + for(index_t i = 0; i < num_elements; i++) + { + h_expected[i] = h_src[i] + scalar_sum; + } + + // Device memory + DeviceMem d_src(sizeof(T) * num_elements); + DeviceMem d_scalar(sizeof(T) * num_scalars); + DeviceMem d_dst_with_prefetch_chunks(sizeof(T) * num_elements); + + d_src.ToDevice(h_src.data()); + d_scalar.ToDevice(h_scalar.data()); + + hipStream_t stream; + hip_check_error(hipStreamCreate(&stream)); + + prefetch_kernel<<>>( + static_cast(d_src.GetDeviceBuffer()), + static_cast(d_dst_with_prefetch_chunks.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(d_scalar.GetDeviceBuffer()), + num_elements, + num_scalars); + + hip_check_error(hipStreamSynchronize(stream)); + + // Copy results back + d_dst_with_prefetch_chunks.FromDevice(h_dst_with_prefetch_chunks.data()); + + // Verify results + bool pass = ck::utils::check_err(h_dst_with_prefetch_chunks, h_expected); + + std::cout << (pass ? "PASS" : "FAIL") << std::endl; + + hip_check_error(hipStreamDestroy(stream)); + + return pass; +} + +} // namespace s_prefetch_op_util +} // namespace ck