mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit 'cd8af997e6d1fde6bc4397bd6ab4fca46510e776' into develop
This commit is contained in:
@@ -58,6 +58,8 @@
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory_internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
@@ -665,7 +667,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -762,7 +764,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -858,7 +860,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -980,7 +982,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
/**********************************************
|
||||
* constexpr helper functions for optional parameters
|
||||
**********************************************/
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; };
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesConvolutionDirection = requires { Sig.direction; };
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_elementwise_operation()
|
||||
{
|
||||
if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_conv_direction()
|
||||
{
|
||||
if constexpr(ProvidesConvolutionDirection<Sig>)
|
||||
{
|
||||
return Sig.direction;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvDirection::FORWARD;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile::builder
|
||||
@@ -49,6 +49,7 @@ struct ConvSignatureWithInvalidOptionalParams
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
|
||||
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
|
||||
@@ -431,5 +431,6 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
|
||||
c3);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -296,5 +296,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
add_subdirectory(mx_mfma_op)
|
||||
add_subdirectory(gemm_mx)
|
||||
endif()
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
add_subdirectory(s_prefetch_op)
|
||||
endif()
|
||||
add_subdirectory(position_embedding)
|
||||
add_subdirectory(scatter_gather)
|
||||
|
||||
2
test/s_prefetch_op/CMakeLists.txt
Normal file
2
test/s_prefetch_op/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_test_executable(test_s_prefetch_op s_prefetch_op.cpp)
|
||||
target_link_libraries(test_s_prefetch_op PRIVATE utility)
|
||||
66
test/s_prefetch_op/s_prefetch_op.cpp
Normal file
66
test/s_prefetch_op/s_prefetch_op.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "s_prefetch_op_util.hpp"
|
||||
|
||||
template <typename T, uint32_t NUM_THREADS, uint32_t NUM_SCALARS>
|
||||
bool run_test(bool time_kernels)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
const auto s_prefetch_kernel =
|
||||
ck::s_prefetch_op_util::kernel_with_prefetch<T,
|
||||
NUM_THREADS,
|
||||
NUM_SCALARS,
|
||||
ck::s_prefetch_op_util::SPrefetchDataOp<T>>;
|
||||
const auto s_buffer_prefetch_kernel = ck::s_prefetch_op_util::kernel_with_prefetch<
|
||||
T,
|
||||
NUM_THREADS,
|
||||
NUM_SCALARS,
|
||||
ck::s_prefetch_op_util::SBufferPrefetchDataOp<T, NUM_SCALARS>>;
|
||||
|
||||
const auto prefetch_kernel_container =
|
||||
std::make_tuple(s_prefetch_kernel, s_buffer_prefetch_kernel);
|
||||
|
||||
ck::static_for<0, 2, 1>{}([&](auto i) {
|
||||
std::string kernel_name = (i == 1 ? "s_buffer_prefetch" : "s_prefetch");
|
||||
|
||||
auto kernel = std::get<ck::Number<i>{}>(prefetch_kernel_container);
|
||||
|
||||
pass &= ck::s_prefetch_op_util::
|
||||
test_prefetch_impl<decltype(kernel), T, NUM_THREADS, NUM_SCALARS>(
|
||||
time_kernels, kernel, kernel_name);
|
||||
});
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This feature is not supported by current HW, skipping tests." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool time_kernels = false;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
time_kernels = std::stoi(argv[1]);
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::cout << "=== Testing Constant Cache Prefetch ===" << std::endl;
|
||||
|
||||
// Test different data types
|
||||
pass &= run_test<float, 4096, 1024>(time_kernels);
|
||||
pass &= run_test<double, 4096, 512>(time_kernels);
|
||||
|
||||
std::cout << "TestConstantPrefetch ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
249
test/s_prefetch_op/s_prefetch_op_util.hpp
Normal file
249
test/s_prefetch_op/s_prefetch_op_util.hpp
Normal file
@@ -0,0 +1,249 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck {
|
||||
namespace s_prefetch_op_util {
|
||||
|
||||
// Enable scalar prefetch in hardware (required on gfx12 before using s_prefetch)
|
||||
__device__ __forceinline__ void enable_scalar_prefetch()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
// SCALAR_PREFETCH_EN is bit 24 in MODE register (hwreg 1)
|
||||
// Set 1 bit at offset 24 to value 1
|
||||
__builtin_amdgcn_s_setreg(1 | (24 << 6), 1); // Set bit to 1
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SPrefetchDataOp
|
||||
{
|
||||
// Prefetch to constant cache using AMD builtin with cachelines to prefetch(1..32)
|
||||
__device__ __forceinline__ void operator()(const T CK_CONSTANT_ADDRESS_SPACE* addr,
|
||||
unsigned int num_cachelines) const
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
assert(num_cachelines > 0 && num_cachelines <= 32);
|
||||
__builtin_amdgcn_s_prefetch_data(addr, num_cachelines - 1); // we need to pass 0..31
|
||||
#else
|
||||
// ignore - not supported
|
||||
(void)addr;
|
||||
(void)num_cachelines;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, uint32_t NUM_SCALARS>
|
||||
struct SBufferPrefetchDataOp
|
||||
{
|
||||
// Prefetch to constant cache using AMD builtin with cachelines to prefetch(1..32)
|
||||
__device__ __forceinline__ void operator()(const T CK_CONSTANT_ADDRESS_SPACE* addr,
|
||||
unsigned int num_cachelines) const
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
__amdgpu_buffer_rsrc_t buf_res = make_wave_buffer_resource_new(addr, NUM_SCALARS);
|
||||
assert(num_cachelines > 0 && num_cachelines <= 32);
|
||||
__builtin_amdgcn_s_buffer_prefetch_data(buf_res, 0, num_cachelines - 1);
|
||||
#else
|
||||
// ignore - not supported
|
||||
(void)addr;
|
||||
(void)num_cachelines;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, uint32_t NUM_THREADS, uint32_t NUM_SCALARS, typename PrefetchOp>
|
||||
__global__ void kernel_with_prefetch(const T* src,
|
||||
T* dst,
|
||||
const T CK_CONSTANT_ADDRESS_SPACE* scalar_data,
|
||||
bool enable_prefetch)
|
||||
{
|
||||
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Calculate number of 128B cachelines needed to cover num_scalars elements
|
||||
constexpr index_t cachelineSize = 128;
|
||||
constexpr index_t elements_per_cachelineSize = cachelineSize / sizeof(T);
|
||||
constexpr unsigned int cachelinesNeeded =
|
||||
(NUM_SCALARS + elements_per_cachelineSize - 1) / elements_per_cachelineSize;
|
||||
|
||||
// Prefetch all scalar data at once
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
if(enable_prefetch)
|
||||
{
|
||||
enable_scalar_prefetch();
|
||||
}
|
||||
PrefetchOp{}(scalar_data, cachelinesNeeded);
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
if(tid < NUM_THREADS)
|
||||
{
|
||||
sum = src[tid]; // load from global mem to give time for prefetch to finish or be close to
|
||||
// finishs
|
||||
}
|
||||
__syncthreads(); // waits on loads from global mem
|
||||
if(tid < NUM_THREADS)
|
||||
{
|
||||
// Access prefetched scalar data
|
||||
for(uint32_t i = 0; i < NUM_SCALARS; i++)
|
||||
{
|
||||
sum += scalar_data[i]; // should be fast due to scalars being preloaded
|
||||
}
|
||||
|
||||
dst[tid] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrefetchKernel, typename T, uint32_t NUM_THREADS, uint32_t NUM_SCALARS>
|
||||
bool test_prefetch_impl(bool time_kernels,
|
||||
const PrefetchKernel& prefetch_kernel,
|
||||
const std::string& kernel_name)
|
||||
{
|
||||
// TODO: maybe add more prefetch instructions inside kernel to support more values
|
||||
assert(NUM_SCALARS / sizeof(T) < (128 * 32));
|
||||
constexpr index_t num_elements = NUM_THREADS;
|
||||
constexpr index_t num_scalars = NUM_SCALARS;
|
||||
constexpr index_t block_size = 256;
|
||||
constexpr index_t grid_size = (num_elements + block_size - 1) / block_size;
|
||||
|
||||
std::cout << "Testing " << kernel_name << " to constant cache for type: " << typeid(T).name()
|
||||
<< std::endl;
|
||||
std::cout << "Elements: " << num_elements << ", Scalars: " << num_scalars << std::endl;
|
||||
|
||||
// Host data
|
||||
std::vector<T> h_src(num_elements);
|
||||
std::vector<T> h_scalar(num_scalars);
|
||||
std::vector<T> h_dst_with_prefetch_chunks(num_elements);
|
||||
std::vector<T> h_expected(num_elements);
|
||||
|
||||
// Initialize data
|
||||
for(index_t i = 0; i < num_elements; i++)
|
||||
{
|
||||
h_src[i] = static_cast<T>(i % 100);
|
||||
}
|
||||
|
||||
T scalar_sum = 0;
|
||||
for(index_t i = 0; i < num_scalars; i++)
|
||||
{
|
||||
h_scalar[i] = static_cast<T>(i + 1);
|
||||
scalar_sum += h_scalar[i];
|
||||
}
|
||||
|
||||
// Expected results
|
||||
for(index_t i = 0; i < num_elements; i++)
|
||||
{
|
||||
h_expected[i] = h_src[i] + scalar_sum;
|
||||
}
|
||||
|
||||
// Device memory
|
||||
DeviceMem d_src(sizeof(T) * num_elements);
|
||||
DeviceMem d_scalar(sizeof(T) * num_scalars);
|
||||
DeviceMem d_dst_with_prefetch_chunks(sizeof(T) * num_elements);
|
||||
|
||||
d_src.ToDevice(h_src.data());
|
||||
d_scalar.ToDevice(h_scalar.data());
|
||||
|
||||
hipStream_t stream;
|
||||
hip_check_error(hipStreamCreate(&stream));
|
||||
|
||||
if(time_kernels)
|
||||
{
|
||||
ck::static_for<0, 2, 1>{}([&](auto static_i) {
|
||||
constexpr bool prefetch_enabled = static_i == 0;
|
||||
std::cout << "PREFETCH " << (prefetch_enabled ? "ENABLED!" : "DISABLED!") << std::endl;
|
||||
|
||||
constexpr int num_warmup = 1;
|
||||
constexpr int num_iterations = 10;
|
||||
|
||||
// Warmup runs
|
||||
for(int i = 0; i < num_warmup; i++)
|
||||
{
|
||||
prefetch_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
static_cast<const T*>(d_src.GetDeviceBuffer()),
|
||||
static_cast<T*>(d_dst_with_prefetch_chunks.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
static_cast<const T*>(d_scalar.GetDeviceBuffer())),
|
||||
prefetch_enabled);
|
||||
}
|
||||
hip_check_error(hipStreamSynchronize(stream));
|
||||
|
||||
// Performance measurement
|
||||
hipEvent_t start, stop;
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipEventRecord(start, stream));
|
||||
for(int i = 0; i < num_iterations; i++)
|
||||
{
|
||||
prefetch_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
static_cast<const T*>(d_src.GetDeviceBuffer()),
|
||||
static_cast<T*>(d_dst_with_prefetch_chunks.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
static_cast<const T*>(d_scalar.GetDeviceBuffer())),
|
||||
prefetch_enabled);
|
||||
}
|
||||
hip_check_error(hipEventRecord(stop, stream));
|
||||
|
||||
hip_check_error(hipStreamSynchronize(stream));
|
||||
|
||||
float elapsed_ms = 0;
|
||||
hip_check_error(hipEventElapsedTime(&elapsed_ms, start, stop));
|
||||
|
||||
float avg_time_us = (elapsed_ms * 1000.0f) / num_iterations;
|
||||
float total_bytes = (num_elements * sizeof(T) + num_scalars * sizeof(T)); // read
|
||||
float bandwidth_gb_s = (total_bytes / (avg_time_us * 1e-6)) / 1e9;
|
||||
float ops_per_iteration = num_elements * num_scalars; // adds
|
||||
float gflops = (ops_per_iteration / (avg_time_us * 1e-6)) / 1e9;
|
||||
|
||||
std::cout << " Performance: " << std::endl;
|
||||
std::cout << " Average kernel time: " << avg_time_us << " us" << std::endl;
|
||||
std::cout << " Effective bandwidth: " << bandwidth_gb_s << " GB/s" << std::endl;
|
||||
std::cout << " Compute throughput: " << gflops << " GFLOPS" << std::endl;
|
||||
|
||||
hip_check_error(hipEventDestroy(start));
|
||||
hip_check_error(hipEventDestroy(stop));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
prefetch_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
static_cast<const T*>(d_src.GetDeviceBuffer()),
|
||||
static_cast<T*>(d_dst_with_prefetch_chunks.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
static_cast<const T*>(d_scalar.GetDeviceBuffer())),
|
||||
true);
|
||||
|
||||
hip_check_error(hipStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
// Copy results back
|
||||
d_dst_with_prefetch_chunks.FromDevice(h_dst_with_prefetch_chunks.data());
|
||||
|
||||
// Verify results
|
||||
bool pass = ck::utils::check_err(h_dst_with_prefetch_chunks, h_expected);
|
||||
|
||||
std::cout << " Correctness: " << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
std::cout << std::endl;
|
||||
|
||||
hip_check_error(hipStreamDestroy(stream));
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace s_prefetch_op_util
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user