mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Improved integration test.
This commit is contained in:
@@ -555,7 +555,18 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
|
||||
distribution);
|
||||
auto distributed_tensor = tile_window.load();
|
||||
|
||||
// Up to this point, we have set-up the distributed tensor for the 4x4 matrix
|
||||
// similar to the output of the MFMA when 2 conv groups are merged.
|
||||
// We want to copy only the diagonal 2x2 blocks to the output, similar to the epilogue
|
||||
// part of batched iGEMM for 2 conv groups.
|
||||
|
||||
// Create output encoding for 4x2 matrix (2 diagonal blocks stacked vertically)
|
||||
// The threads access the output distributed tensor in a sliced manner
|
||||
// T0 -> row 0
|
||||
// T1 -> row 1 (this is masked out when copying from input)
|
||||
// T2 -> row 2 (this is masked out when copying from input)
|
||||
// T3 -> row 3
|
||||
// Each thread copies 2 elements (a row of a 2x2 block)
|
||||
constexpr auto output_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
@@ -573,23 +584,74 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
|
||||
auto output_global_view = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
output, make_tuple(4, 2));
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto row_offset)
|
||||
auto get_block_number = [&]() -> index_t
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
|
||||
sequence<0, 0, row_offset, 0>{},
|
||||
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
|
||||
}
|
||||
const auto x_space_coord = distribution.calculate_index();
|
||||
if (x_space_coord[0] == 0 && x_space_coord[1] == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else if (x_space_coord[0] == 0 && x_space_coord[1] == 2)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else if (x_space_coord[0] == 2 && x_space_coord[1] == 0)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if (x_space_coord[0] == 2 && x_space_coord[1] == 2)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
return -1;
|
||||
};
|
||||
|
||||
if (threadIdx.x == 3)
|
||||
auto mask = [&]() -> bool
|
||||
{
|
||||
// Only blocks 0 and 3 are diagonal
|
||||
const auto blockId = get_block_number();
|
||||
return (blockId == 0 || blockId == 3);
|
||||
};
|
||||
|
||||
auto get_output_row_offset = [&](auto row) -> index_t
|
||||
{
|
||||
const auto blockId = get_block_number();
|
||||
if (blockId == 0)
|
||||
{
|
||||
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
|
||||
sequence<0, 0, 1-row_offset, 0>{},
|
||||
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
|
||||
return row;
|
||||
}
|
||||
else if (blockId == 3)
|
||||
{
|
||||
return -row;
|
||||
}
|
||||
else
|
||||
{
|
||||
return -1000; // Invalid for other threads
|
||||
}
|
||||
};
|
||||
|
||||
// Because we copy one row at the time, we need to loop over the 2 rows of the 2x2 blocks.
|
||||
// We mask out the threads that do contribute to the diagonal blocks.
|
||||
static_for<0, 2, 1>{}([&](auto row)
|
||||
{
|
||||
if (mask())
|
||||
{
|
||||
//const auto row_offset = input_row_offset<row>(get_block_number());
|
||||
const auto block_id = get_block_number();
|
||||
if (block_id == 0)
|
||||
{
|
||||
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
|
||||
sequence<0, 0, row, 0>{},
|
||||
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
|
||||
}
|
||||
else if (block_id == 3)
|
||||
{
|
||||
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
|
||||
sequence<0, 0, 1 - row, 0>{}, //row 0 -> row 1, row 1 -> row 0
|
||||
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
|
||||
}
|
||||
}
|
||||
|
||||
// Print the output distributed tensor for verification
|
||||
if constexpr (DebugOutput)
|
||||
{
|
||||
block_sync_lds();
|
||||
@@ -608,8 +670,9 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
|
||||
});
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
if (mask())
|
||||
{
|
||||
const auto row_offset = get_output_row_offset(row);
|
||||
auto output_tile_window = make_tile_window(output_global_view,
|
||||
make_tuple(4, 2),
|
||||
{row_offset, 0},
|
||||
@@ -617,16 +680,6 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
|
||||
|
||||
output_tile_window.store(output_distributed_tensor);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 3)
|
||||
{
|
||||
auto output_tile_window = make_tile_window(output_global_view,
|
||||
make_tuple(4, 2),
|
||||
{-row_offset, 0},
|
||||
output_distribution);
|
||||
|
||||
output_tile_window.store(output_distributed_tensor);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user