mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Add optimized copy to ck wrapper (#1126)
* Add optimized copy to ck wrapper * Example optimizations * Fixes * Move img2col test to client example * Refactor example * Fix docs * Fixes * Fix * Fixes * Fixes * Fixes * Fixes * Fixes --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -21,49 +21,59 @@ template <typename InputTensor,
|
||||
typename OutputTensor,
|
||||
typename BlockShape,
|
||||
typename ThreadLayoutShape,
|
||||
typename LocalTileSteps,
|
||||
typename LocalPartitionSteps>
|
||||
bool UseOptimizedCopy>
|
||||
__global__ void TestCopyDevice(const InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayoutShape thread_layout,
|
||||
const LocalTileSteps block_steps,
|
||||
const LocalPartitionSteps thread_steps)
|
||||
const ThreadLayoutShape thread_layout)
|
||||
{
|
||||
__shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)];
|
||||
auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
p_shared, ck::wrapper::make_layout(tile_shape));
|
||||
|
||||
const auto block_idxs = ck::make_tuple(ck::make_tuple(0, 0), blockIdx.x);
|
||||
const auto block_idx = static_cast<ck::index_t>(blockIdx.x);
|
||||
|
||||
// Get local tiles for global memory
|
||||
const auto input_local_tile =
|
||||
ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs, block_steps);
|
||||
const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
|
||||
const auto output_local_tile =
|
||||
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs, block_steps);
|
||||
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
|
||||
|
||||
// Get partition per thread
|
||||
const auto input_local_partition = ck::wrapper::make_local_partition(
|
||||
input_local_tile, thread_layout, threadIdx.x, thread_steps);
|
||||
const auto input_local_partition =
|
||||
ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x);
|
||||
auto lds_local_partition =
|
||||
ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x, thread_steps);
|
||||
auto output_local_partition = ck::wrapper::make_local_partition(
|
||||
output_local_tile, thread_layout, threadIdx.x, thread_steps);
|
||||
ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x);
|
||||
auto output_local_partition =
|
||||
ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x);
|
||||
|
||||
// Allocate VGPR
|
||||
constexpr ck::index_t scalar_per_vector = 1;
|
||||
constexpr ck::index_t vgpr_size = ck::wrapper::size(lds_local_partition);
|
||||
auto tensor_vgpr = ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr,
|
||||
vgpr_size,
|
||||
scalar_per_vector,
|
||||
ck::index_t>();
|
||||
auto tensor_vgpr =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
|
||||
layout(lds_local_partition));
|
||||
|
||||
// Perform copy
|
||||
ck::wrapper::copy(input_local_partition, lds_local_partition);
|
||||
ck::wrapper::copy(lds_local_partition, tensor_vgpr);
|
||||
ck::wrapper::copy(tensor_vgpr, output_local_partition);
|
||||
if constexpr(UseOptimizedCopy)
|
||||
{
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>>;
|
||||
constexpr ck::index_t vector_dim = 0;
|
||||
constexpr ck::index_t scalar_per_vector = 2;
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(input_local_partition,
|
||||
lds_local_partition);
|
||||
// TODO: Enable optimized copy for static buffers
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(lds_local_partition,
|
||||
tensor_vgpr);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(tensor_vgpr,
|
||||
output_local_partition);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::wrapper::copy(input_local_partition, lds_local_partition);
|
||||
ck::wrapper::copy(lds_local_partition, tensor_vgpr);
|
||||
ck::wrapper::copy(tensor_vgpr, output_local_partition);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool UseOptimizedCopy>
|
||||
void PerformCopyGlobalToGlobalViaLDS()
|
||||
{
|
||||
const auto shape =
|
||||
@@ -89,15 +99,8 @@ void PerformCopyGlobalToGlobalViaLDS()
|
||||
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout);
|
||||
|
||||
const auto thread_layout =
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<32>{});
|
||||
const auto tile_shape =
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<64>{});
|
||||
|
||||
const auto thread_steps =
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<2>{});
|
||||
const auto block_steps =
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<64>{});
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
|
||||
|
||||
const ck::index_t grid_size = ck::math::integer_divide_ceil(
|
||||
ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape));
|
||||
@@ -106,8 +109,7 @@ void PerformCopyGlobalToGlobalViaLDS()
|
||||
decltype(output_tensor_global),
|
||||
decltype(tile_shape),
|
||||
decltype(thread_layout),
|
||||
decltype(block_steps),
|
||||
decltype(thread_steps)>;
|
||||
UseOptimizedCopy>;
|
||||
launch_and_time_kernel(StreamConfig{},
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
@@ -116,9 +118,7 @@ void PerformCopyGlobalToGlobalViaLDS()
|
||||
input_tensor_global,
|
||||
output_tensor_global,
|
||||
tile_shape,
|
||||
thread_layout,
|
||||
block_steps,
|
||||
thread_steps);
|
||||
thread_layout);
|
||||
|
||||
// Verify results
|
||||
std::vector<ck::index_t> output_data(ck::wrapper::size(shape));
|
||||
@@ -126,4 +126,5 @@ void PerformCopyGlobalToGlobalViaLDS()
|
||||
EXPECT_TRUE(ck::utils::check_err(output_data, input_data));
|
||||
}
|
||||
|
||||
TEST(TestCopy, CopyGlobalToGlobalViaLDS) { PerformCopyGlobalToGlobalViaLDS(); }
|
||||
TEST(TestCopyGlobalToGlobalViaLDS, GenericCopy) { PerformCopyGlobalToGlobalViaLDS<false>(); }
|
||||
TEST(TestCopyGlobalToGlobalViaLDS, OptimizedCopy) { PerformCopyGlobalToGlobalViaLDS<true>(); }
|
||||
|
||||
Reference in New Issue
Block a user