[CK] Added s_prefetch unit test.

-added s_buffer_load_b32/64 assembly
-added amd_s_buffer_load_impl

Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>


[ROCm/composable_kernel commit: f3ef7acca0]
This commit is contained in:
Michal Kulikowski
2025-11-05 14:09:04 +01:00
committed by Michał Kulikowski
parent b9ee41c660
commit 8fc5eca798
7 changed files with 385 additions and 0 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include "data_type.hpp"
#include "amd_inline_asm.hpp"
namespace ck {
@@ -1064,4 +1065,48 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
}
#endif
template <index_t N>
__device__ typename vector_type<int8_t, N>::type
amd_s_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
index_t src_wave_addr_offset)
{
static_assert(N == 4 || N == 8, "wrong! not implemented");
// TODO: add other variants of s_buffer_load
if constexpr(N == 4)
{
int32_t tmp =
amd_assembly_s_buffer_load_b32(src_wave_buffer_resource, src_wave_addr_offset);
return bit_cast<int8x4_t>(tmp);
}
else if constexpr(N == 8)
{
int32x2_t tmp =
amd_assembly_s_buffer_load_b64(src_wave_buffer_resource, src_wave_addr_offset);
return bit_cast<int8x8_t>(tmp);
}
}
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type
amd_s_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
index_t src_wave_addr_offset)
{
static_assert((is_same<T, double>::value && (N == 1)) ||
(is_same<T, float>::value && (N == 1 || N == 2)) ||
(is_same<T, half_t>::value && (N == 2 || N == 4)) ||
(is_same<T, bhalf_t>::value && (N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2)) ||
(is_same<T, f8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, bf8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, uint8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, pk_i4_t>::value && (N == 4 || N == 8)),
"wrong! not implemented");
using r_t = typename vector_type<T, N>::type;
auto raw_data =
amd_s_buffer_load_impl_raw<sizeof(T) * N>(src_wave_buffer_resource, src_wave_addr_offset);
return bit_cast<r_t>(raw_data);
}
} // namespace ck

View File

@@ -3,6 +3,7 @@
#pragma once
#include "data_type.hpp"
#include "amd_inline_asm.hpp"
namespace ck {
@@ -885,4 +886,48 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
}
#endif
template <index_t N>
__device__ typename vector_type<int8_t, N>::type
amd_s_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
index_t src_wave_addr_offset)
{
static_assert(N == 4 || N == 8, "wrong! not implemented");
// TODO: add other variants of s_buffer_load
if constexpr(N == 4)
{
int32_t tmp =
amd_assembly_s_buffer_load_b32(src_wave_buffer_resource, src_wave_addr_offset);
return bit_cast<int8x4_t>(tmp);
}
else if constexpr(N == 8)
{
int32x2_t tmp =
amd_assembly_s_buffer_load_b64(src_wave_buffer_resource, src_wave_addr_offset);
return bit_cast<int8x8_t>(tmp);
}
}
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type
amd_s_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
index_t src_wave_addr_offset)
{
static_assert((is_same<T, double>::value && (N == 1)) ||
(is_same<T, float>::value && (N == 1 || N == 2)) ||
(is_same<T, half_t>::value && (N == 2 || N == 4)) ||
(is_same<T, bhalf_t>::value && (N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2)) ||
(is_same<T, f8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, bf8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, uint8_t>::value && (N == 4 || N == 8)) ||
(is_same<T, pk_i4_t>::value && (N == 4 || N == 8)),
"wrong! not implemented");
using r_t = typename vector_type<T, N>::type;
auto raw_data =
amd_s_buffer_load_impl_raw<sizeof(T) * N>(src_wave_buffer_resource, src_wave_addr_offset);
return bit_cast<r_t>(raw_data);
}
} // namespace ck

View File

@@ -431,5 +431,29 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3);
}
#endif
// s_buffer_loads
inline __device__ int32_t
amd_assembly_s_buffer_load_b32(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, unsigned int offset)
{
int32_t result;
asm volatile("s_buffer_load_b32 %0, %1, %2"
: "=s"(result)
: "s"(src_wave_buffer_resource), "s"(offset)
: "memory");
return result;
}
inline __device__ int32x2_t
amd_assembly_s_buffer_load_b64(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, unsigned int offset)
{
int32x2_t result;
asm volatile("s_buffer_load_b64 %0, %1, %2"
: "=s"(result)
: "s"(src_wave_buffer_resource), "s"(offset)
: "memory");
return result;
}
} // namespace ck
#endif

