mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK profiler] Perform verification on GPU when using GPU reference (#3482)
* Simple verification kernel for ckProfiler * Verification kernel unit tests * Explicit synchronization * Address review comments
This commit is contained in:
313
profiler/include/profiler/gpu_verification.hpp
Normal file
313
profiler/include/profiler/gpu_verification.hpp
Normal file
@@ -0,0 +1,313 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
// Compute relative tolerance for GPU verification
|
||||
// Matches the logic of ck::utils::get_relative_threshold but handles all types
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
inline float compute_relative_tolerance(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I16 = int16_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
// For integer types, tolerance is 0
|
||||
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I16> ||
|
||||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>)
|
||||
{
|
||||
return 0.0f;
|
||||
}
|
||||
// For types supported by get_relative_threshold, use it
|
||||
else if constexpr((std::is_same_v<ComputeDataType, F16> ||
|
||||
std::is_same_v<ComputeDataType, BF16> ||
|
||||
std::is_same_v<ComputeDataType, F32>) &&
|
||||
(std::is_same_v<OutDataType, F16> || std::is_same_v<OutDataType, BF16> ||
|
||||
std::is_same_v<OutDataType, F32>) &&
|
||||
(std::is_same_v<AccDataType, F16> || std::is_same_v<AccDataType, BF16> ||
|
||||
std::is_same_v<AccDataType, F32>))
|
||||
{
|
||||
return static_cast<float>(
|
||||
ck::utils::get_relative_threshold<ComputeDataType, OutDataType, AccDataType>(
|
||||
number_of_accumulations));
|
||||
}
|
||||
// For unsupported types (FP8, BF8, etc.), use default tolerances based on output type
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<OutDataType, F16>)
|
||||
{
|
||||
return 1e-3f;
|
||||
}
|
||||
else if constexpr(std::is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
return 1e-1f;
|
||||
}
|
||||
else
|
||||
{
|
||||
// For FP8/BF8 and other types, use conservative tolerance
|
||||
return 1e-1f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU verification kernel - compares device result against reference using relative and absolute
|
||||
// tolerance Returns 1 in passed if all elements match within tolerance, 0 otherwise
|
||||
template <typename T>
|
||||
__global__ void gpu_verify_kernel(const T* __restrict__ device_result,
|
||||
const T* __restrict__ reference_result,
|
||||
float rtol,
|
||||
float atol,
|
||||
long long size,
|
||||
int* passed)
|
||||
{
|
||||
// Grid-stride loop to handle any tensor size
|
||||
long long idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long stride = blockDim.x * gridDim.x;
|
||||
|
||||
for(long long i = idx; i < size; i += stride)
|
||||
{
|
||||
// Convert to float for comparison
|
||||
float dev_val = type_convert<float>(device_result[i]);
|
||||
float ref_val = type_convert<float>(reference_result[i]);
|
||||
|
||||
// Compute absolute difference
|
||||
float abs_diff = fabsf(dev_val - ref_val);
|
||||
|
||||
// Check tolerance (matches CPU check_err logic: err > atol + rtol * abs(ref))
|
||||
if(abs_diff > atol + rtol * fabsf(ref_val))
|
||||
{
|
||||
atomicMin(passed, 0); // Mark as failed
|
||||
return; // Early exit on first failure
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Host-side wrapper for GPU verification with explicit tolerances
|
||||
// Returns true if verification passed, false otherwise
|
||||
template <typename T>
|
||||
bool gpu_verify(const void* device_result,
|
||||
const void* reference_result,
|
||||
float rtol,
|
||||
float atol,
|
||||
std::size_t size,
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
// Allocate result buffer on device
|
||||
int* passed_dev;
|
||||
hip_check_error(hipMalloc(&passed_dev, sizeof(int)));
|
||||
|
||||
// Initialize to passed (1)
|
||||
int passed_host = 1;
|
||||
hip_check_error(hipMemcpy(passed_dev, &passed_host, sizeof(int), hipMemcpyHostToDevice));
|
||||
|
||||
// Launch kernel with grid-stride loop
|
||||
// Use 65535 as max grid size (hardware limit for grid dimension in x)
|
||||
// Grid-stride loop handles any tensor size regardless of grid dimensions
|
||||
constexpr int block_size = 256;
|
||||
int grid_size = std::min<int>(65535, (size + block_size - 1) / block_size);
|
||||
|
||||
gpu_verify_kernel<T>
|
||||
<<<grid_size, block_size, 0, stream>>>(static_cast<const T*>(device_result),
|
||||
static_cast<const T*>(reference_result),
|
||||
rtol,
|
||||
atol,
|
||||
static_cast<long long>(size),
|
||||
passed_dev);
|
||||
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
// Synchronize the stream to ensure kernel completion before reading results
|
||||
hip_check_error(hipStreamSynchronize(stream));
|
||||
|
||||
// Get result
|
||||
hip_check_error(hipMemcpy(&passed_host, passed_dev, sizeof(int), hipMemcpyDeviceToHost));
|
||||
|
||||
// Free device memory
|
||||
hip_check_error(hipFree(passed_dev));
|
||||
|
||||
return passed_host == 1;
|
||||
}
|
||||
|
||||
// Forward declaration of gpu_reduce_max
|
||||
template <typename T>
|
||||
float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream = nullptr);
|
||||
|
||||
// Host-side wrapper for GPU verification with automatic tolerance computation
|
||||
// Computes max value on GPU, then computes tolerances and verifies
|
||||
// Returns true if verification passed, false otherwise
|
||||
template <typename OutDataType,
|
||||
typename ComputeDataType = OutDataType,
|
||||
typename AccDataType = ComputeDataType>
|
||||
bool gpu_verify(const void* device_result,
|
||||
const void* reference_result,
|
||||
int number_of_accumulations,
|
||||
std::size_t size,
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
// Compute max absolute value on GPU (only 4 bytes transferred!)
|
||||
double max_abs_value =
|
||||
static_cast<double>(gpu_reduce_max<OutDataType>(reference_result, size, stream));
|
||||
|
||||
// Compute tolerances based on data types and accumulation count
|
||||
float rtol = compute_relative_tolerance<ComputeDataType, OutDataType, AccDataType>(
|
||||
number_of_accumulations);
|
||||
|
||||
float atol = 0.0f;
|
||||
// Only compute absolute tolerance for supported types
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
if constexpr((std::is_same_v<ComputeDataType, F16> || std::is_same_v<ComputeDataType, BF16> ||
|
||||
std::is_same_v<ComputeDataType, F32>) &&
|
||||
(std::is_same_v<OutDataType, F16> || std::is_same_v<OutDataType, BF16> ||
|
||||
std::is_same_v<OutDataType, F32>) &&
|
||||
(std::is_same_v<AccDataType, F16> || std::is_same_v<AccDataType, BF16> ||
|
||||
std::is_same_v<AccDataType, F32>))
|
||||
{
|
||||
atol = static_cast<float>(
|
||||
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType, AccDataType>(
|
||||
max_abs_value, number_of_accumulations));
|
||||
}
|
||||
|
||||
// Call the explicit tolerance version
|
||||
return gpu_verify<OutDataType>(device_result, reference_result, rtol, atol, size, stream);
|
||||
}
|
||||
|
||||
//
|
||||
// Helper function for atomic float max (using compare-and-swap)
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* address, float val)
|
||||
{
|
||||
int* address_as_int = reinterpret_cast<int*>(address);
|
||||
int old = *address_as_int;
|
||||
int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
old =
|
||||
atomicCAS(address_as_int, assumed, __float_as_int(fmaxf(val, __int_as_float(assumed))));
|
||||
} while(assumed != old);
|
||||
|
||||
return __int_as_float(old);
|
||||
}
|
||||
|
||||
// GPU reduction kernel for computing max(abs(data))
|
||||
// This is an internal kernel called only by gpu_reduce_max() wrapper.
|
||||
//
|
||||
// Assumption: Block size is 256
|
||||
template <typename T>
|
||||
__global__ void
|
||||
gpu_reduce_max_kernel(const T* __restrict__ data, long long size, float* __restrict__ max_val)
|
||||
{
|
||||
constexpr int block_size = 256;
|
||||
__shared__ float shared_max[block_size];
|
||||
|
||||
long long idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long stride = blockDim.x * gridDim.x;
|
||||
|
||||
float local_max = 0.0f;
|
||||
|
||||
for(long long i = idx; i < size; i += stride)
|
||||
{
|
||||
float val = fabsf(type_convert<float>(data[i]));
|
||||
local_max = fmaxf(local_max, val);
|
||||
}
|
||||
|
||||
shared_max[threadIdx.x] = local_max;
|
||||
__syncthreads();
|
||||
|
||||
// Block-level reduction: 256 -> 128 -> 64 -> 32
|
||||
for(unsigned int s = block_size / 2; s > 32; s >>= 1)
|
||||
{
|
||||
if(threadIdx.x < s)
|
||||
{
|
||||
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Warp-level reduction: 32 -> 16 -> 8 -> 4 -> 2 -> 1
|
||||
// No sync needed within a warp
|
||||
if(threadIdx.x < 32)
|
||||
{
|
||||
volatile float* smem = shared_max;
|
||||
smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 32]);
|
||||
smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 16]);
|
||||
smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 8]);
|
||||
smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 4]);
|
||||
smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 2]);
|
||||
smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 1]);
|
||||
}
|
||||
|
||||
// Two-phase reduction pattern minimizes atomic contention:
|
||||
// 1. Each block reduces to shared memory (above)
|
||||
// 2. Single thread per block updates global max (below)
|
||||
// This limits atomic operations to O(grid_size) rather than O(total_threads)
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicMaxFloat(max_val, shared_max[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// Host-side wrapper for GPU max reduction
|
||||
// Computes max(abs(data)) and returns as float
|
||||
// Only transfers 4 bytes (the final max value) instead of entire tensor
|
||||
template <typename T>
|
||||
float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream)
|
||||
{
|
||||
if(size == 0)
|
||||
{
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// Allocate device memory for result
|
||||
float* max_dev;
|
||||
hip_check_error(hipMalloc(&max_dev, sizeof(float)));
|
||||
|
||||
// Initialize to zero
|
||||
float init_val = 0.0f;
|
||||
hip_check_error(hipMemcpy(max_dev, &init_val, sizeof(float), hipMemcpyHostToDevice));
|
||||
|
||||
// Launch reduction kernel
|
||||
// Use 1024 blocks max for reduction to balance occupancy vs. grid-stride iterations
|
||||
// For very large tensors (>256M elements), grid-stride loop handles the remainder
|
||||
constexpr int block_size = 256;
|
||||
int grid_size = std::min<int>(1024, (size + block_size - 1) / block_size);
|
||||
|
||||
gpu_reduce_max_kernel<T><<<grid_size, block_size, 0, stream>>>(
|
||||
static_cast<const T*>(device_buffer), static_cast<long long>(size), max_dev);
|
||||
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
// Synchronize if using default stream
|
||||
if(stream == nullptr)
|
||||
{
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
}
|
||||
|
||||
// Copy result to host (only 4 bytes!)
|
||||
float max_host;
|
||||
hip_check_error(hipMemcpy(&max_host, max_dev, sizeof(float), hipMemcpyDeviceToHost));
|
||||
|
||||
// Free device memory
|
||||
hip_check_error(hipFree(max_dev));
|
||||
|
||||
return max_host;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp"
|
||||
#include "profiler/gpu_verification.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -89,14 +90,15 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
out_device_buf.ToDevice(out.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
|
||||
// Allocate GPU reference buffer (used only if do_verification == 2)
|
||||
DeviceMem gpu_ref_in_buf(
|
||||
do_verification == 2 ? sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize() : 0);
|
||||
|
||||
float max_accumulated_value = 0;
|
||||
if(do_verification == 2)
|
||||
{
|
||||
// Use GPU reference for verification
|
||||
std::cout << "Using GPU reference for verification" << std::endl;
|
||||
|
||||
// Allocate GPU reference output buffer
|
||||
DeviceMem gpu_ref_in_buf(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize());
|
||||
// Use GPU reference with GPU verification
|
||||
std::cout << "Using GPU reference with GPU verification" << std::endl;
|
||||
|
||||
// Call GPU reference with ConvParam directly
|
||||
ref::naive_conv_bwd_data<InLayout,
|
||||
@@ -116,9 +118,9 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
// Copy GPU reference result to host for comparison
|
||||
gpu_ref_in_buf.FromDevice(in_host.mData.data());
|
||||
max_accumulated_value = *std::max_element(in_host.mData.begin(), in_host.mData.end());
|
||||
// Compute max value on GPU for tolerance calculation (only 4 bytes transferred!)
|
||||
max_accumulated_value = ck::profiler::gpu_reduce_max<InDataType>(
|
||||
gpu_ref_in_buf.GetDeviceBuffer(), in_host.mDesc.GetElementSpaceSize());
|
||||
}
|
||||
else if(do_verification == 1)
|
||||
{
|
||||
@@ -204,8 +206,96 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
best_split_k = split_k_for_run;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
// Synchronize before verification to ensure kernel has completed
|
||||
if(do_verification > 0 && !time_kernel)
|
||||
{
|
||||
hip_check_error(hipStreamSynchronize(nullptr));
|
||||
}
|
||||
|
||||
if(do_verification == 2)
|
||||
{
|
||||
// GPU verification path
|
||||
using ComputeType_ = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
|
||||
OutDataType,
|
||||
WeiDataType>;
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeType_) < sizeof(ComputeDataType),
|
||||
ComputeType_,
|
||||
ComputeDataType>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
|
||||
// Calculate number of accumulations accounting for split_k
|
||||
const int num_accums = static_cast<int>(conv_param.K_ / split_k_for_run);
|
||||
|
||||
// Additional tolerance for split_k accumulation if needed
|
||||
int total_accums = num_accums;
|
||||
if(split_k_for_run > 1)
|
||||
{
|
||||
total_accums = std::max(num_accums, static_cast<int>(split_k_for_run));
|
||||
}
|
||||
|
||||
// Perform GPU verification (max value computed internally on GPU)
|
||||
const std::size_t tensor_size = in_device.mDesc.GetElementSpaceSize();
|
||||
bool gpu_passed = ck::profiler::gpu_verify<InDataType, ComputeType, AccDataType>(
|
||||
in_device_buf.GetDeviceBuffer(),
|
||||
gpu_ref_in_buf.GetDeviceBuffer(),
|
||||
total_accums,
|
||||
tensor_size);
|
||||
|
||||
if(!gpu_passed)
|
||||
{
|
||||
// GPU verification failed - fall back to CPU for detailed diagnostics
|
||||
std::cout << "GPU verification failed, running CPU verification for details..."
|
||||
<< std::endl;
|
||||
|
||||
// Copy both buffers to host
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
gpu_ref_in_buf.FromDevice(in_host.mData.data());
|
||||
|
||||
// Recalculate tolerances for CPU verification with original logic
|
||||
auto rtol =
|
||||
ck::utils::get_relative_threshold<ComputeType, InDataType, AccDataType>(
|
||||
num_accums);
|
||||
auto atol =
|
||||
ck::utils::get_absolute_threshold<ComputeType, InDataType, AccDataType>(
|
||||
max_accumulated_value / split_k_for_run, num_accums);
|
||||
|
||||
if(split_k_for_run > 1)
|
||||
{
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<InDataType, InDataType, InDataType>(
|
||||
split_k_for_run);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<InDataType, InDataType, InDataType>(
|
||||
max_accumulated_value, split_k_for_run);
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
}
|
||||
|
||||
// Run CPU verification for detailed error messages
|
||||
ck::utils::check_err(
|
||||
in_device, in_host, "Error: Incorrect results!", rtol, atol);
|
||||
pass = false;
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "output : ", out.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "weight: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in_host : ", in_host.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in_device: ", in_device.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(do_verification == 1)
|
||||
{
|
||||
// CPU verification path (original behavior)
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
|
||||
using ComputeType_ = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp"
|
||||
#include "profiler/gpu_verification.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -91,6 +92,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
out_device_buf.ToDevice(output.mData.data());
|
||||
|
||||
// Allocate GPU reference buffer (used only if do_verification == 2)
|
||||
DeviceMem gpu_ref_wei_buf(
|
||||
do_verification == 2 ? sizeof(WeiDataType) * weight_host_result.mDesc.GetElementSpaceSize()
|
||||
: 0);
|
||||
|
||||
float max_accumulated_value = 0;
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -120,20 +126,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
max_accumulated_value =
|
||||
*std::max_element(weight_host_result.mData.begin(), weight_host_result.mData.end());
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
// GPU reference
|
||||
std::cout << "Running GPU reference implementation..." << std::endl;
|
||||
|
||||
// Allocate device memory for reference
|
||||
DeviceMem in_ref_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei_ref_buf(sizeof(WeiDataType) *
|
||||
weight_host_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out_ref_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize());
|
||||
|
||||
in_ref_buf.ToDevice(input.mData.data());
|
||||
out_ref_buf.ToDevice(output.mData.data());
|
||||
// Use GPU reference with GPU verification
|
||||
std::cout << "Using GPU reference with GPU verification" << std::endl;
|
||||
|
||||
// Call GPU reference with ConvParam directly
|
||||
ck::ref::naive_conv_bwd_weight<InLayout,
|
||||
@@ -145,20 +144,14 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>(
|
||||
static_cast<const InDataType*>(in_ref_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_ref_buf.GetDeviceBuffer()),
|
||||
static_cast<const OutDataType*>(out_ref_buf.GetDeviceBuffer()),
|
||||
static_cast<const InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(gpu_ref_wei_buf.GetDeviceBuffer()),
|
||||
static_cast<const OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
// Copy result back to host
|
||||
wei_ref_buf.FromDevice(weight_host_result.mData.data());
|
||||
}
|
||||
|
||||
max_accumulated_value =
|
||||
*std::max_element(weight_host_result.mData.begin(), weight_host_result.mData.end());
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
@@ -320,8 +313,109 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
best_split_k = split_k_param_str;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
// Synchronize before verification to ensure kernel has completed
|
||||
if(do_verification > 0 && !time_kernel)
|
||||
{
|
||||
hip_check_error(hipStreamSynchronize(nullptr));
|
||||
}
|
||||
|
||||
if(do_verification == 2)
|
||||
{
|
||||
// GPU verification path
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeA) < sizeof(ComputeTypeB),
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
|
||||
// Calculate number of accumulations accounting for split_k
|
||||
const int num_accums =
|
||||
static_cast<int>(output.GetElementSize() / conv_param.K_ / split_k_value);
|
||||
|
||||
// Additional tolerance for split_k accumulation if needed
|
||||
int total_accums = num_accums;
|
||||
if(split_k_value > 1)
|
||||
{
|
||||
total_accums = std::max(num_accums, static_cast<int>(split_k_value));
|
||||
}
|
||||
|
||||
// Perform GPU verification (max value computed internally on GPU)
|
||||
const std::size_t tensor_size =
|
||||
weight_device_result.mDesc.GetElementSpaceSize();
|
||||
bool gpu_passed =
|
||||
ck::profiler::gpu_verify<WeiDataType, ComputeType, AccDataType>(
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
gpu_ref_wei_buf.GetDeviceBuffer(),
|
||||
total_accums,
|
||||
tensor_size);
|
||||
|
||||
if(!gpu_passed)
|
||||
{
|
||||
// GPU verification failed - fall back to CPU for detailed diagnostics
|
||||
std::cout
|
||||
<< "GPU verification failed, running CPU verification for details..."
|
||||
<< std::endl;
|
||||
|
||||
// Copy both buffers to host
|
||||
wei_device_buf.FromDevice(weight_device_result.mData.data());
|
||||
gpu_ref_wei_buf.FromDevice(weight_host_result.mData.data());
|
||||
|
||||
// Recalculate tolerances for CPU verification with original logic
|
||||
const index_t num_accums_full = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
auto rtol = ck::utils::
|
||||
get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
num_accums_full / num_accums_split_k);
|
||||
auto atol = ck::utils::
|
||||
get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k,
|
||||
num_accums_full / num_accums_split_k);
|
||||
|
||||
if(split_k_value > 1)
|
||||
{
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<WeiDataType,
|
||||
WeiDataType,
|
||||
WeiDataType>(num_accums_split_k);
|
||||
auto atol_split_k = ck::utils::
|
||||
get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
max_accumulated_value, num_accums_split_k);
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
}
|
||||
|
||||
// Run CPU verification for detailed error messages
|
||||
ck::utils::check_err(weight_device_result,
|
||||
weight_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
all_pass = false;
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
std::cout << "Fail info: splitK: " << split_k_value << " "
|
||||
<< op_ptr->GetTypeString() << std::endl;
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "weight (device): ", weight_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "weight (host): ", weight_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "input: ", input.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(do_verification == 1)
|
||||
{
|
||||
// CPU verification path (original behavior)
|
||||
wei_device_buf.FromDevice(weight_device_result.mData.data());
|
||||
|
||||
using ComputeType =
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
|
||||
#include "profiler/gpu_verification.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -113,14 +114,15 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weight.mData.data());
|
||||
|
||||
// Allocate GPU reference buffer (used only if do_verification == 2)
|
||||
DeviceMem gpu_ref_out_buf(
|
||||
do_verification == 2 ? sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize() : 0);
|
||||
|
||||
// run reference op
|
||||
if(do_verification == 2)
|
||||
{
|
||||
// Use GPU reference for verification
|
||||
std::cout << "Using GPU reference for verification" << std::endl;
|
||||
|
||||
// Allocate GPU reference output buffer
|
||||
DeviceMem gpu_ref_out_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
|
||||
// Use GPU reference with GPU verification
|
||||
std::cout << "Using GPU reference with GPU verification" << std::endl;
|
||||
|
||||
// Call GPU reference with ConvParam directly
|
||||
ref::naive_conv_fwd<InLayout,
|
||||
@@ -139,9 +141,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
// Copy GPU reference result to host for comparison
|
||||
gpu_ref_out_buf.FromDevice(host_output.mData.data());
|
||||
}
|
||||
else if(do_verification == 1)
|
||||
{
|
||||
@@ -225,8 +224,63 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
// Synchronize before verification to ensure kernel has completed
|
||||
if(do_verification > 0 && !time_kernel)
|
||||
{
|
||||
hip_check_error(hipStreamSynchronize(nullptr));
|
||||
}
|
||||
|
||||
if(do_verification == 2)
|
||||
{
|
||||
// GPU verification path
|
||||
// Calculate number of accumulations (C * filter spatial dimensions)
|
||||
std::size_t filter_spatial_size = 1;
|
||||
for(auto len : conv_param.filter_spatial_lengths_)
|
||||
{
|
||||
filter_spatial_size *= len;
|
||||
}
|
||||
const int num_accums = static_cast<int>(conv_param.C_ * filter_spatial_size);
|
||||
|
||||
// Perform GPU verification (max value computed internally on GPU)
|
||||
const std::size_t tensor_size = device_output.mDesc.GetElementSpaceSize();
|
||||
bool gpu_passed = ck::profiler::gpu_verify<OutDataType, AComputeType, OutDataType>(
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
gpu_ref_out_buf.GetDeviceBuffer(),
|
||||
num_accums,
|
||||
tensor_size);
|
||||
|
||||
if(!gpu_passed)
|
||||
{
|
||||
// GPU verification failed - fall back to CPU for detailed diagnostics
|
||||
std::cout << "GPU verification failed, running CPU verification for details..."
|
||||
<< std::endl;
|
||||
|
||||
// Copy both buffers to host
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
gpu_ref_out_buf.FromDevice(host_output.mData.data());
|
||||
|
||||
// Run CPU verification for detailed error messages
|
||||
ck::utils::check_err(device_output, host_output);
|
||||
pass = false;
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "host_output : ", host_output.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "device_output: ", device_output.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(do_verification == 1)
|
||||
{
|
||||
// CPU verification path (original behavior)
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(device_output, host_output);
|
||||
|
||||
Reference in New Issue
Block a user