diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index d839518285..39260c8acd 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -58,6 +58,8 @@ #include "ck_tile/builder/types.hpp" #include "ck_tile/builder/versions.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" + namespace ck_tile::builder::factory_internal { // Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types. @@ -665,7 +667,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = @@ -762,7 +764,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = @@ -858,7 +860,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = @@ -980,7 +982,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp new file mode 100644 index 0000000000..3ba2bf24dd --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp @@ -0,0 +1,47 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder { +/********************************************** + * constexpr helper functions for optional parameters + **********************************************/ + +template +concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; }; + +template +concept ProvidesConvolutionDirection = requires { Sig.direction; }; + +template +constexpr auto get_elementwise_operation() +{ + if constexpr(ProvidesElementwiseOperation) + { + return Sig.elementwise_operation; + } + else + { + return ElementwiseOperation::PASS_THROUGH; + } +} + +template +constexpr auto get_conv_direction() +{ + if constexpr(ProvidesConvolutionDirection) + { + return Sig.direction; + } + else + { + return ConvDirection::FORWARD; + } +} +} // namespace ck_tile::builder diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index b53cdc39c7..c2f7039348 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -49,6 +49,7 @@ struct ConvSignatureWithInvalidOptionalParams ckb::GroupConvDeviceOp device_operation = ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; }; + static_assert(!ckb::ConvSignatureDescriptor); struct DefaultAlgorithm diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index efe1f300c2..79efd77edb 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -431,5 +431,6 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, c3); } #endif + } // namespace ck #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d47e55db64..3a667cb551 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -296,5 +296,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_subdirectory(mx_mfma_op) add_subdirectory(gemm_mx) endif() +if(SUPPORTED_GPU_TARGETS MATCHES "gfx12") + add_subdirectory(s_prefetch_op) +endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) diff --git a/test/s_prefetch_op/CMakeLists.txt b/test/s_prefetch_op/CMakeLists.txt new file mode 100644 index 0000000000..1b598cc952 --- /dev/null +++ b/test/s_prefetch_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_s_prefetch_op s_prefetch_op.cpp) +target_link_libraries(test_s_prefetch_op PRIVATE utility) diff --git a/test/s_prefetch_op/s_prefetch_op.cpp b/test/s_prefetch_op/s_prefetch_op.cpp new file mode 100644 index 0000000000..fc0ae84132 --- /dev/null +++ b/test/s_prefetch_op/s_prefetch_op.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" + +#include "s_prefetch_op_util.hpp" + +template +bool run_test(bool time_kernels) +{ + bool pass = true; + + const auto s_prefetch_kernel = + ck::s_prefetch_op_util::kernel_with_prefetch>; + const auto s_buffer_prefetch_kernel = ck::s_prefetch_op_util::kernel_with_prefetch< + T, + NUM_THREADS, + NUM_SCALARS, + ck::s_prefetch_op_util::SBufferPrefetchDataOp>; + + const auto prefetch_kernel_container = + std::make_tuple(s_prefetch_kernel, s_buffer_prefetch_kernel); + + ck::static_for<0, 2, 1>{}([&](auto i) { + std::string kernel_name = (i == 1 ? "s_buffer_prefetch" : "s_prefetch"); + + auto kernel = std::get{}>(prefetch_kernel_container); + + pass &= ck::s_prefetch_op_util:: + test_prefetch_impl( + time_kernels, kernel, kernel_name); + }); + + return pass; +} + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This feature is not supported by current HW, skipping tests." << std::endl; + return 0; + } + + bool time_kernels = false; + + if(argc == 2) + { + time_kernels = std::stoi(argv[1]); + } + + bool pass = true; + + std::cout << "=== Testing Constant Cache Prefetch ===" << std::endl; + + // Test different data types + pass &= run_test(time_kernels); + pass &= run_test(time_kernels); + + std::cout << "TestConstantPrefetch ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; +} diff --git a/test/s_prefetch_op/s_prefetch_op_util.hpp b/test/s_prefetch_op/s_prefetch_op_util.hpp new file mode 100644 index 0000000000..077b876b1a --- /dev/null +++ b/test/s_prefetch_op/s_prefetch_op_util.hpp @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +#include + +namespace ck { +namespace s_prefetch_op_util { + +// Enable scalar prefetch in hardware (required on gfx12 before using s_prefetch) +__device__ __forceinline__ void enable_scalar_prefetch() +{ +#if defined(__gfx12__) + // SCALAR_PREFETCH_EN is bit 24 in MODE register (hwreg 1) + // Set 1 bit at offset 24 to value 1 + __builtin_amdgcn_s_setreg(1 | (24 << 6), 1); // Set bit to 1 +#endif +} + +template +struct SPrefetchDataOp +{ + // Prefetch to constant cache using AMD builtin with cachelines to prefetch(1..32) + __device__ __forceinline__ void operator()(const T CK_CONSTANT_ADDRESS_SPACE* addr, + unsigned int num_cachelines) const + { +#if defined(__gfx12__) + assert(num_cachelines > 0 && num_cachelines <= 32); + __builtin_amdgcn_s_prefetch_data(addr, num_cachelines - 1); // we need to pass 0..31 +#else + // ignore - not supported + (void)addr; + (void)num_cachelines; +#endif + } +}; + +template +struct SBufferPrefetchDataOp +{ + // Prefetch to constant cache using AMD builtin with cachelines to prefetch(1..32) + __device__ __forceinline__ void operator()(const T CK_CONSTANT_ADDRESS_SPACE* addr, + unsigned int num_cachelines) const + { +#if defined(__gfx12__) + __amdgpu_buffer_rsrc_t buf_res = make_wave_buffer_resource_new(addr, NUM_SCALARS); + assert(num_cachelines > 0 && num_cachelines <= 32); + __builtin_amdgcn_s_buffer_prefetch_data(buf_res, 0, num_cachelines - 1); +#else + // ignore - not supported + (void)addr; + (void)num_cachelines; +#endif + } +}; + +template +__global__ void kernel_with_prefetch(const T* src, + T* dst, + const T CK_CONSTANT_ADDRESS_SPACE* scalar_data, + bool enable_prefetch) +{ + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + // Calculate number of 128B cachelines needed to cover num_scalars elements + constexpr index_t cachelineSize = 128; + constexpr index_t elements_per_cachelineSize = cachelineSize / sizeof(T); + constexpr unsigned int cachelinesNeeded = + (NUM_SCALARS + elements_per_cachelineSize - 1) / elements_per_cachelineSize; + + // Prefetch all scalar data at once + if(threadIdx.x == 0) + { + if(enable_prefetch) + { + enable_scalar_prefetch(); + } + PrefetchOp{}(scalar_data, cachelinesNeeded); + } + + T sum = 0; + if(tid < NUM_THREADS) + { + sum = src[tid]; // load from global mem to give time for prefetch to finish or be close to + // finishs + } + __syncthreads(); // waits on loads from global mem + if(tid < NUM_THREADS) + { + // Access prefetched scalar data + for(uint32_t i = 0; i < NUM_SCALARS; i++) + { + sum += scalar_data[i]; // should be fast due to scalars being preloaded + } + + dst[tid] = sum; + } +} + +template +bool test_prefetch_impl(bool time_kernels, + const PrefetchKernel& prefetch_kernel, + const std::string& kernel_name) +{ + // TODO: maybe add more prefetch instructions inside kernel to support more values + assert(NUM_SCALARS / sizeof(T) < (128 * 32)); + constexpr index_t num_elements = NUM_THREADS; + constexpr index_t num_scalars = NUM_SCALARS; + constexpr index_t block_size = 256; + constexpr index_t grid_size = (num_elements + block_size - 1) / block_size; + + std::cout << "Testing " << kernel_name << " to constant cache for type: " << typeid(T).name() + << std::endl; + std::cout << "Elements: " << num_elements << ", Scalars: " << num_scalars << std::endl; + + // Host data + std::vector h_src(num_elements); + std::vector h_scalar(num_scalars); + std::vector h_dst_with_prefetch_chunks(num_elements); + std::vector h_expected(num_elements); + + // Initialize data + for(index_t i = 0; i < num_elements; i++) + { + h_src[i] = static_cast(i % 100); + } + + T scalar_sum = 0; + for(index_t i = 0; i < num_scalars; i++) + { + h_scalar[i] = static_cast(i + 1); + scalar_sum += h_scalar[i]; + } + + // Expected results + for(index_t i = 0; i < num_elements; i++) + { + h_expected[i] = h_src[i] + scalar_sum; + } + + // Device memory + DeviceMem d_src(sizeof(T) * num_elements); + DeviceMem d_scalar(sizeof(T) * num_scalars); + DeviceMem d_dst_with_prefetch_chunks(sizeof(T) * num_elements); + + d_src.ToDevice(h_src.data()); + d_scalar.ToDevice(h_scalar.data()); + + hipStream_t stream; + hip_check_error(hipStreamCreate(&stream)); + + if(time_kernels) + { + ck::static_for<0, 2, 1>{}([&](auto static_i) { + constexpr bool prefetch_enabled = static_i == 0; + std::cout << "PREFETCH " << (prefetch_enabled ? "ENABLED!" : "DISABLED!") << std::endl; + + constexpr int num_warmup = 1; + constexpr int num_iterations = 10; + + // Warmup runs + for(int i = 0; i < num_warmup; i++) + { + prefetch_kernel<<>>( + static_cast(d_src.GetDeviceBuffer()), + static_cast(d_dst_with_prefetch_chunks.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + static_cast(d_scalar.GetDeviceBuffer())), + prefetch_enabled); + } + hip_check_error(hipStreamSynchronize(stream)); + + // Performance measurement + hipEvent_t start, stop; + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipEventRecord(start, stream)); + for(int i = 0; i < num_iterations; i++) + { + prefetch_kernel<<>>( + static_cast(d_src.GetDeviceBuffer()), + static_cast(d_dst_with_prefetch_chunks.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + static_cast(d_scalar.GetDeviceBuffer())), + prefetch_enabled); + } + hip_check_error(hipEventRecord(stop, stream)); + + hip_check_error(hipStreamSynchronize(stream)); + + float elapsed_ms = 0; + hip_check_error(hipEventElapsedTime(&elapsed_ms, start, stop)); + + float avg_time_us = (elapsed_ms * 1000.0f) / num_iterations; + float total_bytes = (num_elements * sizeof(T) + num_scalars * sizeof(T)); // read + float bandwidth_gb_s = (total_bytes / (avg_time_us * 1e-6)) / 1e9; + float ops_per_iteration = num_elements * num_scalars; // adds + float gflops = (ops_per_iteration / (avg_time_us * 1e-6)) / 1e9; + + std::cout << " Performance: " << std::endl; + std::cout << " Average kernel time: " << avg_time_us << " us" << std::endl; + std::cout << " Effective bandwidth: " << bandwidth_gb_s << " GB/s" << std::endl; + std::cout << " Compute throughput: " << gflops << " GFLOPS" << std::endl; + + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + }); + } + else + { + prefetch_kernel<<>>( + static_cast(d_src.GetDeviceBuffer()), + static_cast(d_dst_with_prefetch_chunks.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + static_cast(d_scalar.GetDeviceBuffer())), + true); + + hip_check_error(hipStreamSynchronize(stream)); + } + + // Copy results back + d_dst_with_prefetch_chunks.FromDevice(h_dst_with_prefetch_chunks.data()); + + // Verify results + bool pass = ck::utils::check_err(h_dst_with_prefetch_chunks, h_expected); + + std::cout << " Correctness: " << (pass ? "PASS" : "FAIL") << std::endl; + std::cout << std::endl; + + hip_check_error(hipStreamDestroy(stream)); + + return pass; +} + +} // namespace s_prefetch_op_util +} // namespace ck