Add blockwise gemm to ck wrapper (#1139)

* Add blockwise gemm to ck wrapper

* Add blockwise gemm traits

* Disable test_gemm for non xdl devices

* Fixes

* Add c layout descritpions
This commit is contained in:
Bartłomiej Kocot
2024-01-31 21:24:40 +01:00
committed by GitHub
parent 6651a124cc
commit f3b6c23ac5
12 changed files with 1064 additions and 116 deletions

View File

@@ -29,17 +29,24 @@ TEST(TestPartition, LocalPartition)
const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto thread_steps = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{});
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{});
// 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim)
const auto thread_projection =
ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{});
constexpr ck::index_t projection_thread_length = ck::Number<4>{};
for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++)
for(ck::index_t thread_id = 0;
thread_id < ck::wrapper::size(thread_layout) / projection_thread_length;
thread_id++)
{
const auto packed_partition =
ck::wrapper::make_local_partition(tensor, thread_layout, thread_id);
ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_projection);
const auto expected_partition_size =
ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout);
const auto expected_partition_first_val = thread_id * ck::wrapper::size<0>(thread_steps);
ck::wrapper::size(tensor) /
(ck::wrapper::size(thread_layout) / projection_thread_length);
const auto expected_partition_first_val = thread_id * ck::wrapper::size<1>(thread_steps);
const auto expected_partition_second_val = expected_partition_first_val + 1;
EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size);
EXPECT_EQ(packed_partition(0), expected_partition_first_val);
@@ -58,8 +65,12 @@ TEST(TestPartition, LocalTile)
const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto block_shape = ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{});
// 4d tile partitioning on 3d shape (calculate tile on 4d tile layout, and then skip last dim)
const auto block_shape =
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{});
const auto block_projection =
ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2));
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto num_blocks =
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
@@ -69,9 +80,10 @@ TEST(TestPartition, LocalTile)
for(auto block_idx : block_idxs)
{
const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx);
const auto packed_tile =
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection);
const auto expected_tile_size = ck::wrapper::size(block_shape);
const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim;
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) *
ck::wrapper::size<2>(block_shape) *
ck::wrapper::size<2>(strides);