mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
WIP: start kernel implementation + test structure
This commit is contained in:
@@ -41,3 +41,4 @@ add_subdirectory(fmha)
|
||||
add_subdirectory(gemm_tile_engine)
|
||||
add_subdirectory(pooling)
|
||||
add_subdirectory(grouped_conv)
|
||||
add_subdirectory(mhc)
|
||||
|
||||
@@ -10,11 +10,9 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
#include "test_multi_reduce2d_multiblock_impl.hpp"
|
||||
#include "test_mhc_impl.hpp"
|
||||
|
||||
// Shape parameters for different test configurations
|
||||
using Shape1_BlockWarps = ck_tile::sequence<4, 1>;
|
||||
|
||||
@@ -43,166 +43,181 @@ class TestCkTileMHC : public ::testing::Test
|
||||
// ReduceDimSeq reduce_dims)
|
||||
void RunGenericTest()
|
||||
{
|
||||
static_assert(
|
||||
ReduceOpsType::size() == ElementwiseOpsType::size() &&
|
||||
ReduceOpsType::size() == AccumulatorOpsType::size() &&
|
||||
ReduceOpsType::size() == InterBlockReduceOpsType::size(),
|
||||
"Error: All operations tuple size must match the number of reduction operations");
|
||||
|
||||
const auto number_operations = ReduceOpsType::size();
|
||||
// Test parameters
|
||||
const int B = 8; // Batch size
|
||||
const int n = 4; // Expansion rate (aka streams)
|
||||
const int C = 256; // Output layer dim
|
||||
const int nC = n * C; // Total input dimension
|
||||
|
||||
ck_tile::HostTensor<XDataType> h_x(input_shape, input_strides);
|
||||
const int output_dim = 2 * n + n * n; // 2n + n^2 = 8 + 16 = 24 for n=4
|
||||
|
||||
auto h_ys = ck_tile::generate_tuple(
|
||||
[&output_shape, &output_strides](auto /*i*/) {
|
||||
return ck_tile::HostTensor<YDataType>(output_shape, output_strides);
|
||||
},
|
||||
ck_tile::number<number_operations>{});
|
||||
// Allocate host tensors
|
||||
ck_tile::HostTensor<float> h_x({B, nC}); // Input [B, nC]
|
||||
ck_tile::HostTensor<float> h_phi({nC, output_dim}); // Weights [nC, 2n+n^2]
|
||||
ck_tile::HostTensor<float> h_output({B, output_dim}); // Output [B, 2n+n^2]
|
||||
|
||||
auto h_ys_ref = ck_tile::generate_tuple(
|
||||
[&output_shape, &output_strides](auto /*i*/) {
|
||||
return ck_tile::HostTensor<YDataType>(output_shape, output_strides);
|
||||
},
|
||||
ck_tile::number<number_operations>{});
|
||||
// Initialize with random data
|
||||
ck_tile::FillUniformDistribution<float>{-1.0f, 1.0f}(h_x);
|
||||
ck_tile::FillUniformDistribution<float>{-0.5f, 0.5f}(h_phi);
|
||||
h_output.SetZero();
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(h_x);
|
||||
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
h_ys.template at<i>().SetZero();
|
||||
h_ys_ref.template at<i>().SetZero();
|
||||
});
|
||||
|
||||
auto output_number_elements = [&output_shape]() {
|
||||
ck_tile::index_t prod = 1;
|
||||
for(auto len : output_shape)
|
||||
prod *= len;
|
||||
return prod;
|
||||
}();
|
||||
|
||||
auto output_buffer_size =
|
||||
number_operations * h_ys.get(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_y_mem(output_buffer_size);
|
||||
|
||||
std::vector<YDataType> h(number_operations * output_number_elements);
|
||||
|
||||
// Init the output data with identity values respective to each reduce op
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
constexpr auto op = ReduceOpsType{}.at(i);
|
||||
const auto identity_val = op.template GetIdentityValue<YDataType>();
|
||||
std::fill(h.begin() + i * output_number_elements,
|
||||
h.begin() + (i + 1) * output_number_elements,
|
||||
identity_val);
|
||||
});
|
||||
ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes());
|
||||
|
||||
// Copy data to device
|
||||
d_x_mem.ToDevice(h_x.data());
|
||||
d_y_mem.ToDevice(h.data());
|
||||
d_phi_mem.ToDevice(h_phi.data());
|
||||
d_output_mem.ToDevice(h_output.data());
|
||||
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
TestReduce2dShape,
|
||||
ReduceOpsType,
|
||||
KeptDimSeq,
|
||||
ReduceDimSeq,
|
||||
InputDim>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceMultiblock<Problem>;
|
||||
|
||||
// Launch configuration
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
// Kernel launch configuration
|
||||
const ck_tile::index_t kBlockSize = 256; // 256 threads per block
|
||||
const ck_tile::index_t kGridSize = B; // One block per batch element
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto elementwise_ops =
|
||||
make_elementwise_ops_tuple(total_reduce_elements, ElementwiseOpsType{});
|
||||
auto accumulator_ops =
|
||||
make_elementwise_ops_tuple(total_reduce_elements, AccumulatorOpsType{});
|
||||
// TODO: Define Problem and Policy types
|
||||
// using Problem = ck_tile::MHCProblem<...>;
|
||||
// using Kernel = ck_tile::ManifoldConstrainedHyperConnection<Problem, Policy>;
|
||||
|
||||
auto [num_block_tile_iterations, block_group_size] =
|
||||
typename Kernel::TilePartitioner{total_reduce_elements}.GetBlockGroupParams();
|
||||
|
||||
std::cout << "Block group size: " << block_group_size
|
||||
<< ", Num block tile iterations: " << num_block_tile_iterations
|
||||
<< ", Reduce total length: " << total_reduce_elements << std::endl;
|
||||
|
||||
ck_tile::index_t kGridSize =
|
||||
((kept_dim_len_prod + TestReduce2dShape::Block_M - 1) / TestReduce2dShape::Block_M) *
|
||||
block_group_size;
|
||||
|
||||
// Generic helper to create tuple from vector based on compile-time size
|
||||
auto make_shape_tuple = []<std::size_t N>(const std::vector<ck_tile::index_t>& vec) {
|
||||
return [&vec]<std::size_t... I>(std::index_sequence<I...>) {
|
||||
return ck_tile::make_tuple(vec[I]...);
|
||||
}(std::make_index_sequence<N>{});
|
||||
};
|
||||
|
||||
auto input_shape_tuple = make_shape_tuple.template operator()<InputDim>(input_shape);
|
||||
auto input_strides_tuple = make_shape_tuple.template operator()<InputDim>(input_strides);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
total_reduce_elements,
|
||||
input_strides_tuple)) // output tensor's continuous dimension
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
std::cout << "Launching MHC kernel with:" << std::endl;
|
||||
std::cout << " Batch size (B): " << B << std::endl;
|
||||
std::cout << " Expansion factor (n): " << n << std::endl;
|
||||
std::cout << " Channels per stream (C): " << C << std::endl;
|
||||
std::cout << " Input dimension (nC): " << nC << std::endl;
|
||||
std::cout << " Output dimension (2n+n²): " << output_dim << std::endl;
|
||||
std::cout << " Grid size: " << kGridSize << std::endl;
|
||||
std::cout << " Block size: " << kBlockSize << std::endl;
|
||||
|
||||
// Kernel launch
|
||||
/*
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(d_y_mem.GetDeviceBuffer()),
|
||||
input_shape_tuple,
|
||||
input_strides_tuple,
|
||||
kept_dims,
|
||||
reduce_dims,
|
||||
output_number_elements,
|
||||
elementwise_ops,
|
||||
accumulator_ops,
|
||||
InterBlockReduceOpsType{}));
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0, // shared memory size
|
||||
static_cast<float*>(d_x_mem.GetDeviceBuffer()),
|
||||
static_cast<float*>(d_phi_mem.GetDeviceBuffer()),
|
||||
static_cast<float*>(d_output_mem.GetDeviceBuffer()),
|
||||
B, n, C));
|
||||
*/
|
||||
|
||||
// Reference computation
|
||||
ck_tile::reference_multiple_reduce_multiblock<XDataType, ComputeDataType, YDataType>(
|
||||
h_x,
|
||||
h_ys_ref,
|
||||
ReduceOpsType{},
|
||||
kept_dims,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_ops,
|
||||
InterBlockReduceOpsType{},
|
||||
block_group_size);
|
||||
// Copy results back to host
|
||||
// d_output_mem.FromDevice(h_output.data());
|
||||
|
||||
// Calculate proper error thresholds based on data types and number of accumulations
|
||||
// const auto rtol = ck_tile::get_relative_threshold<XDataType, YDataType, ComputeDataType>(
|
||||
// total_reduce_elements);
|
||||
// const auto atol = ck_tile::get_absolute_threshold<YDataType, YDataType, ComputeDataType>(
|
||||
// 5.0f, total_reduce_elements);
|
||||
// TODO: Add reference computation and validation
|
||||
|
||||
// Unfortunately due to the non-sequenciality, down-casting on the output buffer
|
||||
// and further operations on this buffer, the error is compounding at a faster
|
||||
// rate than what the host reference can support. A large tolerance is then required
|
||||
const auto rtol = 1e-2;
|
||||
const auto atol = 1e-1;
|
||||
// auto h_ys = ck_tile::generate_tuple(
|
||||
// [&output_shape, &output_strides](auto /*i*/) {
|
||||
// return ck_tile::HostTensor<YDataType>(output_shape, output_strides);
|
||||
// },
|
||||
// ck_tile::number<number_operations>{});
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
std::vector<YDataType> h_y_tmp(output_number_elements * number_operations);
|
||||
d_y_mem.FromDevice(h_y_tmp.data());
|
||||
bool result = true;
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(h_ys.get(ck_tile::number<i>{}).data(),
|
||||
h_y_tmp.data() + i * output_number_elements,
|
||||
output_number_elements * sizeof(YDataType));
|
||||
std::cout << "Checking errors for operation: " << i << std::endl;
|
||||
result &= ck_tile::check_err(h_ys.get(ck_tile::number<i>{}),
|
||||
h_ys_ref.get(ck_tile::number<i>{}),
|
||||
"Error: Incorrect reduce results!",
|
||||
rtol,
|
||||
atol);
|
||||
});
|
||||
// auto h_ys_ref = ck_tile::generate_tuple(
|
||||
// [&output_shape, &output_strides](auto /*i*/) {
|
||||
// return ck_tile::HostTensor<YDataType>(output_shape, output_strides);
|
||||
// },
|
||||
// ck_tile::number<number_operations>{});
|
||||
|
||||
EXPECT_TRUE(result);
|
||||
// ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(h_x);
|
||||
|
||||
// ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
// h_ys.template at<i>().SetZero();
|
||||
// h_ys_ref.template at<i>().SetZero();
|
||||
// });
|
||||
|
||||
// auto output_number_elements = [&output_shape]() {
|
||||
// ck_tile::index_t prod = 1;
|
||||
// for(auto len : output_shape)
|
||||
// prod *= len;
|
||||
// return prod;
|
||||
// }();
|
||||
|
||||
// auto output_buffer_size =
|
||||
// number_operations * h_ys.get(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
// ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
|
||||
// ck_tile::DeviceMem d_y_mem(output_buffer_size);
|
||||
|
||||
// std::vector<YDataType> h(number_operations * output_number_elements);
|
||||
|
||||
// // Init the output data with identity values respective to each reduce op
|
||||
// ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
// constexpr auto op = ReduceOpsType{}.at(i);
|
||||
// const auto identity_val = op.template GetIdentityValue<YDataType>();
|
||||
// std::fill(h.begin() + i * output_number_elements,
|
||||
// h.begin() + (i + 1) * output_number_elements,
|
||||
// identity_val);
|
||||
// });
|
||||
|
||||
// d_x_mem.ToDevice(h_x.data());
|
||||
// d_y_mem.ToDevice(h.data());
|
||||
|
||||
// using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
// ComputeDataType,
|
||||
// YDataType,
|
||||
// TestReduce2dShape,
|
||||
// ReduceOpsType,
|
||||
// KeptDimSeq,
|
||||
// ReduceDimSeq,
|
||||
// InputDim>;
|
||||
|
||||
// using Kernel = ck_tile::MultiReduceMultiblock<Problem>;
|
||||
|
||||
// // Launch configuration
|
||||
// const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
// constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
// auto elementwise_ops =
|
||||
// make_elementwise_ops_tuple(total_reduce_elements, ElementwiseOpsType{});
|
||||
// auto accumulator_ops =
|
||||
// make_elementwise_ops_tuple(total_reduce_elements, AccumulatorOpsType{});
|
||||
|
||||
// auto [num_block_tile_iterations, block_group_size] =
|
||||
// typename Kernel::TilePartitioner{total_reduce_elements}.GetBlockGroupParams();
|
||||
|
||||
// std::cout << "Block group size: " << block_group_size
|
||||
// << ", Num block tile iterations: " << num_block_tile_iterations
|
||||
// << ", Reduce total length: " << total_reduce_elements << std::endl;
|
||||
|
||||
// ck_tile::index_t kGridSize =
|
||||
// ((kept_dim_len_prod + TestReduce2dShape::Block_M - 1) / TestReduce2dShape::Block_M) *
|
||||
// block_group_size;
|
||||
|
||||
// // Generic helper to create tuple from vector based on compile-time size
|
||||
// auto make_shape_tuple = []<std::size_t N>(const std::vector<ck_tile::index_t>& vec) {
|
||||
// return [&vec]<std::size_t... I>(std::index_sequence<I...>) {
|
||||
// return ck_tile::make_tuple(vec[I]...);
|
||||
// }(std::make_index_sequence<N>{});
|
||||
// };
|
||||
|
||||
// auto input_shape_tuple = make_shape_tuple.template operator()<InputDim>(input_shape);
|
||||
// auto input_strides_tuple = make_shape_tuple.template operator()<InputDim>(input_strides);
|
||||
|
||||
// if(!Kernel::IsSupportedArgument()) // TODO
|
||||
// {
|
||||
// }
|
||||
|
||||
// ck_tile::launch_kernel(
|
||||
// ck_tile::stream_config{nullptr, false, 0},
|
||||
// ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
// kGridSize,
|
||||
// kBlockSize,
|
||||
// 0,
|
||||
// static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
|
||||
// static_cast<YDataType*>(d_y_mem.GetDeviceBuffer()),
|
||||
// input_shape_tuple,
|
||||
// input_strides_tuple,
|
||||
// kept_dims,
|
||||
// reduce_dims,
|
||||
// output_number_elements,
|
||||
// elementwise_ops,
|
||||
// accumulator_ops,
|
||||
// InterBlockReduceOpsType{}));
|
||||
|
||||
// TODO: Reference computation + Transfer data back to host
|
||||
EXPECT_TRUE(true);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user