View File

@@ -296,5 +296,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_subdirectory(mx_mfma_op)
add_subdirectory(gemm_mx)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx12")
add_subdirectory(s_prefetch_op)
endif()
add_subdirectory(position_embedding)
add_subdirectory(scatter_gather)

View File

@@ -0,0 +1,2 @@
add_test_executable(test_s_prefetch_op s_prefetch_op.cpp)
target_link_libraries(test_s_prefetch_op PRIVATE utility)

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include <chrono>
#include "ck/ck.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/device_prop.hpp"
#include <hip/hip_runtime.h>
#if __clang_major__ >= 20
#include "ck/utility/amd_buffer_addressing_builtins.hpp"
#else
#include "ck/utility/amd_buffer_addressing.hpp"
#endif
#include "s_prefetch_op_util.hpp"
template <typename T>
bool run_test()
{
bool pass = true;
const auto s_prefetch_kernel = ck::s_prefetch_op_util::kernel_with_scalar_prefetch<T>;
const auto s_buffer_prefetch_kernel =
ck::s_prefetch_op_util::kernel_with_scalar_buffer_prefetch<T>;
const auto prefetch_kernel_container =
std::make_tuple(s_prefetch_kernel, s_buffer_prefetch_kernel);
ck::static_for<0, 2, 1>{}([&](auto i) {
std::string kernel_name = (i == 1 ? "s_buffer_prefetch" : "s_prefetch");
pass &= ck::s_prefetch_op_util::test_constant_prefetch_impl<
decltype(std::get<ck::Number<i>{}>(prefetch_kernel_container)),
T>(std::get<ck::Number<i>{}>(prefetch_kernel_container), kernel_name);
});
return pass;
}
int main(int, char*[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This feature is not supported by current HW, skipping tests." << std::endl;
return 0;
}
bool pass = true;
std::cout << "=== Testing Constant Cache Prefetch ===" << std::endl;
// Test different data types
pass &= run_test<float>();
pass &= run_test<double>();
std::cout << "TestConstantPrefetch ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}

View File

@@ -0,0 +1,197 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
namespace ck {
namespace s_prefetch_op_util {
// Prefetch to constant cache using AMD builtin with chunks_to_prefetch(1..32: 1 chunk = 128B)
template <typename T>
__device__ __forceinline__ void prefetch_to_constant_cache(const T* addr,
unsigned int chunks_to_prefetch)
{
#if defined(__gfx12__)
assert(chunks_to_prefetch > 0 && chunks_to_prefetch <= 32);
__builtin_amdgcn_s_prefetch_data(addr, chunks_to_prefetch - 1); // we need to pass 0..31
#else
// ignore - not supported
(void)addr;
(void)chunks_to_prefetch;
#endif
}
// Prefetch to constant cache using AMD builtin with chunks_to_prefetch(1..32: 1 chunk = 128B)
template <unsigned int offset>
__device__ __forceinline__ void prefetch_to_constant_cache(__amdgpu_buffer_rsrc_t buf_res,
unsigned int chunks_to_prefetch)
{
#if defined(__gfx12__)
assert(chunks_to_prefetch > 0 && chunks_to_prefetch <= 32);
__builtin_amdgcn_s_buffer_prefetch_data(buf_res, offset, chunks_to_prefetch - 1);
#else
// ignore - not supported
(void)buf_res;
(void)chunks_to_prefetch;
#endif
}
template <typename T>
__global__ void kernel_with_scalar_prefetch(const T* src,
T* dst,
const void CK_CONSTANT_ADDRESS_SPACE* scalar_data,
index_t num_elements,
index_t num_scalars)
{
index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const T CK_CONSTANT_ADDRESS_SPACE* scalar_elems =
static_cast<const T CK_CONSTANT_ADDRESS_SPACE*>(scalar_data);
// Calculate number of 128B chunks needed to cover num_scalars elements
constexpr index_t chunk_size_bytes = 128;
constexpr index_t elements_per_chunk = chunk_size_bytes / sizeof(T);
unsigned int chunks_needed = (num_scalars + elements_per_chunk - 1) / elements_per_chunk;
// Prefetch all scalar data at once using chunks parameter
if(threadIdx.x == 0)
{
prefetch_to_constant_cache(scalar_elems, chunks_needed);
}
T sum = 0;
if(tid < num_elements)
{
sum = src[tid]; // load from global mem to make sure prefetch finished
}
__syncthreads(); // waits on loads from global mem
if(tid < num_elements)
{
// Access prefetched scalar data
for(index_t i = 0; i < num_scalars; i++)
{
sum += scalar_elems[i]; // should be fast due to scalars being preloaded
}
dst[tid] = sum;
}
}
template <typename T>
__global__ void
kernel_with_scalar_buffer_prefetch(const T* src,
T* dst,
const void CK_CONSTANT_ADDRESS_SPACE* scalar_data,
index_t num_elements,
index_t num_scalars)
{
index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const T CK_CONSTANT_ADDRESS_SPACE* scalar_elems =
static_cast<const T CK_CONSTANT_ADDRESS_SPACE*>(scalar_data);
// Calculate number of 128B chunks needed to cover num_scalars elements
constexpr index_t chunk_size_bytes = 128;
constexpr index_t elements_per_chunk = chunk_size_bytes / sizeof(T);
unsigned int chunks_needed = (num_scalars + elements_per_chunk - 1) / elements_per_chunk;
__amdgpu_buffer_rsrc_t src_wave_buffer_resource =
make_wave_buffer_resource_new(scalar_elems, num_scalars);
// Prefetch all scalar data at once using chunks parameter
if(threadIdx.x == 0)
{
prefetch_to_constant_cache<0>(src_wave_buffer_resource, chunks_needed);
}
T sum = 0;
if(tid < num_elements)
{
sum = src[tid]; // load from global mem to make sure prefetch finished
}
__syncthreads(); // waits on loads from global mem
if(tid < num_elements)
{
// Access prefetched scalar data
for(index_t i = 0; i < num_scalars; i++)
{
sum += amd_s_buffer_load_impl<T, 1>(
src_wave_buffer_resource,
i * sizeof(T)); // should be fast due to scalars being preloaded
}
dst[tid] = sum;
}
}
template <typename PrefetchKernel, typename T>
bool test_constant_prefetch_impl(const PrefetchKernel& prefetch_kernel,
const std::string& kernel_name)
{
constexpr index_t num_elements = 512;
constexpr index_t num_scalars = 512;
constexpr index_t block_size = 256;
constexpr index_t grid_size = (num_elements + block_size - 1) / block_size;
std::cout << "Testing " << kernel_name << " to constant cache for type: " << typeid(T).name()
<< std::endl;
std::cout << "Elements: " << num_elements << ", Scalars: " << num_scalars << std::endl;
// Host data
std::vector<T> h_src(num_elements);
std::vector<T> h_scalar(num_scalars);
std::vector<T> h_dst_with_prefetch_chunks(num_elements);
std::vector<T> h_expected(num_elements);
// Initialize data
for(index_t i = 0; i < num_elements; i++)
{
h_src[i] = static_cast<T>(i % 100);
}
T scalar_sum = 0;
for(index_t i = 0; i < num_scalars; i++)
{
h_scalar[i] = static_cast<T>(i + 1);
scalar_sum += h_scalar[i];
}
// Expected results
for(index_t i = 0; i < num_elements; i++)
{
h_expected[i] = h_src[i] + scalar_sum;
}
// Device memory
DeviceMem d_src(sizeof(T) * num_elements);
DeviceMem d_scalar(sizeof(T) * num_scalars);
DeviceMem d_dst_with_prefetch_chunks(sizeof(T) * num_elements);
d_src.ToDevice(h_src.data());
d_scalar.ToDevice(h_scalar.data());
hipStream_t stream;
hip_check_error(hipStreamCreate(&stream));
prefetch_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<const T*>(d_src.GetDeviceBuffer()),
static_cast<T*>(d_dst_with_prefetch_chunks.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(d_scalar.GetDeviceBuffer()),
num_elements,
num_scalars);
hip_check_error(hipStreamSynchronize(stream));
// Copy results back
d_dst_with_prefetch_chunks.FromDevice(h_dst_with_prefetch_chunks.data());
// Verify results
bool pass = ck::utils::check_err(h_dst_with_prefetch_chunks, h_expected);
std::cout << (pass ? "PASS" : "FAIL") << std::endl;
hip_check_error(hipStreamDestroy(stream));
return pass;
}
} // namespace s_prefetch_op_util
} // namespace ck