Improved integration test.

This commit is contained in:
Ville Pietilä
2025-09-16 15:29:01 +00:00
parent 0d802a305f
commit 9175bef679

View File

@@ -555,7 +555,18 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
distribution);
auto distributed_tensor = tile_window.load();
// Up to this point, we have set-up the distributed tensor for the 4x4 matrix
// similar to the output of the MFMA when 2 conv groups are merged.
// We want to copy only the diagonal 2x2 blocks to the output, similar to the epilogue
// part of batched iGEMM for 2 conv groups.
// Create output encoding for 4x2 matrix (2 diagonal blocks stacked vertically)
// The threads access the output distributed tensor in a sliced manner
// T0 -> row 0
// T1 -> row 1 (this is masked out when copying from input)
// T2 -> row 2 (this is masked out when copying from input)
// T3 -> row 3
// Each thread copies 2 elements (a row of a 2x2 block)
constexpr auto output_encoding = tile_distribution_encoding<
sequence<>,
tuple<
@@ -573,23 +584,74 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
auto output_global_view = make_naive_tensor_view_packed<address_space_enum::global>(
output, make_tuple(4, 2));
static_for<0, 2, 1>{}([&](auto row_offset)
auto get_block_number = [&]() -> index_t
{
if (threadIdx.x == 0)
{
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
sequence<0, 0, row_offset, 0>{},
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
}
const auto x_space_coord = distribution.calculate_index();
if (x_space_coord[0] == 0 && x_space_coord[1] == 0)
{
return 0;
}
else if (x_space_coord[0] == 0 && x_space_coord[1] == 2)
{
return 1;
}
else if (x_space_coord[0] == 2 && x_space_coord[1] == 0)
{
return 2;
}
else if (x_space_coord[0] == 2 && x_space_coord[1] == 2)
{
return 3;
}
return -1;
};
if (threadIdx.x == 3)
auto mask = [&]() -> bool
{
// Only blocks 0 and 3 are diagonal
const auto blockId = get_block_number();
return (blockId == 0 || blockId == 3);
};
auto get_output_row_offset = [&](auto row) -> index_t
{
const auto blockId = get_block_number();
if (blockId == 0)
{
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
sequence<0, 0, 1-row_offset, 0>{},
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
return row;
}
else if (blockId == 3)
{
return -row;
}
else
{
return -1000; // Invalid for other threads
}
};
// Because we copy one row at the time, we need to loop over the 2 rows of the 2x2 blocks.
// We mask out the threads that do contribute to the diagonal blocks.
static_for<0, 2, 1>{}([&](auto row)
{
if (mask())
{
//const auto row_offset = input_row_offset<row>(get_block_number());
const auto block_id = get_block_number();
if (block_id == 0)
{
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
sequence<0, 0, row, 0>{},
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
}
else if (block_id == 3)
{
output_distributed_tensor.get_thread_buffer() = distributed_tensor.get_y_sliced_thread_data(
sequence<0, 0, 1 - row, 0>{}, //row 0 -> row 1, row 1 -> row 0
sequence<1, 1, 1, 2>{}); // copy one row of a 2x2 block
}
}
// Print the output distributed tensor for verification
if constexpr (DebugOutput)
{
block_sync_lds();
@@ -608,8 +670,9 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
});
}
if (threadIdx.x == 0)
if (mask())
{
const auto row_offset = get_output_row_offset(row);
auto output_tile_window = make_tile_window(output_global_view,
make_tuple(4, 2),
{row_offset, 0},
@@ -617,16 +680,6 @@ __global__ void test_4x4_matrix_get_2x2_blocks_kernel(int* input, int* output)
output_tile_window.store(output_distributed_tensor);
}
if (threadIdx.x == 3)
{
auto output_tile_window = make_tile_window(output_global_view,
make_tuple(4, 2),
{-row_offset, 0},
output_distribution);
output_tile_window.store(output_distributed_tensor);
}
});
}