mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
sync with upstream
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
using DataType = float;
|
||||
@@ -36,21 +37,20 @@ struct SimpleDeviceMem
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
// Test copy from Global to Global through LDS and VGPR
|
||||
template <typename InputTensor,
|
||||
typename OutputTensor,
|
||||
typename BlockShape,
|
||||
typename ThreadLayoutShape>
|
||||
__global__ void DeviceImageToColumnPad0(InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayoutShape thread_layout)
|
||||
template <typename InputTensor, typename OutputTensor, typename BlockShape, typename ThreadLayout>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__
|
||||
DeviceImageToColumnPad0(InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
|
||||
// grid layout (dim1, dim0)
|
||||
const auto block_idxs =
|
||||
ck::make_tuple(static_cast<ck::index_t>(blockIdx.y), static_cast<ck::index_t>(blockIdx.x));
|
||||
|
||||
// Get local tiles for global memory
|
||||
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
|
||||
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
|
||||
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
|
||||
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
|
||||
|
||||
// Get partition per thread
|
||||
const auto input_local_partition =
|
||||
@@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType));
|
||||
|
||||
// User can choose appropriate number of threads and sizes per block
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}),
|
||||
ck::make_tuple(ck::Number<16>{}, ck::Number<1>{}));
|
||||
// This example doesn't support padding, user should select tile sizes
|
||||
// which divides the shape completely
|
||||
// which are divisible by the shape.
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{});
|
||||
|
||||
// Create buffers for global memory
|
||||
@@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(out_buf.GetDeviceBuffer()), out_layout);
|
||||
|
||||
const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
|
||||
ck::wrapper::size<0>(tile_shape)) *
|
||||
ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
|
||||
ck::wrapper::size<1>(tile_shape));
|
||||
// grid layout (dim1, dim0)
|
||||
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
|
||||
ck::wrapper::size<1>(tile_shape));
|
||||
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
|
||||
ck::wrapper::size<0>(tile_shape));
|
||||
|
||||
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
|
||||
decltype(output_tensor_global),
|
||||
@@ -134,7 +137,7 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
decltype(thread_layout)>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
input_tensor_global,
|
||||
@@ -178,3 +181,4 @@ int main(int argc, char* argv[])
|
||||
{1, 1, 1} /*filter_dilations*/);
|
||||
return 0;
|
||||
}
|
||||
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
|
||||
|
||||
Reference in New Issue
Block a user