mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Setup build environment. Format source code.
This commit is contained in:
@@ -9,10 +9,10 @@ include_directories(AFTER
|
||||
# Each stage builds on the previous one
|
||||
|
||||
# Stage 00: Hello ck_tile - First program
|
||||
add_subdirectory(stage_00_hello_ck_tile)
|
||||
#add_subdirectory(stage_00_hello_ck_tile)
|
||||
|
||||
# Stage 01: Tile Distribution - Understanding thread-to-data mapping
|
||||
add_subdirectory(stage_01_tile_distribution)
|
||||
#add_subdirectory(stage_01_tile_distribution)
|
||||
|
||||
# Stage 02: 2D Tensors and Windows
|
||||
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/stage_02_2d_tensors/CMakeLists.txt)
|
||||
|
||||
@@ -25,91 +25,85 @@ using namespace ck_tile;
|
||||
// Note: These functions demonstrate the API but may be scalarized by the compiler
|
||||
// when returning by value. For true vectorization, use get_vectorized_elements inline.
|
||||
|
||||
template<typename DataType>
|
||||
CK_TILE_DEVICE thread_buffer<DataType, 2>
|
||||
vectorized_read_2(const DataType* p_data, index_t offset)
|
||||
template <typename DataType>
|
||||
CK_TILE_DEVICE thread_buffer<DataType, 2> vectorized_read_2(const DataType* p_data, index_t offset)
|
||||
{
|
||||
auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_data,
|
||||
make_tuple(12), // total elements
|
||||
make_tuple(1), // stride
|
||||
number<2>{}, // GuaranteedLastDimensionVectorLength
|
||||
number<1>{} // GuaranteedLastDimensionVectorStride
|
||||
make_tuple(12), // total elements
|
||||
make_tuple(1), // stride
|
||||
number<2>{}, // GuaranteedLastDimensionVectorLength
|
||||
number<1>{} // GuaranteedLastDimensionVectorStride
|
||||
);
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
|
||||
return view.template get_vectorized_elements<thread_buffer<DataType, 2>>(coord, 0);
|
||||
}
|
||||
|
||||
template<typename DataType>
|
||||
CK_TILE_DEVICE thread_buffer<DataType, 4>
|
||||
vectorized_read_4(const DataType* p_data, index_t offset)
|
||||
template <typename DataType>
|
||||
CK_TILE_DEVICE thread_buffer<DataType, 4> vectorized_read_4(const DataType* p_data, index_t offset)
|
||||
{
|
||||
auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_data,
|
||||
make_tuple(12), // total elements
|
||||
make_tuple(1), // stride
|
||||
number<4>{}, // GuaranteedLastDimensionVectorLength
|
||||
number<1>{} // GuaranteedLastDimensionVectorStride
|
||||
make_tuple(12), // total elements
|
||||
make_tuple(1), // stride
|
||||
number<4>{}, // GuaranteedLastDimensionVectorLength
|
||||
number<1>{} // GuaranteedLastDimensionVectorStride
|
||||
);
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
|
||||
return view.template get_vectorized_elements<thread_buffer<DataType, 4>>(coord, 0);
|
||||
}
|
||||
|
||||
template<typename DataType>
|
||||
CK_TILE_DEVICE void
|
||||
template <typename DataType>
|
||||
CK_TILE_DEVICE void
|
||||
vectorized_write_4(DataType* p_data, index_t offset, thread_buffer<DataType, 4> buffer)
|
||||
{
|
||||
auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_data,
|
||||
make_tuple(12), // total elements
|
||||
make_tuple(1) // stride
|
||||
auto view = make_naive_tensor_view<address_space_enum::global>(p_data,
|
||||
make_tuple(12), // total elements
|
||||
make_tuple(1) // stride
|
||||
);
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
|
||||
view.set_vectorized_elements(coord, 0, buffer);
|
||||
}
|
||||
|
||||
// Additional functions with fp16 to demonstrate vectorization with smaller types
|
||||
CK_TILE_DEVICE thread_buffer<half_t, 4>
|
||||
vectorized_read_4_fp16(const half_t* p_data, index_t offset)
|
||||
CK_TILE_DEVICE thread_buffer<half_t, 4> vectorized_read_4_fp16(const half_t* p_data, index_t offset)
|
||||
{
|
||||
auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_data,
|
||||
make_tuple(24), // total elements (more for fp16)
|
||||
make_tuple(1) // stride
|
||||
make_tuple(24), // total elements (more for fp16)
|
||||
make_tuple(1) // stride
|
||||
);
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
|
||||
return view.template get_vectorized_elements<thread_buffer<half_t, 4>>(coord, 0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE thread_buffer<half_t, 8>
|
||||
vectorized_read_8_fp16(const half_t* p_data, index_t offset)
|
||||
CK_TILE_DEVICE thread_buffer<half_t, 8> vectorized_read_8_fp16(const half_t* p_data, index_t offset)
|
||||
{
|
||||
auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_data,
|
||||
make_tuple(24), // total elements
|
||||
make_tuple(1) // stride
|
||||
auto view = make_naive_tensor_view<address_space_enum::global>(p_data,
|
||||
make_tuple(24), // total elements
|
||||
make_tuple(1) // stride
|
||||
);
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto desc = view.get_tensor_descriptor();
|
||||
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
|
||||
return view.template get_vectorized_elements<thread_buffer<half_t, 8>>(coord, 0);
|
||||
}
|
||||
|
||||
// The kernel that demonstrates all fundamental concepts
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TensorFundamentalsKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 64;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* p_input,
|
||||
DataType* p_output,
|
||||
index_t H, index_t W, index_t C) const
|
||||
CK_TILE_DEVICE void
|
||||
operator()(const DataType* p_input, DataType* p_output, index_t H, index_t W, index_t C) const
|
||||
{
|
||||
// Only thread 0 for clean output
|
||||
if(get_thread_id() != 0) return;
|
||||
if(get_thread_id() != 0)
|
||||
return;
|
||||
|
||||
printf("\n=== TENSOR FUNDAMENTALS IN CK_TILE ===\n\n");
|
||||
|
||||
@@ -123,25 +117,31 @@ struct TensorFundamentalsKernel
|
||||
// It contains: lengths (shape) + strides (memory layout)
|
||||
|
||||
// Create a descriptor for [H,W,C] tensor in row-major layout
|
||||
auto hwc_descriptor = make_naive_tensor_descriptor(
|
||||
make_tuple(H, W, C), // lengths: [2, 3, 2]
|
||||
make_tuple(W*C, C, 1) // strides: [6, 2, 1] for row-major
|
||||
);
|
||||
auto hwc_descriptor =
|
||||
make_naive_tensor_descriptor(make_tuple(H, W, C), // lengths: [2, 3, 2]
|
||||
make_tuple(W * C, C, 1) // strides: [6, 2, 1] for row-major
|
||||
);
|
||||
|
||||
// Access descriptor properties
|
||||
auto lengths = hwc_descriptor.get_lengths();
|
||||
// Note: Descriptors don't expose strides directly after transformation
|
||||
|
||||
printf("Descriptor for [H=%ld, W=%ld, C=%ld] tensor:\n",
|
||||
static_cast<long>(H), static_cast<long>(W), static_cast<long>(C));
|
||||
static_cast<long>(H),
|
||||
static_cast<long>(W),
|
||||
static_cast<long>(C));
|
||||
printf(" Lengths: [%ld, %ld, %ld]\n",
|
||||
static_cast<long>(lengths.at(number<0>{})),
|
||||
static_cast<long>(lengths.at(number<1>{})),
|
||||
static_cast<long>(lengths.at(number<2>{})));
|
||||
printf(" Strides: [%ld, %ld, %ld] (row-major)\n",
|
||||
static_cast<long>(W*C), static_cast<long>(C), static_cast<long>(1));
|
||||
static_cast<long>(W * C),
|
||||
static_cast<long>(C),
|
||||
static_cast<long>(1));
|
||||
printf(" Memory formula: offset = h*%ld + w*%ld + c*%ld\n\n",
|
||||
static_cast<long>(W*C), static_cast<long>(C), static_cast<long>(1));
|
||||
static_cast<long>(W * C),
|
||||
static_cast<long>(C),
|
||||
static_cast<long>(1));
|
||||
|
||||
//==================================================================
|
||||
// PART 2: TENSOR VIEW - Three Creation Methods
|
||||
@@ -152,46 +152,51 @@ struct TensorFundamentalsKernel
|
||||
// Method 1: Explicit strides (most control)
|
||||
printf("Method 1: make_naive_tensor_view with explicit strides\n");
|
||||
auto view1 = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input, // GPU memory pointer
|
||||
make_tuple(H, W, C), // lengths
|
||||
make_tuple(W*C, C, 1) // explicit strides
|
||||
p_input, // GPU memory pointer
|
||||
make_tuple(H, W, C), // lengths
|
||||
make_tuple(W * C, C, 1) // explicit strides
|
||||
);
|
||||
printf(" Created view with shape [%ld,%ld,%ld] and strides [%ld,%ld,%ld]\n",
|
||||
static_cast<long>(H), static_cast<long>(W), static_cast<long>(C),
|
||||
static_cast<long>(W*C), static_cast<long>(C), static_cast<long>(1));
|
||||
static_cast<long>(H),
|
||||
static_cast<long>(W),
|
||||
static_cast<long>(C),
|
||||
static_cast<long>(W * C),
|
||||
static_cast<long>(C),
|
||||
static_cast<long>(1));
|
||||
|
||||
// Method 2: Packed/contiguous (auto-computes row-major strides)
|
||||
printf("\nMethod 2: make_naive_tensor_view_packed (auto strides)\n");
|
||||
auto view2 = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_input, // GPU memory pointer
|
||||
make_tuple(H, W, C) // lengths only, strides auto-computed
|
||||
p_input, // GPU memory pointer
|
||||
make_tuple(H, W, C) // lengths only, strides auto-computed
|
||||
);
|
||||
printf(" Created packed view - strides computed automatically\n");
|
||||
printf(" For row-major: last dim stride=1, each dim stride = next_dim_stride * next_dim_length\n");
|
||||
printf(" For row-major: last dim stride=1, each dim stride = next_dim_stride * "
|
||||
"next_dim_length\n");
|
||||
|
||||
// Method 3: From existing descriptor
|
||||
printf("\nMethod 3: make_tensor_view from descriptor\n");
|
||||
auto view3 = make_tensor_view<address_space_enum::global>(
|
||||
p_input, // GPU memory pointer
|
||||
hwc_descriptor // existing descriptor
|
||||
);
|
||||
auto view3 =
|
||||
make_tensor_view<address_space_enum::global>(p_input, // GPU memory pointer
|
||||
hwc_descriptor // existing descriptor
|
||||
);
|
||||
printf(" Created view using pre-existing descriptor\n");
|
||||
|
||||
// Demonstrate all three views access the same data
|
||||
printf("\nVerifying all three methods create equivalent views:\n");
|
||||
{
|
||||
auto coord_test = make_tensor_coordinate(
|
||||
view1.get_tensor_descriptor(), make_tuple(0, 1, 0));
|
||||
auto coord_test =
|
||||
make_tensor_coordinate(view1.get_tensor_descriptor(), make_tuple(0, 1, 0));
|
||||
auto val1 = view1.template get_vectorized_elements<thread_buffer<DataType, 1>>(
|
||||
coord_test, 0)[number<0>{}];
|
||||
|
||||
auto coord_test2 = make_tensor_coordinate(
|
||||
view2.get_tensor_descriptor(), make_tuple(0, 1, 0));
|
||||
auto coord_test2 =
|
||||
make_tensor_coordinate(view2.get_tensor_descriptor(), make_tuple(0, 1, 0));
|
||||
auto val2 = view2.template get_vectorized_elements<thread_buffer<DataType, 1>>(
|
||||
coord_test2, 0)[number<0>{}];
|
||||
|
||||
auto coord_test3 = make_tensor_coordinate(
|
||||
view3.get_tensor_descriptor(), make_tuple(0, 1, 0));
|
||||
auto coord_test3 =
|
||||
make_tensor_coordinate(view3.get_tensor_descriptor(), make_tuple(0, 1, 0));
|
||||
auto val3 = view3.template get_vectorized_elements<thread_buffer<DataType, 1>>(
|
||||
coord_test3, 0)[number<0>{}];
|
||||
|
||||
@@ -217,11 +222,12 @@ struct TensorFundamentalsKernel
|
||||
|
||||
// Coordinate can compute its linear offset
|
||||
index_t offset = coord.get_offset();
|
||||
printf("Coordinate [1,2,0] maps to linear offset: %ld\n",
|
||||
static_cast<long>(offset));
|
||||
printf("Coordinate [1,2,0] maps to linear offset: %ld\n", static_cast<long>(offset));
|
||||
printf(" Calculation: 1*%ld + 2*%ld + 0*%ld = %ld\n\n",
|
||||
static_cast<long>(W*C), static_cast<long>(C),
|
||||
static_cast<long>(1), static_cast<long>(offset));
|
||||
static_cast<long>(W * C),
|
||||
static_cast<long>(C),
|
||||
static_cast<long>(1),
|
||||
static_cast<long>(offset));
|
||||
|
||||
//==================================================================
|
||||
// PART 4: ELEMENT ACCESS - The Critical Pattern
|
||||
@@ -237,8 +243,8 @@ struct TensorFundamentalsKernel
|
||||
|
||||
// get_vectorized_elements returns thread_buffer<T,N>, not T!
|
||||
auto buffer = view1.template get_vectorized_elements<thread_buffer<DataType, 1>>(
|
||||
read_coord, // coordinate
|
||||
0 // linear_offset (usually 0)
|
||||
read_coord, // coordinate
|
||||
0 // linear_offset (usually 0)
|
||||
);
|
||||
|
||||
// Extract actual value from thread_buffer
|
||||
@@ -261,10 +267,7 @@ struct TensorFundamentalsKernel
|
||||
|
||||
// Write to output view
|
||||
auto output_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_output,
|
||||
make_tuple(H, W, C),
|
||||
make_tuple(W*C, C, 1)
|
||||
);
|
||||
p_output, make_tuple(H, W, C), make_tuple(W * C, C, 1));
|
||||
|
||||
output_view.set_vectorized_elements(write_coord, 0, write_buffer);
|
||||
|
||||
@@ -284,8 +287,8 @@ struct TensorFundamentalsKernel
|
||||
// Create a flattened view for easier vectorized access
|
||||
auto flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input,
|
||||
make_tuple(H*W*C), // [12] - all elements in linear order
|
||||
make_tuple(1) // stride = 1 (contiguous)
|
||||
make_tuple(H * W * C), // [12] - all elements in linear order
|
||||
make_tuple(1) // stride = 1 (contiguous)
|
||||
);
|
||||
auto flat_desc = flat_view.get_tensor_descriptor();
|
||||
|
||||
@@ -294,10 +297,10 @@ struct TensorFundamentalsKernel
|
||||
{
|
||||
// Call the vectorized_read_2 function (easy to disassemble in debugger)
|
||||
auto buffer = vectorized_read_2(p_input, 0);
|
||||
|
||||
|
||||
DataType val0 = buffer[number<0>{}];
|
||||
DataType val1 = buffer[number<1>{}];
|
||||
|
||||
|
||||
printf(" Position [0]: Read 2 elements in one operation\n");
|
||||
printf(" buffer[0] = %.0f\n", static_cast<float>(val0));
|
||||
printf(" buffer[1] = %.0f\n", static_cast<float>(val1));
|
||||
@@ -310,7 +313,7 @@ struct TensorFundamentalsKernel
|
||||
{
|
||||
// Call the vectorized_read_4 function (easy to disassemble in debugger)
|
||||
auto buffer = vectorized_read_4(p_input, 4);
|
||||
|
||||
|
||||
printf(" Position [4]: Read 4 elements in one operation\n");
|
||||
printf(" buffer[0] = %.0f\n", static_cast<float>(buffer[number<0>{}]));
|
||||
printf(" buffer[1] = %.0f\n", static_cast<float>(buffer[number<1>{}]));
|
||||
@@ -329,10 +332,10 @@ struct TensorFundamentalsKernel
|
||||
write_buffer[number<1>{}] = 101.0f;
|
||||
write_buffer[number<2>{}] = 102.0f;
|
||||
write_buffer[number<3>{}] = 103.0f;
|
||||
|
||||
|
||||
// Call the vectorized_write_4 function (easy to disassemble in debugger)
|
||||
vectorized_write_4(p_output, 4, write_buffer);
|
||||
|
||||
|
||||
printf(" Position [4-7]: Wrote 4 elements in one operation\n");
|
||||
printf(" Wrote: 100, 101, 102, 103\n");
|
||||
printf(" ✓ 4x faster than writing elements individually!\n");
|
||||
@@ -343,30 +346,32 @@ struct TensorFundamentalsKernel
|
||||
printf("Example 4: Vectorized copy - INLINE usage (TRUE vectorization)\n");
|
||||
{
|
||||
auto output_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_output,
|
||||
make_tuple(H*W*C),
|
||||
make_tuple(1)
|
||||
);
|
||||
p_output, make_tuple(H * W * C), make_tuple(1));
|
||||
auto out_flat_desc = output_flat_view.get_tensor_descriptor();
|
||||
|
||||
|
||||
// Copy first 8 elements using vector size 4 (2 iterations)
|
||||
// THIS is where real vectorization happens - inline, no function calls!
|
||||
printf(" Copying first 8 elements using 2 vectorized operations:\n");
|
||||
for(index_t i = 0; i < 8; i += 4) {
|
||||
auto in_coord = make_tensor_coordinate(flat_desc, make_tuple(i));
|
||||
for(index_t i = 0; i < 8; i += 4)
|
||||
{
|
||||
auto in_coord = make_tensor_coordinate(flat_desc, make_tuple(i));
|
||||
auto out_coord = make_tensor_coordinate(out_flat_desc, make_tuple(i));
|
||||
|
||||
|
||||
// Read 4 elements - INLINE vectorized load (not through function)
|
||||
auto buffer = flat_view.template get_vectorized_elements<
|
||||
thread_buffer<DataType, 4>>(in_coord, 0);
|
||||
|
||||
auto buffer =
|
||||
flat_view.template get_vectorized_elements<thread_buffer<DataType, 4>>(in_coord,
|
||||
0);
|
||||
|
||||
// Write 4 elements (skip positions 4-7 which we already wrote)
|
||||
if(i != 4) {
|
||||
if(i != 4)
|
||||
{
|
||||
output_flat_view.set_vectorized_elements(out_coord, 0, buffer);
|
||||
}
|
||||
|
||||
printf(" Iteration %ld: Copied elements [%ld-%ld]\n",
|
||||
static_cast<long>(i/4), static_cast<long>(i), static_cast<long>(i+3));
|
||||
|
||||
printf(" Iteration %ld: Copied elements [%ld-%ld]\n",
|
||||
static_cast<long>(i / 4),
|
||||
static_cast<long>(i),
|
||||
static_cast<long>(i + 3));
|
||||
}
|
||||
printf(" ✓ Copied 8 elements with only 2 memory operations!\n");
|
||||
printf(" ✓ THIS loop shows true vectorization in assembly!\n\n");
|
||||
@@ -390,45 +395,39 @@ struct TensorFundamentalsKernel
|
||||
// Create two different views of the same memory
|
||||
// View A: [H, W, C] = [2, 3, 2]
|
||||
auto view_hwc = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input,
|
||||
make_tuple(H, W, C),
|
||||
make_tuple(W*C, C, 1)
|
||||
);
|
||||
p_input, make_tuple(H, W, C), make_tuple(W * C, C, 1));
|
||||
|
||||
// View B: [HW, C] = [6, 2] - flattened spatial dimensions
|
||||
auto view_hw_c = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input,
|
||||
make_tuple(H*W, C),
|
||||
make_tuple(C, 1)
|
||||
);
|
||||
p_input, make_tuple(H * W, C), make_tuple(C, 1));
|
||||
|
||||
printf("Two views of same memory:\n");
|
||||
printf(" View A: [H=%ld, W=%ld, C=%ld]\n",
|
||||
static_cast<long>(H), static_cast<long>(W), static_cast<long>(C));
|
||||
printf(" View B: [HW=%ld, C=%ld]\n",
|
||||
static_cast<long>(H*W), static_cast<long>(C));
|
||||
static_cast<long>(H),
|
||||
static_cast<long>(W),
|
||||
static_cast<long>(C));
|
||||
printf(" View B: [HW=%ld, C=%ld]\n", static_cast<long>(H * W), static_cast<long>(C));
|
||||
|
||||
// Show they access the same data
|
||||
printf("\nAccessing same element through different views:\n");
|
||||
|
||||
// Access element at h=1, w=1, c=0 through View A
|
||||
auto desc_a = view_hwc.get_tensor_descriptor();
|
||||
auto desc_a = view_hwc.get_tensor_descriptor();
|
||||
auto coord_a = make_tensor_coordinate(desc_a, make_tuple(1, 1, 0));
|
||||
auto buffer_a = view_hwc.template get_vectorized_elements<thread_buffer<DataType, 1>>(
|
||||
coord_a, 0);
|
||||
auto buffer_a =
|
||||
view_hwc.template get_vectorized_elements<thread_buffer<DataType, 1>>(coord_a, 0);
|
||||
DataType val_a = buffer_a[number<0>{}];
|
||||
|
||||
// Access same element through View B at hw=4 (1*3+1), c=0
|
||||
auto desc_b = view_hw_c.get_tensor_descriptor();
|
||||
auto desc_b = view_hw_c.get_tensor_descriptor();
|
||||
auto coord_b = make_tensor_coordinate(desc_b, make_tuple(4, 0));
|
||||
auto buffer_b = view_hw_c.template get_vectorized_elements<thread_buffer<DataType, 1>>(
|
||||
coord_b, 0);
|
||||
auto buffer_b =
|
||||
view_hw_c.template get_vectorized_elements<thread_buffer<DataType, 1>>(coord_b, 0);
|
||||
DataType val_b = buffer_b[number<0>{}];
|
||||
|
||||
printf(" View A[1,1,0] = %.0f\n", static_cast<float>(val_a));
|
||||
printf(" View B[4,0] = %.0f (same value!)\n", static_cast<float>(val_b));
|
||||
printf(" Both access offset %ld in memory\n\n",
|
||||
static_cast<long>(coord_a.get_offset()));
|
||||
printf(" Both access offset %ld in memory\n\n", static_cast<long>(coord_a.get_offset()));
|
||||
|
||||
//==================================================================
|
||||
// PART 6: PRACTICAL EXAMPLE - Copy with Views
|
||||
@@ -438,24 +437,26 @@ struct TensorFundamentalsKernel
|
||||
|
||||
// Create output view
|
||||
auto output_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_output,
|
||||
make_tuple(H, W, C),
|
||||
make_tuple(W*C, C, 1)
|
||||
);
|
||||
p_output, make_tuple(H, W, C), make_tuple(W * C, C, 1));
|
||||
auto out_desc = output_view.get_tensor_descriptor();
|
||||
|
||||
// Copy all elements using tensor_view API
|
||||
index_t count = 0;
|
||||
for(index_t h = 0; h < H; h++) {
|
||||
for(index_t w = 0; w < W; w++) {
|
||||
for(index_t c = 0; c < C; c++) {
|
||||
for(index_t h = 0; h < H; h++)
|
||||
{
|
||||
for(index_t w = 0; w < W; w++)
|
||||
{
|
||||
for(index_t c = 0; c < C; c++)
|
||||
{
|
||||
// Read from input
|
||||
auto in_coord = make_tensor_coordinate(desc, make_tuple(h, w, c));
|
||||
auto in_buffer = view1.template get_vectorized_elements<
|
||||
thread_buffer<DataType, 1>>(in_coord, 0);
|
||||
auto in_buffer =
|
||||
view1.template get_vectorized_elements<thread_buffer<DataType, 1>>(in_coord,
|
||||
0);
|
||||
|
||||
// Write to output (except [0,0,1] which we already wrote as 99)
|
||||
if(!(h == 0 && w == 0 && c == 1)) {
|
||||
if(!(h == 0 && w == 0 && c == 1))
|
||||
{
|
||||
auto out_coord = make_tensor_coordinate(out_desc, make_tuple(h, w, c));
|
||||
output_view.set_vectorized_elements(out_coord, 0, in_buffer);
|
||||
}
|
||||
@@ -489,7 +490,8 @@ int main()
|
||||
// Initialize HIP
|
||||
int device_count;
|
||||
hip_check_error(hipGetDeviceCount(&device_count));
|
||||
if(device_count == 0) {
|
||||
if(device_count == 0)
|
||||
{
|
||||
std::cerr << "No GPU devices found!\n";
|
||||
return 1;
|
||||
}
|
||||
@@ -500,36 +502,42 @@ int main()
|
||||
std::cout << "Using GPU: " << props.name << "\n";
|
||||
|
||||
// Small tensor for demonstration
|
||||
constexpr index_t H = 2;
|
||||
constexpr index_t W = 3;
|
||||
constexpr index_t C = 2;
|
||||
constexpr index_t H = 2;
|
||||
constexpr index_t W = 3;
|
||||
constexpr index_t C = 2;
|
||||
constexpr index_t size = H * W * C;
|
||||
|
||||
std::cout << "\nTensor configuration:\n";
|
||||
std::cout << " Shape: [" << H << ", " << W << ", " << C << "]\n";
|
||||
std::cout << " Total elements: " << size << "\n";
|
||||
std::cout << " Layout: Row-major (strides = [" << W*C << ", " << C << ", 1])\n\n";
|
||||
std::cout << " Layout: Row-major (strides = [" << W * C << ", " << C << ", 1])\n\n";
|
||||
|
||||
// Create test data: 1, 2, 3, 4, ... 12
|
||||
std::vector<float> h_input(size);
|
||||
std::iota(h_input.begin(), h_input.end(), 1.0f);
|
||||
|
||||
std::cout << "Input data (row-major memory order):\n";
|
||||
for(index_t i = 0; i < size; ++i) {
|
||||
if(i % C == 0 && i > 0) std::cout << " | ";
|
||||
for(index_t i = 0; i < size; ++i)
|
||||
{
|
||||
if(i % C == 0 && i > 0)
|
||||
std::cout << " | ";
|
||||
std::cout << std::setw(2) << h_input[i] << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "\nLogical view [H,W,C]:\n";
|
||||
for(index_t h = 0; h < H; h++) {
|
||||
for(index_t h = 0; h < H; h++)
|
||||
{
|
||||
std::cout << " H=" << h << ": ";
|
||||
for(index_t w = 0; w < W; w++) {
|
||||
for(index_t w = 0; w < W; w++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(index_t c = 0; c < C; c++) {
|
||||
for(index_t c = 0; c < C; c++)
|
||||
{
|
||||
index_t idx = h * W * C + w * C + c;
|
||||
std::cout << std::setw(2) << h_input[idx];
|
||||
if(c < C-1) std::cout << ",";
|
||||
if(c < C - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << "] ";
|
||||
}
|
||||
@@ -555,14 +563,15 @@ int main()
|
||||
std::cout << "=====================================\n";
|
||||
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
TensorFundamentalsKernel<float>{},
|
||||
dim3(1), // 1 block
|
||||
dim3(block_size), // 64 threads
|
||||
0, // no shared memory
|
||||
static_cast<const float*>(d_input.GetDeviceBuffer()),
|
||||
static_cast<float*>(d_output.GetDeviceBuffer()),
|
||||
H, W, C));
|
||||
make_kernel<block_size>(TensorFundamentalsKernel<float>{},
|
||||
dim3(1), // 1 block
|
||||
dim3(block_size), // 64 threads
|
||||
0, // no shared memory
|
||||
static_cast<const float*>(d_input.GetDeviceBuffer()),
|
||||
static_cast<float*>(d_output.GetDeviceBuffer()),
|
||||
H,
|
||||
W,
|
||||
C));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
std::cout << "=====================================\n";
|
||||
@@ -574,17 +583,19 @@ int main()
|
||||
// Verify results
|
||||
std::cout << "\nOutput verification:\n";
|
||||
bool passed = true;
|
||||
for(index_t i = 0; i < size; ++i) {
|
||||
float expected = (i == 1) ? 99.0f : h_input[i]; // We wrote 99 to position [0,0,1]
|
||||
if(std::abs(h_output[i] - expected) > 1e-6f) {
|
||||
for(index_t i = 0; i < size; ++i)
|
||||
{
|
||||
float expected = (i == 1) ? 99.0f : h_input[i]; // We wrote 99 to position [0,0,1]
|
||||
if(std::abs(h_output[i] - expected) > 1e-6f)
|
||||
{
|
||||
passed = false;
|
||||
std::cout << " ✗ Mismatch at index " << i
|
||||
<< ": expected " << expected
|
||||
<< ", got " << h_output[i] << "\n";
|
||||
std::cout << " ✗ Mismatch at index " << i << ": expected " << expected << ", got "
|
||||
<< h_output[i] << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if(passed) {
|
||||
if(passed)
|
||||
{
|
||||
std::cout << " ✓ All elements correct!\n";
|
||||
std::cout << " ✓ output[0,0,1] = 99 (modified as expected)\n";
|
||||
std::cout << " ✓ All other elements copied correctly\n";
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TensorAdaptorsKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 64;
|
||||
@@ -33,43 +33,43 @@ struct TensorAdaptorsKernel
|
||||
printf("PART 1: make_single_stage_tensor_adaptor\n");
|
||||
printf("=========================================\n\n");
|
||||
|
||||
printf("Purpose: Create a tensor adaptor with transformations applied in a single stage.\n");
|
||||
printf(
|
||||
"Purpose: Create a tensor adaptor with transformations applied in a single stage.\n");
|
||||
printf("This is the foundation for building complex layout transformations.\n\n");
|
||||
|
||||
// Example 1.1: Simple dimension split (Unmerge)
|
||||
printf("Example 1.1: Split M dimension for tiling\n");
|
||||
printf("------------------------------------------\n");
|
||||
{
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 32;
|
||||
|
||||
printf("Input layout: [M=%ld, K=%ld]\n", static_cast<long>(M), static_cast<long>(K));
|
||||
printf("Goal: Split M into [M0=%ld, M1=%ld] for tiling\n",
|
||||
static_cast<long>(M0), static_cast<long>(M1));
|
||||
printf("Goal: Split M into [M0=%ld, M1=%ld] for tiling\n",
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1));
|
||||
|
||||
auto transforms = make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})
|
||||
);
|
||||
auto transforms =
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{}));
|
||||
|
||||
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{});
|
||||
auto upper_dims = make_tuple(sequence<0, 1>{}, sequence<2>{});
|
||||
|
||||
auto adaptor = make_single_stage_tensor_adaptor(
|
||||
transforms, lower_dims, upper_dims
|
||||
);
|
||||
auto adaptor = make_single_stage_tensor_adaptor(transforms, lower_dims, upper_dims);
|
||||
|
||||
printf("\nAdaptor created:\n");
|
||||
printf(" Input: [M, K] = [%ld, %ld]\n",
|
||||
static_cast<long>(M), static_cast<long>(K));
|
||||
printf(" Input: [M, K] = [%ld, %ld]\n", static_cast<long>(M), static_cast<long>(K));
|
||||
printf(" Output: [M0, M1, K] = [%ld, %ld, %ld]\n",
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K));
|
||||
|
||||
auto top_idx = make_tuple(1, 16, 32);
|
||||
auto top_idx = make_tuple(1, 16, 32);
|
||||
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
|
||||
|
||||
|
||||
printf("\nTest: [M0=1, M1=16, K=32] -> [M=%ld, K=%ld]\n",
|
||||
static_cast<long>(bottom_idx[number<0>{}]),
|
||||
static_cast<long>(bottom_idx[number<1>{}]));
|
||||
@@ -81,8 +81,8 @@ struct TensorAdaptorsKernel
|
||||
printf("Example 1.2: GEMM C Matrix Tiling (Interleaved)\n");
|
||||
printf("------------------------------------------------\n");
|
||||
{
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 64;
|
||||
constexpr index_t N0 = 4;
|
||||
@@ -90,25 +90,23 @@ struct TensorAdaptorsKernel
|
||||
|
||||
printf("Input: [M=%ld, N=%ld]\n", static_cast<long>(M), static_cast<long>(N));
|
||||
printf("Output: [M0=%ld, N0=%ld, M1=%ld, N1=%ld] (interleaved)\n",
|
||||
static_cast<long>(M0), static_cast<long>(N0),
|
||||
static_cast<long>(M1), static_cast<long>(N1));
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(N0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(N1));
|
||||
|
||||
auto transforms = make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<N0>{}, number<N1>{}))
|
||||
);
|
||||
auto transforms =
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<N0>{}, number<N1>{})));
|
||||
|
||||
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{});
|
||||
auto upper_dims = make_tuple(
|
||||
sequence<0, 2>{}, // M splits to dims 0,2
|
||||
sequence<1, 3>{} // N splits to dims 1,3
|
||||
auto upper_dims = make_tuple(sequence<0, 2>{}, // M splits to dims 0,2
|
||||
sequence<1, 3>{} // N splits to dims 1,3
|
||||
);
|
||||
|
||||
auto adaptor = make_single_stage_tensor_adaptor(
|
||||
transforms, lower_dims, upper_dims
|
||||
);
|
||||
auto adaptor = make_single_stage_tensor_adaptor(transforms, lower_dims, upper_dims);
|
||||
|
||||
auto top_idx = make_tuple(2, 3, 16, 32);
|
||||
auto top_idx = make_tuple(2, 3, 16, 32);
|
||||
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
|
||||
printf("\nTest: [M0=2, N0=3, M1=16, N1=32] -> [M=%ld, N=%ld]\n",
|
||||
static_cast<long>(bottom_idx[number<0>{}]),
|
||||
@@ -130,43 +128,44 @@ struct TensorAdaptorsKernel
|
||||
printf("Example 2.1: Two-Stage Hierarchical Tiling\n");
|
||||
printf("-------------------------------------------\n");
|
||||
{
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 64;
|
||||
constexpr index_t K0 = 4;
|
||||
constexpr index_t K1 = 32;
|
||||
|
||||
printf("Stage 1: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(M), static_cast<long>(K),
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
|
||||
static_cast<long>(M),
|
||||
static_cast<long>(K),
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K));
|
||||
|
||||
auto stage1_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})
|
||||
),
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{})
|
||||
);
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
printf("Stage 2: [M0=%ld, M1=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K),
|
||||
static_cast<long>(M0), static_cast<long>(M1),
|
||||
static_cast<long>(K0), static_cast<long>(K1));
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K),
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K0),
|
||||
static_cast<long>(K1));
|
||||
|
||||
auto final_adaptor = transform_tensor_adaptor(
|
||||
stage1_adaptor,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<M0>{}),
|
||||
make_pass_through_transform(number<M1>{}),
|
||||
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))
|
||||
),
|
||||
make_tuple(make_pass_through_transform(number<M0>{}),
|
||||
make_pass_through_transform(number<M1>{}),
|
||||
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{})
|
||||
);
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}));
|
||||
|
||||
auto top_idx = make_tuple(2, 32, 3, 16);
|
||||
auto top_idx = make_tuple(2, 32, 3, 16);
|
||||
auto bottom_idx = final_adaptor.calculate_bottom_index(top_idx);
|
||||
printf("\nTest: [M0=2, M1=32, K0=3, K1=16] -> [M=%ld, K=%ld]\n",
|
||||
static_cast<long>(bottom_idx[number<0>{}]),
|
||||
@@ -187,49 +186,53 @@ struct TensorAdaptorsKernel
|
||||
printf("Example 3.1: Chain Two Adaptors\n");
|
||||
printf("--------------------------------\n");
|
||||
{
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 32;
|
||||
constexpr index_t K0 = 4;
|
||||
constexpr index_t K1 = 16;
|
||||
|
||||
printf("Adaptor A: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(M), static_cast<long>(K),
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
|
||||
static_cast<long>(M),
|
||||
static_cast<long>(K),
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K));
|
||||
|
||||
auto adaptor_a = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})
|
||||
),
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{})
|
||||
);
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
printf("Adaptor B: [M0=%ld, M1=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K),
|
||||
static_cast<long>(M0), static_cast<long>(M1),
|
||||
static_cast<long>(K0), static_cast<long>(K1));
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K),
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K0),
|
||||
static_cast<long>(K1));
|
||||
|
||||
auto adaptor_b = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<M0>{}),
|
||||
make_pass_through_transform(number<M1>{}),
|
||||
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))
|
||||
),
|
||||
make_tuple(make_pass_through_transform(number<M0>{}),
|
||||
make_pass_through_transform(number<M1>{}),
|
||||
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{})
|
||||
);
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}));
|
||||
|
||||
auto chained = chain_tensor_adaptors(adaptor_a, adaptor_b);
|
||||
|
||||
printf("\nChained: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
|
||||
static_cast<long>(M), static_cast<long>(K),
|
||||
static_cast<long>(M0), static_cast<long>(M1),
|
||||
static_cast<long>(K0), static_cast<long>(K1));
|
||||
static_cast<long>(M),
|
||||
static_cast<long>(K),
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K0),
|
||||
static_cast<long>(K1));
|
||||
|
||||
auto top_idx = make_tuple(2, 16, 3, 8);
|
||||
auto top_idx = make_tuple(2, 16, 3, 8);
|
||||
auto bottom_idx = chained.calculate_bottom_index(top_idx);
|
||||
printf("Test: [M0=2, M1=16, K0=3, K1=8] -> [M=%ld, K=%ld]\n",
|
||||
static_cast<long>(bottom_idx[number<0>{}]),
|
||||
@@ -245,32 +248,33 @@ struct TensorAdaptorsKernel
|
||||
printf("PART 4: Real-World GEMM Tiling Example\n");
|
||||
printf("=======================================\n\n");
|
||||
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t MWaves = 4;
|
||||
constexpr index_t NWaves = 4;
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t MWaves = 4;
|
||||
constexpr index_t NWaves = 4;
|
||||
constexpr index_t MPerXDL = 16;
|
||||
constexpr index_t NPerXDL = 16;
|
||||
constexpr index_t M0 = M / (MWaves * MPerXDL);
|
||||
constexpr index_t N0 = N / (NWaves * NPerXDL);
|
||||
constexpr index_t M0 = M / (MWaves * MPerXDL);
|
||||
constexpr index_t N0 = N / (NWaves * NPerXDL);
|
||||
|
||||
printf("GEMM C Matrix: [M=%ld, N=%ld]\n",
|
||||
static_cast<long>(M), static_cast<long>(N));
|
||||
printf("GEMM C Matrix: [M=%ld, N=%ld]\n", static_cast<long>(M), static_cast<long>(N));
|
||||
printf("Tiling: [M0=%ld, N0=%ld, M1=%ld, N1=%ld, M2=%ld, N2=%ld]\n",
|
||||
static_cast<long>(M0), static_cast<long>(N0),
|
||||
static_cast<long>(MWaves), static_cast<long>(NWaves),
|
||||
static_cast<long>(MPerXDL), static_cast<long>(NPerXDL));
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(N0),
|
||||
static_cast<long>(MWaves),
|
||||
static_cast<long>(NWaves),
|
||||
static_cast<long>(MPerXDL),
|
||||
static_cast<long>(NPerXDL));
|
||||
|
||||
auto adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<MWaves>{}, number<MPerXDL>{})),
|
||||
make_unmerge_transform(make_tuple(number<N0>{}, number<NWaves>{}, number<NPerXDL>{}))
|
||||
),
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<M0>{}, number<MWaves>{}, number<MPerXDL>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<N0>{}, number<NWaves>{}, number<NPerXDL>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{})
|
||||
);
|
||||
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}));
|
||||
|
||||
auto top_idx = make_tuple(2, 3, 1, 2, 8, 12);
|
||||
auto top_idx = make_tuple(2, 3, 1, 2, 8, 12);
|
||||
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
|
||||
printf("\nTest: [M0=2, N0=3, M1=1, N1=2, M2=8, N2=12] -> [M=%ld, N=%ld]\n",
|
||||
static_cast<long>(bottom_idx[number<0>{}]),
|
||||
@@ -288,8 +292,8 @@ struct TensorAdaptorsKernel
|
||||
printf("Demonstrating padding transform with coordinate mapping.\n\n");
|
||||
|
||||
// Original size: 10 elements, pad to 16
|
||||
constexpr index_t OrigSize = 10;
|
||||
constexpr index_t PadRight = 6;
|
||||
constexpr index_t OrigSize = 10;
|
||||
constexpr index_t PadRight = 6;
|
||||
constexpr index_t TotalSize = OrigSize + PadRight;
|
||||
|
||||
printf("Original size: %ld elements\n", static_cast<long>(OrigSize));
|
||||
@@ -301,30 +305,35 @@ struct TensorAdaptorsKernel
|
||||
make_naive_tensor_descriptor_packed(make_tuple(number<OrigSize>{})),
|
||||
make_tuple(make_right_pad_transform(number<OrigSize>{}, number<PadRight>{})),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(sequence<0>{})
|
||||
);
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
printf("Coordinate mapping and memory reads:\n");
|
||||
printf("------------------------------------\n\n");
|
||||
|
||||
printf("Real area (indices 0-9):\n");
|
||||
for(index_t i = 0; i < OrigSize; i++) {
|
||||
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
|
||||
for(index_t i = 0; i < OrigSize; i++)
|
||||
{
|
||||
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
|
||||
index_t offset = coord.get_offset();
|
||||
DataType val = p_data[offset];
|
||||
|
||||
DataType val = p_data[offset];
|
||||
|
||||
printf(" Index %ld -> offset %ld -> value %.1f (real data)\n",
|
||||
static_cast<long>(i), static_cast<long>(offset), static_cast<float>(val));
|
||||
static_cast<long>(i),
|
||||
static_cast<long>(offset),
|
||||
static_cast<float>(val));
|
||||
}
|
||||
|
||||
printf("\nPadded area (indices 10-15):\n");
|
||||
for(index_t i = OrigSize; i < TotalSize; i++) {
|
||||
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
|
||||
for(index_t i = OrigSize; i < TotalSize; i++)
|
||||
{
|
||||
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
|
||||
index_t offset = coord.get_offset();
|
||||
DataType val = p_data[offset];
|
||||
|
||||
DataType val = p_data[offset];
|
||||
|
||||
printf(" Index %ld -> offset %ld -> value %.1f (wraps around)\n",
|
||||
static_cast<long>(i), static_cast<long>(offset), static_cast<float>(val));
|
||||
static_cast<long>(i),
|
||||
static_cast<long>(offset),
|
||||
static_cast<float>(val));
|
||||
}
|
||||
|
||||
printf("\nKey Observations:\n");
|
||||
@@ -344,13 +353,11 @@ struct TensorAdaptorsKernel
|
||||
printf("Demonstrating replicate transform with complete coordinate mapping.\n\n");
|
||||
|
||||
// Start with flattened 1D tensor
|
||||
constexpr index_t Size = 16; // H*W = 2*8
|
||||
constexpr index_t Size = 16; // H*W = 2*8
|
||||
|
||||
printf("Step 1: Create initial 1D descriptor [Size=%ld]\n", static_cast<long>(Size));
|
||||
|
||||
auto desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<Size>{})
|
||||
);
|
||||
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<Size>{}));
|
||||
|
||||
printf(" Initial: [16] (flattened)\n\n");
|
||||
|
||||
@@ -362,11 +369,11 @@ struct TensorAdaptorsKernel
|
||||
auto desc_stage1 = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(
|
||||
make_replicate_transform(make_tuple(number<8>{})), // Broadcast to 8
|
||||
make_unmerge_transform(make_tuple(number<8>{}, number<2>{})) // Split 16 -> [8,2]
|
||||
),
|
||||
make_tuple(sequence<>{}, sequence<0>{}), // Replicate has no input, Unmerge uses dim 0
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}) // Rep0=dim0, Unmerge produces dims 1,2
|
||||
make_replicate_transform(make_tuple(number<8>{})), // Broadcast to 8
|
||||
make_unmerge_transform(make_tuple(number<8>{}, number<2>{})) // Split 16 -> [8,2]
|
||||
),
|
||||
make_tuple(sequence<>{}, sequence<0>{}), // Replicate has no input, Unmerge uses dim 0
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}) // Rep0=dim0, Unmerge produces dims 1,2
|
||||
);
|
||||
|
||||
printf("\n After Stage 1: [Rep0=8, Dim0=8, Dim1=2]\n");
|
||||
@@ -378,55 +385,63 @@ struct TensorAdaptorsKernel
|
||||
auto desc_final = transform_tensor_descriptor(
|
||||
desc_stage1,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<8>{}, number<8>{})), // Merge Rep0, Dim0
|
||||
make_pass_through_transform(number<2>{}) // Dim1 unchanged
|
||||
),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}), // Merge dims 0,1; pass-through dim 2
|
||||
make_tuple(sequence<0>{}, sequence<1>{}) // Output: [Merged, Dim1]
|
||||
make_merge_transform(make_tuple(number<8>{}, number<8>{})), // Merge Rep0, Dim0
|
||||
make_pass_through_transform(number<2>{}) // Dim1 unchanged
|
||||
),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}), // Merge dims 0,1; pass-through dim 2
|
||||
make_tuple(sequence<0>{}, sequence<1>{}) // Output: [Merged, Dim1]
|
||||
);
|
||||
|
||||
printf("\n Final: [Merged=64, Dim1=2]\n\n");
|
||||
|
||||
// Comprehensive coordinate testing - ALL coordinates
|
||||
printf("COORDINATE MAPPING TEST - ALL %ld coordinates:\n",
|
||||
static_cast<long>(64 * 2));
|
||||
printf("COORDINATE MAPPING TEST - ALL %ld coordinates:\n", static_cast<long>(64 * 2));
|
||||
printf("=======================================================\n");
|
||||
printf("Format: [Merged, Dim1] -> memory_offset\n\n");
|
||||
|
||||
auto lengths_final = desc_final.get_lengths();
|
||||
index_t merged_len = lengths_final[number<0>{}];
|
||||
index_t dim1_len = lengths_final[number<1>{}];
|
||||
index_t dim1_len = lengths_final[number<1>{}];
|
||||
|
||||
printf("Descriptor: [Merged=%ld, Dim1=%ld] = %ld total coordinates\n",
|
||||
static_cast<long>(merged_len), static_cast<long>(dim1_len),
|
||||
static_cast<long>(merged_len),
|
||||
static_cast<long>(dim1_len),
|
||||
static_cast<long>(merged_len * dim1_len));
|
||||
printf("Memory: Only 16 locations (broadcasting effect!)\n\n");
|
||||
|
||||
// Print ALL coordinates to show broadcasting pattern
|
||||
index_t count = 0;
|
||||
for(index_t merged = 0; merged < merged_len; merged++) {
|
||||
for(index_t dim1 = 0; dim1 < dim1_len; dim1++) {
|
||||
auto coord = make_tensor_coordinate(desc_final, make_tuple(merged, dim1));
|
||||
for(index_t merged = 0; merged < merged_len; merged++)
|
||||
{
|
||||
for(index_t dim1 = 0; dim1 < dim1_len; dim1++)
|
||||
{
|
||||
auto coord = make_tensor_coordinate(desc_final, make_tuple(merged, dim1));
|
||||
index_t offset = coord.get_offset();
|
||||
printf(" [%2ld, %ld] -> offset %2ld",
|
||||
static_cast<long>(merged), static_cast<long>(dim1),
|
||||
static_cast<long>(merged),
|
||||
static_cast<long>(dim1),
|
||||
static_cast<long>(offset));
|
||||
|
||||
|
||||
// Add newline every 4 coordinates for readability
|
||||
count++;
|
||||
if(count % 4 == 0) {
|
||||
if(count % 4 == 0)
|
||||
{
|
||||
printf("\n");
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
printf(" | ");
|
||||
}
|
||||
}
|
||||
}
|
||||
if(count % 4 != 0) printf("\n");
|
||||
if(count % 4 != 0)
|
||||
printf("\n");
|
||||
|
||||
printf("\nKey Observations:\n");
|
||||
printf(" - Total coordinates: %ld (Merged=%ld × Dim1=%ld)\n",
|
||||
static_cast<long>(merged_len * dim1_len),
|
||||
static_cast<long>(merged_len), static_cast<long>(dim1_len));
|
||||
static_cast<long>(merged_len),
|
||||
static_cast<long>(dim1_len));
|
||||
printf(" - Memory locations: 16 (original size)\n");
|
||||
printf(" - Broadcasting ratio: %ld:1 (each memory location accessed by %ld coordinates)\n",
|
||||
static_cast<long>((merged_len * dim1_len) / 16),
|
||||
@@ -436,7 +451,8 @@ struct TensorAdaptorsKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* p_data) const
|
||||
{
|
||||
if(get_thread_id() != 0) return;
|
||||
if(get_thread_id() != 0)
|
||||
return;
|
||||
|
||||
printf("\n=== TENSOR ADAPTORS IN CK_TILE ===\n\n");
|
||||
|
||||
@@ -476,7 +492,8 @@ int main()
|
||||
|
||||
int device_count;
|
||||
hip_check_error(hipGetDeviceCount(&device_count));
|
||||
if(device_count == 0) {
|
||||
if(device_count == 0)
|
||||
{
|
||||
std::cerr << "No GPU devices found!\n";
|
||||
return 1;
|
||||
}
|
||||
@@ -488,13 +505,15 @@ int main()
|
||||
|
||||
// Allocate data for padding example (16 elements, but only first 10 have real data)
|
||||
constexpr index_t data_size = 16;
|
||||
std::vector<float> h_data(data_size, 0.0f); // Initialize all to 0
|
||||
std::iota(h_data.begin(), h_data.begin() + 10, 1.0f); // First 10: 1,2,3,...,10
|
||||
std::vector<float> h_data(data_size, 0.0f); // Initialize all to 0
|
||||
std::iota(h_data.begin(), h_data.begin() + 10, 1.0f); // First 10: 1,2,3,...,10
|
||||
|
||||
std::cout << "\nTest data (first 10 real, last 6 padding zeros): ";
|
||||
for(size_t i = 0; i < h_data.size(); i++) {
|
||||
for(size_t i = 0; i < h_data.size(); i++)
|
||||
{
|
||||
std::cout << h_data[i];
|
||||
if(i < h_data.size() - 1) std::cout << " ";
|
||||
if(i < h_data.size() - 1)
|
||||
std::cout << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
|
||||
@@ -508,12 +527,11 @@ int main()
|
||||
std::cout << "=====================================\n";
|
||||
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
TensorAdaptorsKernel<float>{},
|
||||
dim3(1),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const float*>(d_data.GetDeviceBuffer())));
|
||||
make_kernel<block_size>(TensorAdaptorsKernel<float>{},
|
||||
dim3(1),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const float*>(d_data.GetDeviceBuffer())));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
std::cout << "=====================================\n";
|
||||
|
||||
@@ -21,16 +21,16 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct PaddingTileKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 64;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* p_data,
|
||||
index_t orig_size,
|
||||
index_t padded_size) const
|
||||
CK_TILE_DEVICE void
|
||||
operator()(const DataType* p_data, index_t orig_size, index_t padded_size) const
|
||||
{
|
||||
if(get_thread_id() != 0) return;
|
||||
if(get_thread_id() != 0)
|
||||
return;
|
||||
|
||||
printf("\n=== PADDING WITH TILE WINDOWS ===\n\n");
|
||||
|
||||
@@ -38,24 +38,17 @@ struct PaddingTileKernel
|
||||
printf("Padded size: %ld\n\n", static_cast<long>(padded_size));
|
||||
|
||||
// Step 1: Create original descriptor (runtime)
|
||||
auto desc_orig = make_naive_tensor_descriptor(
|
||||
make_tuple(orig_size),
|
||||
make_tuple(1)
|
||||
);
|
||||
auto desc_orig = make_naive_tensor_descriptor(make_tuple(orig_size), make_tuple(1));
|
||||
|
||||
// Step 2: Apply padding transform
|
||||
index_t pad_amount = padded_size - orig_size;
|
||||
auto desc_padded = transform_tensor_descriptor(
|
||||
desc_orig,
|
||||
make_tuple(make_right_pad_transform(orig_size, pad_amount)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(sequence<0>{})
|
||||
);
|
||||
auto desc_padded =
|
||||
transform_tensor_descriptor(desc_orig,
|
||||
make_tuple(make_right_pad_transform(orig_size, pad_amount)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
auto tensor_simple = make_tensor_view<address_space_enum::global>(
|
||||
p_data,
|
||||
desc_padded
|
||||
);
|
||||
auto tensor_simple = make_tensor_view<address_space_enum::global>(p_data, desc_padded);
|
||||
|
||||
printf("Created tensor_view (simple API, no identity value)\n");
|
||||
printf(" - Padded reads will wrap around to existing data\n\n");
|
||||
@@ -63,26 +56,28 @@ struct PaddingTileKernel
|
||||
// Step 5: Read tiles using get_vectorized_elements
|
||||
constexpr index_t tile_size = 8;
|
||||
|
||||
printf("Reading tiles of size %ld using get_vectorized_elements:\n\n",
|
||||
printf("Reading tiles of size %ld using get_vectorized_elements:\n\n",
|
||||
static_cast<long>(tile_size));
|
||||
|
||||
// Load tiles covering the entire padded range
|
||||
index_t num_tiles = (padded_size + tile_size - 1) / tile_size;
|
||||
|
||||
for(index_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) {
|
||||
|
||||
for(index_t tile_idx = 0; tile_idx < num_tiles; tile_idx++)
|
||||
{
|
||||
// Use get_vectorized_elements directly on tensor_view
|
||||
printf("Tile %ld (indices %ld-%ld):\n",
|
||||
static_cast<long>(tile_idx),
|
||||
static_cast<long>(tile_idx * tile_size),
|
||||
static_cast<long>(tile_idx * tile_size + tile_size - 1));
|
||||
|
||||
|
||||
printf(" Values: ");
|
||||
// Use static_for to access elements with compile-time indices
|
||||
static_for<0, tile_size, 1>{}([&](auto i) {
|
||||
index_t global_idx = tile_idx * tile_size + i;
|
||||
auto coord = make_tensor_coordinate(desc_padded, make_tuple(global_idx));
|
||||
auto buffer = tensor_simple.template get_vectorized_elements<
|
||||
thread_buffer<DataType, 1>>(coord, 0);
|
||||
auto coord = make_tensor_coordinate(desc_padded, make_tuple(global_idx));
|
||||
auto buffer =
|
||||
tensor_simple.template get_vectorized_elements<thread_buffer<DataType, 1>>(
|
||||
coord, 0);
|
||||
// static_for<0, 4, 1>{}([&](auto j) {
|
||||
// DataType val = buffer[number<j>{}];
|
||||
// printf("%.1f ", static_cast<float>(val));
|
||||
@@ -91,11 +86,12 @@ struct PaddingTileKernel
|
||||
printf("%.1f ", static_cast<float>(val));
|
||||
});
|
||||
printf("\n");
|
||||
|
||||
|
||||
// Check if this tile contains padding
|
||||
index_t tile_start = tile_idx * tile_size;
|
||||
index_t tile_end = tile_start + tile_size;
|
||||
if(tile_end > orig_size) {
|
||||
index_t tile_end = tile_start + tile_size;
|
||||
if(tile_end > orig_size)
|
||||
{
|
||||
printf(" Note: Elements %ld-%ld are padded (return identity value 0.0)\n",
|
||||
static_cast<long>(orig_size - tile_start),
|
||||
static_cast<long>(tile_size - 1));
|
||||
@@ -108,7 +104,6 @@ struct PaddingTileKernel
|
||||
printf(" - Out-of-bounds accesses return identity value (0.0)\n");
|
||||
printf(" - get_vectorized_elements properly handles padding\n");
|
||||
printf(" - This is the pattern used in pooling/convolution kernels\n\n");
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -120,7 +115,8 @@ int main()
|
||||
|
||||
int device_count;
|
||||
hip_check_error(hipGetDeviceCount(&device_count));
|
||||
if(device_count == 0) {
|
||||
if(device_count == 0)
|
||||
{
|
||||
std::cerr << "No GPU devices found!\n";
|
||||
return 1;
|
||||
}
|
||||
@@ -131,14 +127,15 @@ int main()
|
||||
std::cout << "Using GPU: " << props.name << "\n";
|
||||
|
||||
// Create test data: 10 real elements
|
||||
constexpr index_t orig_size = 10;
|
||||
constexpr index_t orig_size = 10;
|
||||
constexpr index_t padded_size = 16;
|
||||
|
||||
|
||||
std::vector<float> h_data(orig_size);
|
||||
std::iota(h_data.begin(), h_data.end(), 1.0f); // 1, 2, 3, ..., 10
|
||||
std::iota(h_data.begin(), h_data.end(), 1.0f); // 1, 2, 3, ..., 10
|
||||
|
||||
std::cout << "\nTest data (" << orig_size << " elements): ";
|
||||
for(auto val : h_data) {
|
||||
for(auto val : h_data)
|
||||
{
|
||||
std::cout << val << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
@@ -154,14 +151,13 @@ int main()
|
||||
std::cout << "=====================================\n";
|
||||
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
PaddingTileKernel<float>{},
|
||||
dim3(1),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const float*>(d_data.GetDeviceBuffer()),
|
||||
orig_size,
|
||||
padded_size));
|
||||
make_kernel<block_size>(PaddingTileKernel<float>{},
|
||||
dim3(1),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const float*>(d_data.GetDeviceBuffer()),
|
||||
orig_size,
|
||||
padded_size));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
std::cout << "=====================================\n";
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct DescriptorVsAdaptorKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 64;
|
||||
@@ -38,24 +38,23 @@ struct DescriptorVsAdaptorKernel
|
||||
printf("Example 1.1: Matrix Tiling [M, K] -> [M0, M1, K]\n");
|
||||
printf("------------------------------------------------\n");
|
||||
{
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 32;
|
||||
|
||||
printf("Input: [M=%ld, K=%ld]\n", static_cast<long>(M), static_cast<long>(K));
|
||||
printf("Output: [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
|
||||
printf("Output: [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K));
|
||||
|
||||
// Create adaptor - only transformation logic
|
||||
auto adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})
|
||||
),
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{})
|
||||
);
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
printf("\nAdaptor properties:\n");
|
||||
printf(" - Stores: Transformation logic only\n");
|
||||
@@ -64,9 +63,9 @@ struct DescriptorVsAdaptorKernel
|
||||
printf(" - Cannot do: Calculate memory offsets\n\n");
|
||||
|
||||
// Test coordinate mapping
|
||||
auto top_idx = make_tuple(2, 16, 32);
|
||||
auto top_idx = make_tuple(2, 16, 32);
|
||||
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
|
||||
|
||||
|
||||
printf("Coordinate mapping test:\n");
|
||||
printf(" Input: [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(top_idx.template get<0>()),
|
||||
@@ -90,23 +89,18 @@ struct DescriptorVsAdaptorKernel
|
||||
// Define a generic 2D tiling pattern
|
||||
auto create_tiling_adaptor = [](auto M0, auto M1, auto N0, auto N1) {
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_unmerge_transform(make_tuple(N0, N1))
|
||||
),
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_unmerge_transform(make_tuple(N0, N1))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{})
|
||||
);
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}));
|
||||
};
|
||||
|
||||
// Use same pattern for different matrix sizes
|
||||
[[maybe_unused]] auto adaptor_64x64 = create_tiling_adaptor(
|
||||
number<4>{}, number<16>{}, number<4>{}, number<16>{}
|
||||
);
|
||||
|
||||
[[maybe_unused]] auto adaptor_128x128 = create_tiling_adaptor(
|
||||
number<8>{}, number<16>{}, number<8>{}, number<16>{}
|
||||
);
|
||||
[[maybe_unused]] auto adaptor_64x64 =
|
||||
create_tiling_adaptor(number<4>{}, number<16>{}, number<4>{}, number<16>{});
|
||||
|
||||
[[maybe_unused]] auto adaptor_128x128 =
|
||||
create_tiling_adaptor(number<8>{}, number<16>{}, number<8>{}, number<16>{});
|
||||
|
||||
printf("Created two adaptors with same pattern:\n");
|
||||
printf(" - 64x64 matrix: [64, 64] -> [4, 16, 4, 16]\n");
|
||||
@@ -133,12 +127,11 @@ struct DescriptorVsAdaptorKernel
|
||||
constexpr index_t K = 64;
|
||||
|
||||
// Create descriptor - includes memory information
|
||||
auto desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<M>{}, number<K>{})
|
||||
);
|
||||
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
|
||||
|
||||
printf("Created descriptor for [M=%ld, K=%ld] matrix\n",
|
||||
static_cast<long>(M), static_cast<long>(K));
|
||||
static_cast<long>(M),
|
||||
static_cast<long>(K));
|
||||
|
||||
auto space_size = desc.get_element_space_size();
|
||||
printf("\nDescriptor properties:\n");
|
||||
@@ -150,13 +143,15 @@ struct DescriptorVsAdaptorKernel
|
||||
// Calculate memory offset
|
||||
auto offset1 = desc.calculate_offset(make_tuple(10, 20));
|
||||
auto offset2 = desc.calculate_offset(make_tuple(0, 0));
|
||||
auto offset3 = desc.calculate_offset(make_tuple(M-1, K-1));
|
||||
auto offset3 = desc.calculate_offset(make_tuple(M - 1, K - 1));
|
||||
|
||||
printf("Memory offset calculations:\n");
|
||||
printf(" [10, 20] -> offset %ld (10*64 + 20)\n", static_cast<long>(offset1));
|
||||
printf(" [0, 0] -> offset %ld (first element)\n", static_cast<long>(offset2));
|
||||
printf(" [%ld, %ld] -> offset %ld (last element)\n",
|
||||
static_cast<long>(M-1), static_cast<long>(K-1), static_cast<long>(offset3));
|
||||
static_cast<long>(M - 1),
|
||||
static_cast<long>(K - 1),
|
||||
static_cast<long>(offset3));
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
@@ -165,32 +160,30 @@ struct DescriptorVsAdaptorKernel
|
||||
printf("Example 2.2: Transforming a Descriptor\n");
|
||||
printf("---------------------------------------\n");
|
||||
{
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t M = 256;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 64;
|
||||
|
||||
printf("Step 1: Create initial descriptor\n");
|
||||
auto desc_initial = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<M>{}, number<K>{})
|
||||
);
|
||||
auto desc_initial =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
|
||||
printf(" Initial: [M=%ld, K=%ld]\n", static_cast<long>(M), static_cast<long>(K));
|
||||
printf(" Memory size: %ld elements\n\n",
|
||||
printf(" Memory size: %ld elements\n\n",
|
||||
static_cast<long>(desc_initial.get_element_space_size()));
|
||||
|
||||
printf("Step 2: Transform to add tiling\n");
|
||||
auto desc_tiled = transform_tensor_descriptor(
|
||||
desc_initial,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})
|
||||
),
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{})
|
||||
);
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
printf(" Transformed: [M, K] -> [M0=%ld, M1=%ld, K=%ld]\n",
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K));
|
||||
printf(" Memory size preserved: %ld elements\n\n",
|
||||
static_cast<long>(desc_tiled.get_element_space_size()));
|
||||
|
||||
@@ -217,9 +210,7 @@ struct DescriptorVsAdaptorKernel
|
||||
constexpr index_t M = 64;
|
||||
constexpr index_t K = 32;
|
||||
|
||||
auto desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<M>{}, number<K>{})
|
||||
);
|
||||
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
|
||||
|
||||
printf("Descriptor: [M=%ld, K=%ld]\n\n", static_cast<long>(M), static_cast<long>(K));
|
||||
|
||||
@@ -243,9 +234,7 @@ struct DescriptorVsAdaptorKernel
|
||||
constexpr index_t M = 64;
|
||||
constexpr index_t K = 32;
|
||||
|
||||
auto desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<M>{}, number<K>{})
|
||||
);
|
||||
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
|
||||
|
||||
printf("Scenario: Iterate through a row of tiles\n");
|
||||
printf("Descriptor: [M=%ld, K=%ld]\n\n", static_cast<long>(M), static_cast<long>(K));
|
||||
@@ -286,28 +275,25 @@ struct DescriptorVsAdaptorKernel
|
||||
printf("Example 3.3: Moving Coordinates with Transformations\n");
|
||||
printf("----------------------------------------------------\n");
|
||||
{
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M0 = 4;
|
||||
constexpr index_t M1 = 32;
|
||||
|
||||
// Create tiled descriptor
|
||||
auto desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<M>{}, number<K>{})
|
||||
);
|
||||
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
|
||||
|
||||
auto desc_tiled = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})
|
||||
),
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
|
||||
make_pass_through_transform(number<K>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{})
|
||||
);
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
printf("Tiled descriptor: [M, K] -> [M0=%ld, M1=%ld, K=%ld]\n\n",
|
||||
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
|
||||
static_cast<long>(M0),
|
||||
static_cast<long>(M1),
|
||||
static_cast<long>(K));
|
||||
|
||||
// Create coordinate
|
||||
auto coord = make_tensor_coordinate(desc_tiled, make_tuple(1, 8, 16));
|
||||
@@ -360,7 +346,8 @@ struct DescriptorVsAdaptorKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()() const
|
||||
{
|
||||
if(get_thread_id() != 0) return;
|
||||
if(get_thread_id() != 0)
|
||||
return;
|
||||
|
||||
printf("\n=== TENSOR DESCRIPTOR VS TENSOR ADAPTOR ===\n\n");
|
||||
|
||||
@@ -400,7 +387,8 @@ int main()
|
||||
|
||||
int device_count;
|
||||
hip_check_error(hipGetDeviceCount(&device_count));
|
||||
if(device_count == 0) {
|
||||
if(device_count == 0)
|
||||
{
|
||||
std::cerr << "No GPU devices found!\n";
|
||||
return 1;
|
||||
}
|
||||
@@ -416,12 +404,9 @@ int main()
|
||||
std::cout << "\nLaunching kernel...\n";
|
||||
std::cout << "=====================================\n";
|
||||
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
DescriptorVsAdaptorKernel<float>{},
|
||||
dim3(1),
|
||||
dim3(block_size),
|
||||
0));
|
||||
launch_kernel(
|
||||
stream,
|
||||
make_kernel<block_size>(DescriptorVsAdaptorKernel<float>{}, dim3(1), dim3(block_size), 0));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
std::cout << "=====================================\n";
|
||||
|
||||
@@ -26,16 +26,16 @@
|
||||
using namespace ck_tile;
|
||||
|
||||
// Distributed HGEMM kernel using proper tile_distribution
|
||||
template<typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
|
||||
struct DistributedHgemmKernel
|
||||
{
|
||||
static constexpr index_t kWaveSize = 64; // AMD wave size
|
||||
static constexpr index_t kBlockM = 16; // MFMA M dimension
|
||||
static constexpr index_t kBlockN = 16; // MFMA N dimension
|
||||
static constexpr index_t kBlockK = 16; // MFMA K dimension per instruction
|
||||
static constexpr index_t kWaveSize = 64; // AMD wave size
|
||||
static constexpr index_t kBlockM = 16; // MFMA M dimension
|
||||
static constexpr index_t kBlockN = 16; // MFMA N dimension
|
||||
static constexpr index_t kBlockK = 16; // MFMA K dimension per instruction
|
||||
|
||||
// Use ck_tile's WarpGemm for MFMA
|
||||
using WarpGemm = WarpGemmMfmaF16F16F32M16N16K16;
|
||||
using WarpGemm = WarpGemmMfmaF16F16F32M16N16K16;
|
||||
static constexpr index_t kBlockSize = kWaveSize;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* a,
|
||||
@@ -45,18 +45,19 @@ struct DistributedHgemmKernel
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t lda, // Leading dimension of A (column-major)
|
||||
index_t ldb, // Leading dimension of B (row-major)
|
||||
index_t ldc, // Leading dimension of C (column-major)
|
||||
index_t ldd, // Leading dimension of D (column-major)
|
||||
index_t lda, // Leading dimension of A (column-major)
|
||||
index_t ldb, // Leading dimension of B (row-major)
|
||||
index_t ldc, // Leading dimension of C (column-major)
|
||||
index_t ldd, // Leading dimension of D (column-major)
|
||||
AccDataType alpha,
|
||||
AccDataType beta) const
|
||||
{
|
||||
// Calculate which 16×16 block this wave computes
|
||||
// const index_t wave_id = get_block_id() * get_block_size() / kWaveSize + threadIdx.x / kWaveSize;
|
||||
// const index_t wave_id = get_block_id() * get_block_size() / kWaveSize + threadIdx.x /
|
||||
// kWaveSize;
|
||||
const index_t wave_id = get_warp_id();
|
||||
const index_t wave_m = wave_id / (N / kBlockN);
|
||||
const index_t wave_n = wave_id % (N / kBlockN);
|
||||
const index_t wave_m = wave_id / (N / kBlockN);
|
||||
const index_t wave_n = wave_id % (N / kBlockN);
|
||||
|
||||
const index_t m_offset = wave_m * kBlockM;
|
||||
const index_t n_offset = wave_n * kBlockN;
|
||||
@@ -73,89 +74,81 @@ struct DistributedHgemmKernel
|
||||
// A is column-major: M×K with stride lda between columns
|
||||
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
a,
|
||||
make_tuple(M, K), // Shape: M×K
|
||||
make_tuple(1, lda), // Strides: column-major
|
||||
make_tuple(M, K), // Shape: M×K
|
||||
make_tuple(1, lda), // Strides: column-major
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// B is row-major: K×N with stride ldb between rows
|
||||
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
b,
|
||||
make_tuple(K, N), // Shape: K×N
|
||||
make_tuple(ldb, 1), // Strides: row-major
|
||||
make_tuple(K, N), // Shape: K×N
|
||||
make_tuple(ldb, 1), // Strides: row-major
|
||||
number<4>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// C is column-major: M×N with stride ldc between columns
|
||||
const auto c_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
c,
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldc), // Strides: column-major
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldc), // Strides: column-major
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// D is column-major: M×N with stride ldd between columns
|
||||
auto d_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
d,
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldd), // Strides: column-major
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldd), // Strides: column-major
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// Use our tested custom distributions from test_a_distribution.cpp and test_b_distribution.cpp
|
||||
// A: Column-major M×K with each thread loading 4 consecutive K values from one M position
|
||||
// Use our tested custom distributions from test_a_distribution.cpp and
|
||||
// test_b_distribution.cpp A: Column-major M×K with each thread loading 4 consecutive K
|
||||
// values from one M position
|
||||
constexpr auto a_distribution = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // No replication
|
||||
tuple<sequence<16>, // H0 (M): 16 lanes for M
|
||||
sequence<4, 4>>, // H1 (K): 4 lanes × 4 per lane
|
||||
tuple<sequence<2, 1>>, // P-dims map to H-dims
|
||||
tuple<sequence<0, 0>>, // P positions in H-dims
|
||||
sequence<2>, // Y maps to K dimension only
|
||||
sequence<1>>{} // Y at position 1
|
||||
tile_distribution_encoding<sequence<>, // No replication
|
||||
tuple<sequence<16>, // H0 (M): 16 lanes for M
|
||||
sequence<4, 4>>, // H1 (K): 4 lanes × 4 per lane
|
||||
tuple<sequence<2, 1>>, // P-dims map to H-dims
|
||||
tuple<sequence<0, 0>>, // P positions in H-dims
|
||||
sequence<2>, // Y maps to K dimension only
|
||||
sequence<1>>{} // Y at position 1
|
||||
);
|
||||
|
||||
// B: Row-major K×N with each thread loading 4 consecutive K values from one N position
|
||||
constexpr auto b_distribution = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // No replication
|
||||
tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K values
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
|
||||
tuple<sequence<0, 0>>, // P positions in H-dims
|
||||
sequence<1>, // Y maps to K dimension (H0)
|
||||
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
|
||||
sequence<>, // No replication
|
||||
tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K values
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
|
||||
tuple<sequence<0, 0>>, // P positions in H-dims
|
||||
sequence<1>, // Y maps to K dimension (H0)
|
||||
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
|
||||
);
|
||||
|
||||
// Create windows for A and B that we'll move along K
|
||||
auto a_window = make_tile_window(
|
||||
a_tensor,
|
||||
make_tuple(number<kBlockM>{}, number<kBlockK>{}),
|
||||
{m_offset, 0},
|
||||
a_distribution
|
||||
);
|
||||
auto a_window = make_tile_window(a_tensor,
|
||||
make_tuple(number<kBlockM>{}, number<kBlockK>{}),
|
||||
{m_offset, 0},
|
||||
a_distribution);
|
||||
|
||||
auto b_window = make_tile_window(
|
||||
b_tensor,
|
||||
make_tuple(number<kBlockK>{}, number<kBlockN>{}),
|
||||
{0, n_offset},
|
||||
b_distribution
|
||||
);
|
||||
auto b_window = make_tile_window(b_tensor,
|
||||
make_tuple(number<kBlockK>{}, number<kBlockN>{}),
|
||||
{0, n_offset},
|
||||
b_distribution);
|
||||
|
||||
// C distribution (column-major M×N output) - tested in test_c_distribution.cpp
|
||||
constexpr auto c_distribution = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // No replication
|
||||
tuple<sequence<4, 4>, // H0 (M): 4 groups of 4 consecutive M values
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
|
||||
tuple<sequence<0, 0>>, // P positions in H-dims
|
||||
sequence<1>, // Y maps to M dimension (H0)
|
||||
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
|
||||
sequence<>, // No replication
|
||||
tuple<sequence<4, 4>, // H0 (M): 4 groups of 4 consecutive M values
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
|
||||
tuple<sequence<0, 0>>, // P positions in H-dims
|
||||
sequence<1>, // Y maps to M dimension (H0)
|
||||
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
|
||||
);
|
||||
|
||||
// Create accumulator using our tested C distribution
|
||||
@@ -178,9 +171,10 @@ struct DistributedHgemmKernel
|
||||
|
||||
// Move windows to next K chunk using the move API
|
||||
// This efficiently updates window_origin_ without recreating the window
|
||||
if(k_iter < num_k_loops - 1) {
|
||||
a_window.move({0, kBlockK}); // Move K forward for A
|
||||
b_window.move({kBlockK, 0}); // Move K forward for B
|
||||
if(k_iter < num_k_loops - 1)
|
||||
{
|
||||
a_window.move({0, kBlockK}); // Move K forward for A
|
||||
b_window.move({kBlockK, 0}); // Move K forward for B
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,57 +185,60 @@ struct DistributedHgemmKernel
|
||||
// Load C, apply beta, and add to result
|
||||
if(std::abs(beta) > 1e-6f)
|
||||
{
|
||||
auto c_window = make_tile_window(
|
||||
c_tensor,
|
||||
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
|
||||
{m_offset, n_offset},
|
||||
c_distribution
|
||||
);
|
||||
auto c_window = make_tile_window(c_tensor,
|
||||
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
|
||||
{m_offset, n_offset},
|
||||
c_distribution);
|
||||
|
||||
const auto c_tile = load_tile(c_window);
|
||||
|
||||
// Apply beta * C + acc using ck_tile's elementwise API
|
||||
// This combines two tiles with a lambda function
|
||||
tile_elementwise_inout(
|
||||
[beta](const auto& c_val, auto& acc_val) {
|
||||
acc_val += beta * c_val;
|
||||
},
|
||||
c_tile, acc_tile);
|
||||
[beta](const auto& c_val, auto& acc_val) { acc_val += beta * c_val; },
|
||||
c_tile,
|
||||
acc_tile);
|
||||
}
|
||||
|
||||
// Store final result to D
|
||||
auto d_window = make_tile_window(
|
||||
d_tensor,
|
||||
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
|
||||
{m_offset, n_offset},
|
||||
c_distribution
|
||||
);
|
||||
auto d_window = make_tile_window(d_tensor,
|
||||
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
|
||||
{m_offset, n_offset},
|
||||
c_distribution);
|
||||
|
||||
store_tile(d_window, acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
// CPU reference for verification
|
||||
template<typename InType, typename AccType>
|
||||
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
|
||||
const std::vector<InType>& b, // Row-major
|
||||
const std::vector<AccType>& c, // Column-major
|
||||
std::vector<AccType>& d, // Column-major
|
||||
index_t M, index_t N, index_t K,
|
||||
index_t lda, index_t ldb, index_t ldc, index_t ldd,
|
||||
AccType alpha, AccType beta)
|
||||
template <typename InType, typename AccType>
|
||||
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
|
||||
const std::vector<InType>& b, // Row-major
|
||||
const std::vector<AccType>& c, // Column-major
|
||||
std::vector<AccType>& d, // Column-major
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t lda,
|
||||
index_t ldb,
|
||||
index_t ldc,
|
||||
index_t ldd,
|
||||
AccType alpha,
|
||||
AccType beta)
|
||||
{
|
||||
// D = alpha * A * B + beta * C
|
||||
for(index_t n = 0; n < N; ++n) {
|
||||
for(index_t m = 0; m < M; ++m) {
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
AccType sum = 0;
|
||||
|
||||
// Compute A * B
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
// A is column-major: A[m,k] = a[m + k*lda]
|
||||
// B is row-major: B[k,n] = b[k*ldb + n]
|
||||
sum += static_cast<AccType>(a[m + k * lda]) *
|
||||
static_cast<AccType>(b[k * ldb + n]);
|
||||
sum += static_cast<AccType>(a[m + k * lda]) * static_cast<AccType>(b[k * ldb + n]);
|
||||
}
|
||||
|
||||
// D[m,n] = alpha * sum + beta * C[m,n]
|
||||
@@ -252,30 +249,38 @@ void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
|
||||
}
|
||||
|
||||
// Helper to fill matrix with random values
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void fill_random(std::vector<T>& data, T min_val = -1, T max_val = 1)
|
||||
{
|
||||
for(auto& val : data) {
|
||||
val = static_cast<T>(min_val + (max_val - min_val) *
|
||||
static_cast<float>(rand()) / RAND_MAX);
|
||||
for(auto& val : data)
|
||||
{
|
||||
val = static_cast<T>(min_val + (max_val - min_val) * static_cast<float>(rand()) / RAND_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to print matrix (for debugging)
|
||||
template<typename T>
|
||||
void print_matrix(const std::vector<T>& mat, index_t rows, index_t cols,
|
||||
index_t ld, bool col_major = true, const std::string& name = "Matrix")
|
||||
template <typename T>
|
||||
void print_matrix(const std::vector<T>& mat,
|
||||
index_t rows,
|
||||
index_t cols,
|
||||
index_t ld,
|
||||
bool col_major = true,
|
||||
const std::string& name = "Matrix")
|
||||
{
|
||||
std::cout << name << " (" << rows << "×" << cols << "):\n";
|
||||
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i) {
|
||||
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j) {
|
||||
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i)
|
||||
{
|
||||
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j)
|
||||
{
|
||||
index_t idx = col_major ? (i + j * ld) : (i * ld + j);
|
||||
std::cout << std::setw(8) << std::setprecision(3) << mat[idx] << " ";
|
||||
}
|
||||
if(cols > 8) std::cout << "...";
|
||||
if(cols > 8)
|
||||
std::cout << "...";
|
||||
std::cout << "\n";
|
||||
}
|
||||
if(rows > 8) std::cout << "...\n";
|
||||
if(rows > 8)
|
||||
std::cout << "...\n";
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
@@ -298,16 +303,16 @@ int main()
|
||||
constexpr index_t K = 64;
|
||||
|
||||
// Leading dimensions
|
||||
constexpr index_t lda = M; // Column-major
|
||||
constexpr index_t ldb = N; // Row-major
|
||||
constexpr index_t ldc = M; // Column-major
|
||||
constexpr index_t ldd = M; // Column-major
|
||||
constexpr index_t lda = M; // Column-major
|
||||
constexpr index_t ldb = N; // Row-major
|
||||
constexpr index_t ldc = M; // Column-major
|
||||
constexpr index_t ldd = M; // Column-major
|
||||
|
||||
using InputType = half_t; // fp16
|
||||
using AccumType = float; // fp32
|
||||
using InputType = half_t; // fp16
|
||||
using AccumType = float; // fp32
|
||||
|
||||
constexpr AccumType alpha = 2.0f;
|
||||
constexpr AccumType beta = 1.5f;
|
||||
constexpr AccumType beta = 1.5f;
|
||||
|
||||
std::cout << "Problem configuration:\n";
|
||||
std::cout << " M×N×K: " << M << "×" << N << "×" << K << "\n";
|
||||
@@ -315,7 +320,7 @@ int main()
|
||||
std::cout << " B: row-major, ldb=" << ldb << " (fp16)\n";
|
||||
std::cout << " C/D: column-major, ldc=" << ldc << ", ldd=" << ldd << " (fp32)\n";
|
||||
std::cout << " alpha=" << alpha << ", beta=" << beta << "\n";
|
||||
std::cout << " Total FLOPs: " << 2*M*N*K << "\n\n";
|
||||
std::cout << " Total FLOPs: " << 2 * M * N * K << "\n\n";
|
||||
|
||||
// Host memory
|
||||
std::vector<InputType> h_a(M * K);
|
||||
@@ -333,7 +338,7 @@ int main()
|
||||
// CPU reference
|
||||
auto cpu_start = std::chrono::high_resolution_clock::now();
|
||||
reference_gemm_mixed(h_a, h_b, h_c, h_d_ref, M, N, K, lda, ldb, ldc, ldd, alpha, beta);
|
||||
auto cpu_end = std::chrono::high_resolution_clock::now();
|
||||
auto cpu_end = std::chrono::high_resolution_clock::now();
|
||||
double cpu_time_ms = std::chrono::duration<double, std::milli>(cpu_end - cpu_start).count();
|
||||
|
||||
// Device memory
|
||||
@@ -348,30 +353,39 @@ int main()
|
||||
d_d.ToDevice(h_d.data(), M * N * sizeof(AccumType));
|
||||
|
||||
// Launch kernel
|
||||
constexpr index_t block_size = 64; // One wave
|
||||
const index_t grid_size = (M / 16) * (N / 16); // One wave per 16×16 output block
|
||||
constexpr index_t block_size = 64; // One wave
|
||||
const index_t grid_size = (M / 16) * (N / 16); // One wave per 16×16 output block
|
||||
|
||||
std::cout << "Launching kernel:\n";
|
||||
std::cout << " Grid: " << grid_size << " blocks\n";
|
||||
std::cout << " Block: " << block_size << " threads (1 wave)\n";
|
||||
std::cout << " Output blocks: " << (M/16) << "×" << (N/16) << " = " << grid_size << "\n";
|
||||
std::cout << " MFMA instructions per block: " << K/16 << "\n\n";
|
||||
std::cout << " Output blocks: " << (M / 16) << "×" << (N / 16) << " = " << grid_size << "\n";
|
||||
std::cout << " MFMA instructions per block: " << K / 16 << "\n\n";
|
||||
|
||||
stream_config stream;
|
||||
|
||||
// Warmup
|
||||
for(int i = 0; i < 5; ++i) {
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
|
||||
make_kernel<block_size>(
|
||||
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
alpha,
|
||||
beta));
|
||||
}
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
@@ -379,40 +393,50 @@ int main()
|
||||
auto gpu_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
|
||||
make_kernel<block_size>(
|
||||
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
alpha,
|
||||
beta));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
auto gpu_end = std::chrono::high_resolution_clock::now();
|
||||
auto gpu_end = std::chrono::high_resolution_clock::now();
|
||||
double gpu_time_ms = std::chrono::duration<double, std::milli>(gpu_end - gpu_start).count();
|
||||
|
||||
// Get result
|
||||
d_d.FromDevice(h_d.data(), M * N * sizeof(AccumType));
|
||||
|
||||
// Verify
|
||||
bool passed = true;
|
||||
float max_error = 0;
|
||||
bool passed = true;
|
||||
float max_error = 0;
|
||||
index_t error_count = 0;
|
||||
|
||||
for(index_t i = 0; i < M * N; ++i) {
|
||||
for(index_t i = 0; i < M * N; ++i)
|
||||
{
|
||||
float error = std::abs(h_d[i] - h_d_ref[i]);
|
||||
max_error = std::max(max_error, error);
|
||||
if(error > 1e-2f) { // Relaxed tolerance for fp16
|
||||
if(error_count < 5) {
|
||||
max_error = std::max(max_error, error);
|
||||
if(error > 1e-2f)
|
||||
{ // Relaxed tolerance for fp16
|
||||
if(error_count < 5)
|
||||
{
|
||||
index_t m = i % M;
|
||||
index_t n = i / M;
|
||||
std::cout << "Error at [" << m << "," << n << "]: "
|
||||
<< h_d[i] << " vs " << h_d_ref[i]
|
||||
<< " (diff=" << error << ")\n";
|
||||
std::cout << "Error at [" << m << "," << n << "]: " << h_d[i] << " vs "
|
||||
<< h_d_ref[i] << " (diff=" << error << ")\n";
|
||||
}
|
||||
error_count++;
|
||||
}
|
||||
@@ -421,14 +445,15 @@ int main()
|
||||
passed = (error_count == 0);
|
||||
|
||||
// Calculate performance
|
||||
double gflops = 2.0 * M * N * K / 1e9;
|
||||
double gflops = 2.0 * M * N * K / 1e9;
|
||||
double gpu_tflops = gflops / (gpu_time_ms / 1000);
|
||||
double cpu_gflops = gflops / (cpu_time_ms / 1000);
|
||||
|
||||
std::cout << "Results:\n";
|
||||
std::cout << " Correctness: " << (passed ? "✓ PASSED" : "✗ FAILED") << "\n";
|
||||
std::cout << " Max error: " << max_error << "\n";
|
||||
if(!passed) std::cout << " Error count: " << error_count << "/" << M*N << "\n";
|
||||
if(!passed)
|
||||
std::cout << " Error count: " << error_count << "/" << M * N << "\n";
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "Performance:\n";
|
||||
|
||||
@@ -18,101 +18,94 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestADistributionKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* a,
|
||||
DataType* debug_output,
|
||||
index_t lda) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* a, DataType* debug_output, index_t lda) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
// Create tensor view for A (column-major)
|
||||
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
a,
|
||||
make_tuple(kM, kK),
|
||||
make_tuple(1, lda),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
|
||||
a, make_tuple(kM, kK), make_tuple(1, lda), number<1>{}, number<1>{});
|
||||
|
||||
// A distribution WITH NWarp replication and MWarp in H-dimension
|
||||
// Based on 02_gemm pattern: include MWarp in H-tuple
|
||||
// R: NWarp replication
|
||||
// H0: MWarp × 16 threads = 2×16 = 32 M positions
|
||||
// H0: MWarp × 16 threads = 2×16 = 32 M positions
|
||||
// H1: 4×4 = 16 K elements
|
||||
constexpr auto a_distribution = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarp>, // R: REPLICATE across 2 N-warps
|
||||
tuple<sequence<MWarp, 16>, // H0 (M): 2 M-warps × 16 threads = 32 M
|
||||
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // Ps_to_Hs: P0→(R,M), P1→(M,K)
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
sequence<2>, // Ys_to_Hs: Y maps to K (dimension 2)
|
||||
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
|
||||
sequence<NWarp>, // R: REPLICATE across 2 N-warps
|
||||
tuple<sequence<MWarp, 16>, // H0 (M): 2 M-warps × 16 threads = 32 M
|
||||
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // Ps_to_Hs: P0→(R,M), P1→(M,K)
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
sequence<2>, // Ys_to_Hs: Y maps to K (dimension 2)
|
||||
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
|
||||
);
|
||||
|
||||
auto a_window = make_tile_window(
|
||||
a_tensor,
|
||||
make_tuple(number<kM>{}, number<kK>{}),
|
||||
{0, 0},
|
||||
a_distribution
|
||||
);
|
||||
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_distribution);
|
||||
|
||||
const auto a_tile = load_tile(a_window);
|
||||
const auto a_tile = load_tile(a_window);
|
||||
const auto& thread_buffer = a_tile.get_thread_buffer();
|
||||
|
||||
// Calculate matrix coordinates using make_tensor_coordinate
|
||||
// This shows which matrix positions each thread accesses
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
|
||||
printf("Distribution covers 32×16 matrix (MWarp×16 threads × K)\n");
|
||||
printf("With NWarp=2 replication, pattern repeats\n");
|
||||
printf("Showing first 16 threads of each warp:\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print warp by warp with calculated coordinates
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
|
||||
w, w/NWarp, w%NWarp);
|
||||
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print lanes sequentially within each warp
|
||||
for(int lane = 0; lane < 16; ++lane) {
|
||||
for(int lane = 0; lane < 16; ++lane)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == lane) {
|
||||
|
||||
if(warp_id == w && lane_id == lane)
|
||||
{
|
||||
printf("W%d L%02d: ", w, lane);
|
||||
|
||||
|
||||
// For each Y element, just print the loaded value
|
||||
// The distribution handles the coordinate mapping internally
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
|
||||
{
|
||||
float val = static_cast<float>(thread_buffer[y_idx]);
|
||||
int m = static_cast<int>(val) % 100;
|
||||
int k = static_cast<int>(val) / 100;
|
||||
|
||||
int m = static_cast<int>(val) % 100;
|
||||
int k = static_cast<int>(val) / 100;
|
||||
|
||||
printf("A[%2d,%2d] ", m, k);
|
||||
}
|
||||
printf("\n");
|
||||
@@ -122,7 +115,8 @@ struct TestADistributionKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Expected Pattern with NWarp Replication ===\n");
|
||||
printf("sequence<NWarp> replicates across N-warp dimension:\n");
|
||||
printf("Warp 0 (M-warp 0, N-warp 0): Loads some M-rows, K[0-15]\n");
|
||||
@@ -130,13 +124,16 @@ struct TestADistributionKernel
|
||||
printf("Warp 2 (M-warp 1, N-warp 0): Loads SAME as Warp 0 (NWarp replication!)\n");
|
||||
printf("Warp 3 (M-warp 1, N-warp 1): Loads SAME as Warp 1 (NWarp replication!)\n");
|
||||
printf("\nReplication pairs:\n");
|
||||
printf(" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
|
||||
printf(" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
|
||||
printf(
|
||||
" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
|
||||
printf(
|
||||
" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
|
||||
printf(" Warps 0 & 1 should be DIFFERENT (different N-warps)\n");
|
||||
}
|
||||
|
||||
// Store for verification
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
debug_output[tid * 4 + i] = thread_buffer[i];
|
||||
}
|
||||
}
|
||||
@@ -148,8 +145,8 @@ int main()
|
||||
std::cout << "Test A Distribution with NWarp Replication\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
|
||||
constexpr index_t M = 16;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t M = 16;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t lda = M;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -159,8 +156,10 @@ int main()
|
||||
std::vector<DataType> h_debug(256 * 4, -1);
|
||||
|
||||
// Initialize A[m,k] = m + k*100
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t m = 0; m < M; ++m) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
h_a[m + k * lda] = static_cast<DataType>(m + k * 100);
|
||||
}
|
||||
}
|
||||
@@ -173,14 +172,13 @@ int main()
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestADistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
lda));
|
||||
make_kernel<256>(TestADistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
lda));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
@@ -188,40 +186,52 @@ int main()
|
||||
|
||||
// Verify NWarp replication: warps 0&2 identical, warps 1&3 identical
|
||||
bool passed = true;
|
||||
|
||||
|
||||
// Check warps 0 and 2 (same N-warp 0, replicated across M-warps)
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp0_val = h_debug[lane * 4 + i];
|
||||
float warp2_val = h_debug[(128 + lane) * 4 + i];
|
||||
if(std::abs(warp0_val - warp2_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp0_val - warp2_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i
|
||||
<< "\n";
|
||||
std::cout << " Warp 0: " << warp0_val << ", Warp 2: " << warp2_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
|
||||
// Check warps 1 and 3 (same N-warp 1, replicated across M-warps)
|
||||
if(passed) {
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
if(passed)
|
||||
{
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp1_val = h_debug[(64 + lane) * 4 + i];
|
||||
float warp3_val = h_debug[(192 + lane) * 4 + i];
|
||||
if(std::abs(warp1_val - warp3_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp1_val - warp3_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element "
|
||||
<< i << "\n";
|
||||
std::cout << " Warp 1: " << warp1_val << ", Warp 3: " << warp3_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(passed) {
|
||||
if(passed)
|
||||
{
|
||||
std::cout << "\n✓ NWarp Replication verified:\n";
|
||||
std::cout << " Warps 0 & 2 load identical data (N-warp 0, replicated across M-warps)\n";
|
||||
std::cout << " Warps 1 & 3 load identical data (N-warp 1, replicated across M-warps)\n";
|
||||
|
||||
@@ -18,35 +18,27 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestADistributionKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* a,
|
||||
DataType* debug_output,
|
||||
index_t lda) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* a, DataType* debug_output, index_t lda) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
// Create tensor view for A (column-major)
|
||||
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
a,
|
||||
make_tuple(kM, kK),
|
||||
make_tuple(1, lda),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
|
||||
a, make_tuple(kM, kK), make_tuple(1, lda), number<1>{}, number<1>{});
|
||||
|
||||
// A distribution using EMBED API (like 02_gemm)
|
||||
// Separate block-level and warp-level distributions
|
||||
@@ -62,84 +54,84 @@ struct TestADistributionKernel
|
||||
// sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
|
||||
// );
|
||||
|
||||
|
||||
// Step 1: Warp-level distribution (64 threads within one warp)
|
||||
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>, // No replication at warp level
|
||||
tuple<sequence<16>, // H0 (M): 16 M positions
|
||||
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<2>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto a_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>, // No replication at warp level
|
||||
tuple<sequence<16>, // H0 (M): 16 M positions
|
||||
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<2>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
// Step 2: Block-level outer distribution (warp organization)
|
||||
// Must have same NDimX as inner (2 dimensions: M and K)
|
||||
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<NWarp>, // R: Replicate across N-warps
|
||||
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1 in K-dim
|
||||
tuple<sequence<0, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both M and K
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto a_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<NWarp>, // R: Replicate across N-warps
|
||||
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1
|
||||
// in K-dim
|
||||
tuple<sequence<0, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both M and K
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
// Step 3: Embed warp-level into block-level
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
|
||||
// Step 4: Create final distribution
|
||||
constexpr auto a_distribution = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
auto a_window = make_tile_window(
|
||||
a_tensor,
|
||||
make_tuple(number<kM>{}, number<kK>{}),
|
||||
{0, 0},
|
||||
a_distribution
|
||||
);
|
||||
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_distribution);
|
||||
|
||||
const auto a_tile = load_tile(a_window);
|
||||
const auto a_tile = load_tile(a_window);
|
||||
const auto& thread_buffer = a_tile.get_thread_buffer();
|
||||
|
||||
// Calculate matrix coordinates using make_tensor_coordinate
|
||||
// This shows which matrix positions each thread accesses
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
|
||||
printf("Distribution covers 32×16 matrix (MWarp×16 threads × K)\n");
|
||||
printf("With NWarp=2 replication, pattern repeats\n");
|
||||
printf("Showing first 16 threads of each warp:\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print warp by warp with calculated coordinates
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
|
||||
w, w/NWarp, w%NWarp);
|
||||
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print lanes sequentially within each warp
|
||||
for(int lane = 0; lane < 16; ++lane) {
|
||||
for(int lane = 0; lane < 16; ++lane)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == lane) {
|
||||
|
||||
if(warp_id == w && lane_id == lane)
|
||||
{
|
||||
printf("W%d L%02d: ", w, lane);
|
||||
|
||||
|
||||
// For each Y element, just print the loaded value
|
||||
// The distribution handles the coordinate mapping internally
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
|
||||
{
|
||||
float val = static_cast<float>(thread_buffer[y_idx]);
|
||||
int m = static_cast<int>(val) % 100;
|
||||
int k = static_cast<int>(val) / 100;
|
||||
|
||||
int m = static_cast<int>(val) % 100;
|
||||
int k = static_cast<int>(val) / 100;
|
||||
|
||||
printf("A[%2d,%2d] ", m, k);
|
||||
}
|
||||
printf("\n");
|
||||
@@ -149,7 +141,8 @@ struct TestADistributionKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Expected Pattern with NWarp Replication ===\n");
|
||||
printf("sequence<NWarp> replicates across N-warp dimension:\n");
|
||||
printf("Warp 0 (M-warp 0, N-warp 0): Loads some M-rows, K[0-15]\n");
|
||||
@@ -157,13 +150,16 @@ struct TestADistributionKernel
|
||||
printf("Warp 2 (M-warp 1, N-warp 0): Loads SAME as Warp 0 (NWarp replication!)\n");
|
||||
printf("Warp 3 (M-warp 1, N-warp 1): Loads SAME as Warp 1 (NWarp replication!)\n");
|
||||
printf("\nReplication pairs:\n");
|
||||
printf(" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
|
||||
printf(" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
|
||||
printf(
|
||||
" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
|
||||
printf(
|
||||
" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
|
||||
printf(" Warps 0 & 1 should be DIFFERENT (different N-warps)\n");
|
||||
}
|
||||
|
||||
// Store for verification
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
debug_output[tid * 4 + i] = thread_buffer[i];
|
||||
}
|
||||
}
|
||||
@@ -174,10 +170,11 @@ int main()
|
||||
std::cout << "\n==================================================\n";
|
||||
std::cout << "Test A Distribution using EMBED API\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
std::cout << "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
|
||||
std::cout
|
||||
<< "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
|
||||
|
||||
constexpr index_t M = 16;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t M = 16;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t lda = M;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -187,8 +184,10 @@ int main()
|
||||
std::vector<DataType> h_debug(256 * 4, -1);
|
||||
|
||||
// Initialize A[m,k] = m + k*100
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t m = 0; m < M; ++m) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
h_a[m + k * lda] = static_cast<DataType>(m + k * 100);
|
||||
}
|
||||
}
|
||||
@@ -201,14 +200,13 @@ int main()
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestADistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
lda));
|
||||
make_kernel<256>(TestADistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
lda));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
@@ -216,40 +214,52 @@ int main()
|
||||
|
||||
// Verify NWarp replication: warps 0&2 identical, warps 1&3 identical
|
||||
bool passed = true;
|
||||
|
||||
|
||||
// Check warps 0 and 2 (same N-warp 0, replicated across M-warps)
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp0_val = h_debug[lane * 4 + i];
|
||||
float warp2_val = h_debug[(128 + lane) * 4 + i];
|
||||
if(std::abs(warp0_val - warp2_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp0_val - warp2_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i
|
||||
<< "\n";
|
||||
std::cout << " Warp 0: " << warp0_val << ", Warp 2: " << warp2_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
|
||||
// Check warps 1 and 3 (same N-warp 1, replicated across M-warps)
|
||||
if(passed) {
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
if(passed)
|
||||
{
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp1_val = h_debug[(64 + lane) * 4 + i];
|
||||
float warp3_val = h_debug[(192 + lane) * 4 + i];
|
||||
if(std::abs(warp1_val - warp3_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp1_val - warp3_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element "
|
||||
<< i << "\n";
|
||||
std::cout << " Warp 1: " << warp1_val << ", Warp 3: " << warp3_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(passed) {
|
||||
if(passed)
|
||||
{
|
||||
std::cout << "\n✓ NWarp Replication verified:\n";
|
||||
std::cout << " Warps 0 & 2 load identical data (N-warp 0, replicated across M-warps)\n";
|
||||
std::cout << " Warps 1 & 3 load identical data (N-warp 1, replicated across M-warps)\n";
|
||||
|
||||
@@ -18,34 +18,27 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestBDistributionKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kK = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kK = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* b,
|
||||
DataType* debug_output,
|
||||
index_t ldb) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* b, DataType* debug_output, index_t ldb) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
// Create tensor view for B (row-major)
|
||||
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
b,
|
||||
make_tuple(kK, kN),
|
||||
make_tuple(ldb, 1),
|
||||
number<4>{},
|
||||
number<1>{}
|
||||
);
|
||||
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
|
||||
|
||||
// B distribution WITH MWarp replication
|
||||
// R dimension (index 0) has MWarp=2 replicas
|
||||
@@ -65,84 +58,87 @@ struct TestBDistributionKernel
|
||||
// sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
|
||||
// );
|
||||
|
||||
|
||||
// constexpr auto b_distribution = make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<>, // No replication
|
||||
// tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K values
|
||||
// tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K
|
||||
// values
|
||||
// sequence<16>>, // H1 (N): 16 N positions
|
||||
// tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
|
||||
// tuple<sequence<0, 0>>, // P positions in H-dims
|
||||
// sequence<1>, // Y maps to K dimension (H0)
|
||||
// sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
|
||||
// sequence<1>>{} // Y at position 1 in H0 (the second 4 in
|
||||
// sequence<4,4>)
|
||||
// );
|
||||
|
||||
// constexpr auto b_block_outer_dstr_encoding =
|
||||
// tile_distribution_encoding<sequence<MWarp>,
|
||||
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// tuple<sequence<NIterPerWarp, NWarp>,
|
||||
// sequence<KIterPerWarp>>, tuple<sequence<0, 1>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
|
||||
// The key: Ps_to_Hs must include dimension 0 (the R dimension)!
|
||||
constexpr auto b_distribution = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
|
||||
tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
|
||||
sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
|
||||
tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1), P2→N(dim 2)
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
|
||||
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
|
||||
sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
|
||||
tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
|
||||
sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
|
||||
tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1),
|
||||
// P2→N(dim 2)
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
|
||||
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
|
||||
);
|
||||
|
||||
auto b_window = make_tile_window(
|
||||
b_tensor,
|
||||
make_tuple(number<kK>{}, number<kN>{}),
|
||||
{0, 0},
|
||||
b_distribution
|
||||
);
|
||||
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_distribution);
|
||||
|
||||
const auto b_tile = load_tile(b_window);
|
||||
const auto b_tile = load_tile(b_window);
|
||||
const auto& thread_buffer = b_tile.get_thread_buffer();
|
||||
|
||||
// Sequential printing with synchronizations (copied from test_a)
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
|
||||
printf("Distribution covers K×32 matrix (K × NWarp×16 threads)\n");
|
||||
printf("With MWarp=2 replication, pattern repeats\n");
|
||||
printf("Showing first 16 threads of each warp:\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print warp by warp with calculated coordinates
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
|
||||
w, w/NWarp, w%NWarp);
|
||||
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print lanes sequentially within each warp
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == lane) {
|
||||
|
||||
if(warp_id == w && lane_id == lane)
|
||||
{
|
||||
printf("W%d L%02d: ", w, lane);
|
||||
|
||||
|
||||
// Print loaded values
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
|
||||
{
|
||||
float val = static_cast<float>(thread_buffer[y_idx]);
|
||||
int k = static_cast<int>(val) % 100;
|
||||
int n = static_cast<int>(val) / 100;
|
||||
|
||||
int k = static_cast<int>(val) % 100;
|
||||
int n = static_cast<int>(val) / 100;
|
||||
|
||||
printf("B[%2d,%2d] ", k, n);
|
||||
}
|
||||
printf("\n");
|
||||
@@ -152,7 +148,8 @@ struct TestBDistributionKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Observed Pattern ===\n");
|
||||
printf("Warps 0 & 1 are identical\n");
|
||||
printf("Warps 2 & 3 are identical\n");
|
||||
@@ -160,7 +157,8 @@ struct TestBDistributionKernel
|
||||
}
|
||||
|
||||
// Store for verification
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
debug_output[tid * 4 + i] = thread_buffer[i];
|
||||
}
|
||||
}
|
||||
@@ -172,8 +170,8 @@ int main()
|
||||
std::cout << "Test B Distribution with MWarp Replication\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t N = 32;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t N = 32;
|
||||
constexpr index_t ldb = N;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -183,8 +181,10 @@ int main()
|
||||
std::vector<DataType> h_debug(256 * 4, -1);
|
||||
|
||||
// Initialize B[k,n] = k + n*100 (row-major)
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t n = 0; n < N; ++n) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
h_b[k * ldb + n] = static_cast<DataType>(k + n * 100);
|
||||
}
|
||||
}
|
||||
@@ -197,14 +197,13 @@ int main()
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestBDistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
ldb));
|
||||
make_kernel<256>(TestBDistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
ldb));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
@@ -212,40 +211,52 @@ int main()
|
||||
|
||||
// Verify: Based on your observation, warps 0&1 identical, warps 2&3 identical
|
||||
bool passed = true;
|
||||
|
||||
|
||||
// Check warps 0 and 1
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp0_val = h_debug[lane * 4 + i];
|
||||
float warp1_val = h_debug[(64 + lane) * 4 + i];
|
||||
if(std::abs(warp0_val - warp1_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp0_val - warp1_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i
|
||||
<< "\n";
|
||||
std::cout << " Warp 0: " << warp0_val << ", Warp 1: " << warp1_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
|
||||
// Check warps 2 and 3
|
||||
if(passed) {
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
if(passed)
|
||||
{
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp2_val = h_debug[(128 + lane) * 4 + i];
|
||||
float warp3_val = h_debug[(192 + lane) * 4 + i];
|
||||
if(std::abs(warp2_val - warp3_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp2_val - warp3_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element "
|
||||
<< i << "\n";
|
||||
std::cout << " Warp 2: " << warp2_val << ", Warp 3: " << warp3_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(passed) {
|
||||
if(passed)
|
||||
{
|
||||
std::cout << "\n✓ Replication verified:\n";
|
||||
std::cout << " Warps 0 & 1 load identical data\n";
|
||||
std::cout << " Warps 2 & 3 load identical data\n";
|
||||
|
||||
@@ -18,123 +18,119 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestBDistributionKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kK = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kBlockSize = 256; // 4 warps
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kK = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* b,
|
||||
DataType* debug_output,
|
||||
index_t ldb) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* b, DataType* debug_output, index_t ldb) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
// Create tensor view for B (row-major)
|
||||
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
b,
|
||||
make_tuple(kK, kN),
|
||||
make_tuple(ldb, 1),
|
||||
number<4>{},
|
||||
number<1>{}
|
||||
);
|
||||
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
|
||||
|
||||
// B distribution using EMBED API (like 02_gemm)
|
||||
// Separate block-level and warp-level distributions
|
||||
|
||||
|
||||
// Step 1: Warp-level distribution (64 threads within one warp)
|
||||
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>, // No replication at warp level
|
||||
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values (2D P-space)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<1>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>, // No replication at warp level
|
||||
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values
|
||||
// (2D P-space)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<1>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
// Step 2: Block-level outer distribution (warp organization)
|
||||
// Must have same NDimX as inner (2 dimensions: K and N)
|
||||
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<MWarp>, // R: Replicate across M-warps
|
||||
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1 in K-dim
|
||||
tuple<sequence<2, 0>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both K and N
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
constexpr auto b_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<MWarp>, // R: Replicate across M-warps
|
||||
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1
|
||||
// in K-dim
|
||||
tuple<sequence<2, 0>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both K and N
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
// constexpr auto b_distribution = make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
|
||||
// tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
|
||||
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2
|
||||
// M-warps tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K
|
||||
// elements
|
||||
// sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
|
||||
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1), P2→N(dim 2)
|
||||
// tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1),
|
||||
// P2→N(dim 2) tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
// sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
|
||||
// sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
|
||||
// );
|
||||
|
||||
|
||||
// Step 3: Embed warp-level into block-level
|
||||
constexpr auto b_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
|
||||
// Step 4: Create final distribution
|
||||
constexpr auto b_distribution = make_static_tile_distribution(b_block_dstr_encode);
|
||||
|
||||
auto b_window = make_tile_window(
|
||||
b_tensor,
|
||||
make_tuple(number<kK>{}, number<kN>{}),
|
||||
{0, 0},
|
||||
b_distribution
|
||||
);
|
||||
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_distribution);
|
||||
|
||||
const auto b_tile = load_tile(b_window);
|
||||
const auto b_tile = load_tile(b_window);
|
||||
const auto& thread_buffer = b_tile.get_thread_buffer();
|
||||
|
||||
// Sequential printing with synchronizations (copied from test_a)
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
|
||||
printf("Distribution covers K×32 matrix (K × NWarp×16 threads)\n");
|
||||
printf("With MWarp=2 replication, pattern repeats\n");
|
||||
printf("Showing first 16 threads of each warp:\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print warp by warp with calculated coordinates
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
|
||||
w, w/NWarp, w%NWarp);
|
||||
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// Print lanes sequentially within each warp
|
||||
for(int lane = 0; lane < 16; ++lane) {
|
||||
for(int lane = 0; lane < 16; ++lane)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if(warp_id == w && lane_id == lane) {
|
||||
|
||||
if(warp_id == w && lane_id == lane)
|
||||
{
|
||||
printf("W%d L%02d: ", w, lane);
|
||||
|
||||
|
||||
// Print loaded values
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
|
||||
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
|
||||
{
|
||||
float val = static_cast<float>(thread_buffer[y_idx]);
|
||||
int k = static_cast<int>(val) % 100;
|
||||
int n = static_cast<int>(val) / 100;
|
||||
|
||||
int k = static_cast<int>(val) % 100;
|
||||
int n = static_cast<int>(val) / 100;
|
||||
|
||||
printf("B[%2d,%2d] ", k, n);
|
||||
}
|
||||
printf("\n");
|
||||
@@ -144,7 +140,8 @@ struct TestBDistributionKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Observed Pattern ===\n");
|
||||
printf("Warps 0 & 1 are identical\n");
|
||||
printf("Warps 2 & 3 are identical\n");
|
||||
@@ -152,7 +149,8 @@ struct TestBDistributionKernel
|
||||
}
|
||||
|
||||
// Store for verification
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
debug_output[tid * 4 + i] = thread_buffer[i];
|
||||
}
|
||||
}
|
||||
@@ -163,10 +161,11 @@ int main()
|
||||
std::cout << "\n==================================================\n";
|
||||
std::cout << "Test B Distribution using EMBED API\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
std::cout << "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
|
||||
std::cout
|
||||
<< "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
|
||||
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t N = 32;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t N = 32;
|
||||
constexpr index_t ldb = N;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -176,8 +175,10 @@ int main()
|
||||
std::vector<DataType> h_debug(256 * 4, -1);
|
||||
|
||||
// Initialize B[k,n] = k + n*100 (row-major)
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t n = 0; n < N; ++n) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
h_b[k * ldb + n] = static_cast<DataType>(k + n * 100);
|
||||
}
|
||||
}
|
||||
@@ -190,14 +191,13 @@ int main()
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestBDistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
ldb));
|
||||
make_kernel<256>(TestBDistributionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
ldb));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
@@ -205,40 +205,52 @@ int main()
|
||||
|
||||
// Verify: Based on your observation, warps 0&1 identical, warps 2&3 identical
|
||||
bool passed = true;
|
||||
|
||||
|
||||
// Check warps 0 and 1
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp0_val = h_debug[lane * 4 + i];
|
||||
float warp1_val = h_debug[(64 + lane) * 4 + i];
|
||||
if(std::abs(warp0_val - warp1_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp0_val - warp1_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i
|
||||
<< "\n";
|
||||
std::cout << " Warp 0: " << warp0_val << ", Warp 1: " << warp1_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
|
||||
// Check warps 2 and 3
|
||||
if(passed) {
|
||||
for(int lane = 0; lane < 64; ++lane) {
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
if(passed)
|
||||
{
|
||||
for(int lane = 0; lane < 64; ++lane)
|
||||
{
|
||||
for(int i = 0; i < 4; ++i)
|
||||
{
|
||||
float warp2_val = h_debug[(128 + lane) * 4 + i];
|
||||
float warp3_val = h_debug[(192 + lane) * 4 + i];
|
||||
if(std::abs(warp2_val - warp3_val) > 0.01f) {
|
||||
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element " << i << "\n";
|
||||
if(std::abs(warp2_val - warp3_val) > 0.01f)
|
||||
{
|
||||
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element "
|
||||
<< i << "\n";
|
||||
std::cout << " Warp 2: " << warp2_val << ", Warp 3: " << warp3_val << "\n";
|
||||
passed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!passed) break;
|
||||
if(!passed)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(passed) {
|
||||
if(passed)
|
||||
{
|
||||
std::cout << "\n✓ Replication verified:\n";
|
||||
std::cout << " Warps 0 & 1 load identical data\n";
|
||||
std::cout << " Warps 2 & 3 load identical data\n";
|
||||
|
||||
@@ -28,18 +28,18 @@
|
||||
using namespace ck_tile;
|
||||
|
||||
// Tile Sweeping HGEMM kernel with multiple warps
|
||||
template<typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
|
||||
struct TileSweepingHgemmKernel
|
||||
{
|
||||
static constexpr index_t kWaveSize = 64; // AMD wave size
|
||||
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
|
||||
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
|
||||
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
|
||||
static constexpr index_t kWaveSize = 64; // AMD wave size
|
||||
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
|
||||
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
|
||||
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
|
||||
|
||||
// Warp configuration: 2×2 warps per block
|
||||
static constexpr index_t MWarp = 2; // 2 warps in M dimension
|
||||
static constexpr index_t NWarp = 2; // 2 warps in N dimension
|
||||
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
|
||||
static constexpr index_t MWarp = 2; // 2 warps in M dimension
|
||||
static constexpr index_t NWarp = 2; // 2 warps in N dimension
|
||||
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
|
||||
|
||||
// No iterations - each warp computes exactly one 16×16 output tile
|
||||
|
||||
@@ -53,24 +53,24 @@ struct TileSweepingHgemmKernel
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t lda, // Leading dimension of A (column-major)
|
||||
index_t ldb, // Leading dimension of B (row-major)
|
||||
index_t ldc, // Leading dimension of C (column-major)
|
||||
index_t ldd, // Leading dimension of D (column-major)
|
||||
index_t lda, // Leading dimension of A (column-major)
|
||||
index_t ldb, // Leading dimension of B (row-major)
|
||||
index_t ldc, // Leading dimension of C (column-major)
|
||||
index_t ldd, // Leading dimension of D (column-major)
|
||||
AccDataType alpha,
|
||||
AccDataType beta) const
|
||||
{
|
||||
// Calculate which warp this thread belongs to within the block
|
||||
const index_t warp_id = get_warp_id();
|
||||
const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
|
||||
const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
|
||||
const index_t warp_id = get_warp_id();
|
||||
const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
|
||||
const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
|
||||
const index_t block_id = get_block_id();
|
||||
|
||||
|
||||
// Convert linear block_id to 2D grid coordinates
|
||||
const index_t num_blocks_n = N / (NWarp * kWarpN); // Number of blocks in N dimension
|
||||
const index_t block_m = block_id / num_blocks_n; // M-block index
|
||||
const index_t block_n = block_id % num_blocks_n; // N-block index
|
||||
|
||||
const index_t num_blocks_n = N / (NWarp * kWarpN); // Number of blocks in N dimension
|
||||
const index_t block_m = block_id / num_blocks_n; // M-block index
|
||||
const index_t block_n = block_id % num_blocks_n; // N-block index
|
||||
|
||||
// printf("Block %d (grid [%d,%d]), Warp %d (M-warp %d, N-warp %d)\n",
|
||||
// block_id, block_m, block_n, warp_id, iMWarp, iNWarp);
|
||||
|
||||
@@ -86,96 +86,91 @@ struct TileSweepingHgemmKernel
|
||||
// A is column-major: M×K with stride lda between columns
|
||||
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
a,
|
||||
make_tuple(M, K), // Shape: M×K
|
||||
make_tuple(1, lda), // Strides: column-major
|
||||
make_tuple(M, K), // Shape: M×K
|
||||
make_tuple(1, lda), // Strides: column-major
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// B is row-major: K×N with stride ldb between rows
|
||||
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
b,
|
||||
make_tuple(K, N), // Shape: K×N
|
||||
make_tuple(ldb, 1), // Strides: row-major
|
||||
make_tuple(K, N), // Shape: K×N
|
||||
make_tuple(ldb, 1), // Strides: row-major
|
||||
number<4>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// C is column-major: M×N with stride ldc between columns
|
||||
const auto c_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
c,
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldc), // Strides: column-major
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldc), // Strides: column-major
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// D is column-major: M×N with stride ldd between columns
|
||||
auto d_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
d,
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldd), // Strides: column-major
|
||||
make_tuple(M, N), // Shape: M×N
|
||||
make_tuple(1, ldd), // Strides: column-major
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
number<1>{});
|
||||
|
||||
// ============================================================================
|
||||
// TILE DISTRIBUTIONS using EMBED API (from verified tests)
|
||||
// ============================================================================
|
||||
|
||||
|
||||
// Step 1: Warp-level distribution (64 threads within one warp)
|
||||
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>, // No replication at warp level
|
||||
tuple<sequence<16>, // H0 (M): 16 M positions
|
||||
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<2>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto a_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>, // No replication at warp level
|
||||
tuple<sequence<16>, // H0 (M): 16 M positions
|
||||
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<2>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
// Step 2: Block-level outer distribution (warp organization)
|
||||
// Must have same NDimX as inner (2 dimensions: M and K)
|
||||
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<NWarp>, // R: Replicate across N-warps
|
||||
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1 in K-dim
|
||||
tuple<sequence<0, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both M and K
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
constexpr auto a_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<NWarp>, // R: Replicate across N-warps
|
||||
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1
|
||||
// in K-dim
|
||||
tuple<sequence<0, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both M and K
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
// B warp-level distribution
|
||||
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>, // No replication at warp level
|
||||
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values (2D P-space)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<1>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>, // No replication at warp level
|
||||
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
|
||||
sequence<16>>, // H1 (N): 16 N positions
|
||||
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values
|
||||
// (2D P-space)
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<1>, // Ys_to_Hs: Y maps to K
|
||||
sequence<1>>{}; // Ys_in_Hs
|
||||
|
||||
// Step 2: Block-level outer distribution (warp organization)
|
||||
// Must have same NDimX as inner (2 dimensions: K and N)
|
||||
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<MWarp>, // R: Replicate across M-warps
|
||||
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1 in K-dim
|
||||
tuple<sequence<2, 0>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both K and N
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto b_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<MWarp>, // R: Replicate across M-warps
|
||||
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1
|
||||
// in K-dim
|
||||
tuple<sequence<2, 0>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to both K and N
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
// Embed to create block-level distributions with replication
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
|
||||
constexpr auto b_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
|
||||
// /*direct approach*/
|
||||
// constexpr auto a_block_dstr_encode =
|
||||
// constexpr auto a_block_dstr_encode =
|
||||
// tile_distribution_encoding<
|
||||
// sequence<NWarp>, // R: REPLICATE across 2 N-warps
|
||||
// tuple<sequence<MWarp, 16>, // H0 (M): 2 M-warps × 16 threads = 32 M
|
||||
@@ -185,24 +180,23 @@ struct TileSweepingHgemmKernel
|
||||
// sequence<2>, // Ys_to_Hs: Y maps to K (dimension 2)
|
||||
// sequence<1>>{}; // Ys_in_Hs: Y at position 1 in K
|
||||
|
||||
|
||||
// // Direct approach (like test_b_distribution_with_replication.cpp)
|
||||
// constexpr auto b_block_dstr_encode =
|
||||
// constexpr auto b_block_dstr_encode =
|
||||
// tile_distribution_encoding<
|
||||
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
|
||||
// tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
|
||||
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2
|
||||
// M-warps tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K
|
||||
// elements
|
||||
// sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
|
||||
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1), P2→N(dim 2)
|
||||
// tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1),
|
||||
// P2→N(dim 2) tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
|
||||
// sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
|
||||
// sequence<1>>{}; // Ys_in_Hs: Y at position 1 in K
|
||||
// /*direct approach*/
|
||||
|
||||
|
||||
// Use block-level distributions for loading (includes replication)
|
||||
constexpr auto a_block_distribution = make_static_tile_distribution(a_block_dstr_encode);
|
||||
constexpr auto b_block_distribution = make_static_tile_distribution(b_block_dstr_encode);
|
||||
|
||||
|
||||
// C Distribution: Create block-level distribution for 32×32 output
|
||||
// No replication needed - each warp computes its own unique output region
|
||||
// 2D P-space for 4 warps: P[0] for M-warp, P[1] for N-warp
|
||||
@@ -215,52 +209,47 @@ struct TileSweepingHgemmKernel
|
||||
// sequence<1>, // No Y dimension for output
|
||||
// sequence<2>>{};
|
||||
|
||||
constexpr auto c_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
constexpr auto c_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>, // No replication for output
|
||||
tuple<sequence<MWarp>, // H0: M iterations
|
||||
sequence<NWarp>>, // H1: N iterations
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to BOTH M and N
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
constexpr auto c_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>, // No replication for output
|
||||
tuple<sequence<MWarp>, // H0: M iterations
|
||||
sequence<NWarp>>, // H1: N iterations
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 0>>, // Ps_in_Hs
|
||||
sequence<>, // Ys_to_Hs: Y maps to BOTH M and N
|
||||
sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto c_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encode, c_warp_dstr_encode);
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encode, c_warp_dstr_encode);
|
||||
|
||||
|
||||
constexpr auto c_block_distribution = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
// Create block-level windows (one K-chunk at a time)
|
||||
// A: 32×16 (all M-rows × one K-chunk)
|
||||
// B: 16×32 (one K-chunk × all N-columns)
|
||||
auto a_block_window = make_tile_window(
|
||||
a_tensor,
|
||||
make_tuple(number<MWarp * kWarpM>{}, number<kWarpK>{}), // 32×16
|
||||
{block_m * (MWarp * kWarpM), 0},
|
||||
a_block_distribution
|
||||
);
|
||||
auto a_block_window =
|
||||
make_tile_window(a_tensor,
|
||||
make_tuple(number<MWarp * kWarpM>{}, number<kWarpK>{}), // 32×16
|
||||
{block_m * (MWarp * kWarpM), 0},
|
||||
a_block_distribution);
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_tensor,
|
||||
make_tuple(number<kWarpK>{}, number<NWarp * kWarpN>{}), // 16×32
|
||||
{0, block_n * (NWarp * kWarpN)},
|
||||
b_block_distribution
|
||||
);
|
||||
auto b_block_window =
|
||||
make_tile_window(b_tensor,
|
||||
make_tuple(number<kWarpK>{}, number<NWarp * kWarpN>{}), // 16×32
|
||||
{0, block_n * (NWarp * kWarpN)},
|
||||
b_block_distribution);
|
||||
|
||||
// Create block-level accumulator tile (covers all 4 warps)
|
||||
auto c_block_tile = make_static_distributed_tensor<AccDataType>(c_block_distribution);
|
||||
set_tile(c_block_tile, AccDataType{0});
|
||||
|
||||
|
||||
// Main K-loop
|
||||
const index_t num_k_loops = K / kWarpK;
|
||||
for(index_t k_iter = 0; k_iter < num_k_loops; ++k_iter)
|
||||
@@ -269,13 +258,14 @@ struct TileSweepingHgemmKernel
|
||||
// Each warp gets its correct portion based on the distribution encoding
|
||||
const auto a_tile = load_tile(a_block_window);
|
||||
const auto b_tile = load_tile(b_block_window);
|
||||
|
||||
|
||||
// Perform MFMA: C += A * B
|
||||
// Each warp updates its portion of the block tile
|
||||
WarpGemm{}(c_block_tile, a_tile, b_tile);
|
||||
|
||||
// // Move windows to next K chunk
|
||||
if(k_iter < num_k_loops - 1) {
|
||||
if(k_iter < num_k_loops - 1)
|
||||
{
|
||||
move_tile_window(a_block_window, {0, kWarpK});
|
||||
move_tile_window(b_block_window, {kWarpK, 0});
|
||||
}
|
||||
@@ -289,53 +279,58 @@ struct TileSweepingHgemmKernel
|
||||
{
|
||||
auto c_block_window = make_tile_window(
|
||||
c_tensor,
|
||||
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
|
||||
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
|
||||
{block_m * (MWarp * kWarpM), block_n * (NWarp * kWarpN)},
|
||||
c_block_distribution
|
||||
);
|
||||
c_block_distribution);
|
||||
|
||||
const auto c_input_block_tile = load_tile(c_block_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[beta](const auto& c_val, auto& acc_val) {
|
||||
acc_val += beta * c_val;
|
||||
},
|
||||
c_input_block_tile, c_block_tile);
|
||||
[beta](const auto& c_val, auto& acc_val) { acc_val += beta * c_val; },
|
||||
c_input_block_tile,
|
||||
c_block_tile);
|
||||
}
|
||||
|
||||
// Store final result to D (entire block)
|
||||
auto d_block_window = make_tile_window(
|
||||
d_tensor,
|
||||
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
|
||||
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
|
||||
{block_m * (MWarp * kWarpM), block_n * (NWarp * kWarpN)},
|
||||
c_block_distribution
|
||||
);
|
||||
c_block_distribution);
|
||||
|
||||
store_tile(d_block_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
// CPU reference for verification
|
||||
template<typename InType, typename AccType>
|
||||
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
|
||||
const std::vector<InType>& b, // Row-major
|
||||
const std::vector<AccType>& c, // Column-major
|
||||
std::vector<AccType>& d, // Column-major
|
||||
index_t M, index_t N, index_t K,
|
||||
index_t lda, index_t ldb, index_t ldc, index_t ldd,
|
||||
AccType alpha, AccType beta)
|
||||
template <typename InType, typename AccType>
|
||||
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
|
||||
const std::vector<InType>& b, // Row-major
|
||||
const std::vector<AccType>& c, // Column-major
|
||||
std::vector<AccType>& d, // Column-major
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t lda,
|
||||
index_t ldb,
|
||||
index_t ldc,
|
||||
index_t ldd,
|
||||
AccType alpha,
|
||||
AccType beta)
|
||||
{
|
||||
// D = alpha * A * B + beta * C
|
||||
for(index_t n = 0; n < N; ++n) {
|
||||
for(index_t m = 0; m < M; ++m) {
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
AccType sum = 0;
|
||||
|
||||
// Compute A * B
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
// A is column-major: A[m,k] = a[m + k*lda]
|
||||
// B is row-major: B[k,n] = b[k*ldb + n]
|
||||
sum += static_cast<AccType>(a[m + k * lda]) *
|
||||
static_cast<AccType>(b[k * ldb + n]);
|
||||
sum += static_cast<AccType>(a[m + k * lda]) * static_cast<AccType>(b[k * ldb + n]);
|
||||
}
|
||||
|
||||
// D[m,n] = alpha * sum + beta * C[m,n]
|
||||
@@ -346,30 +341,38 @@ void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
|
||||
}
|
||||
|
||||
// Helper to fill matrix with random values
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void fill_random(std::vector<T>& data, T min_val = -1, T max_val = 1)
|
||||
{
|
||||
for(auto& val : data) {
|
||||
val = static_cast<T>(min_val + (max_val - min_val) *
|
||||
static_cast<float>(rand()) / RAND_MAX);
|
||||
for(auto& val : data)
|
||||
{
|
||||
val = static_cast<T>(min_val + (max_val - min_val) * static_cast<float>(rand()) / RAND_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to print matrix (for debugging)
|
||||
template<typename T>
|
||||
void print_matrix(const std::vector<T>& mat, index_t rows, index_t cols,
|
||||
index_t ld, bool col_major = true, const std::string& name = "Matrix")
|
||||
template <typename T>
|
||||
void print_matrix(const std::vector<T>& mat,
|
||||
index_t rows,
|
||||
index_t cols,
|
||||
index_t ld,
|
||||
bool col_major = true,
|
||||
const std::string& name = "Matrix")
|
||||
{
|
||||
std::cout << name << " (" << rows << "×" << cols << "):\n";
|
||||
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i) {
|
||||
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j) {
|
||||
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i)
|
||||
{
|
||||
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j)
|
||||
{
|
||||
index_t idx = col_major ? (i + j * ld) : (i * ld + j);
|
||||
std::cout << std::setw(8) << std::setprecision(3) << mat[idx] << " ";
|
||||
}
|
||||
if(cols > 8) std::cout << "...";
|
||||
if(cols > 8)
|
||||
std::cout << "...";
|
||||
std::cout << "\n";
|
||||
}
|
||||
if(rows > 8) std::cout << "...\n";
|
||||
if(rows > 8)
|
||||
std::cout << "...\n";
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
@@ -392,16 +395,16 @@ int main()
|
||||
constexpr index_t K = 32;
|
||||
|
||||
// Leading dimensions
|
||||
constexpr index_t lda = M; // Column-major
|
||||
constexpr index_t ldb = N; // Row-major
|
||||
constexpr index_t ldc = M; // Column-major
|
||||
constexpr index_t ldd = M; // Column-major
|
||||
constexpr index_t lda = M; // Column-major
|
||||
constexpr index_t ldb = N; // Row-major
|
||||
constexpr index_t ldc = M; // Column-major
|
||||
constexpr index_t ldd = M; // Column-major
|
||||
|
||||
using InputType = half_t; // fp16
|
||||
using AccumType = float; // fp32
|
||||
using InputType = half_t; // fp16
|
||||
using AccumType = float; // fp32
|
||||
|
||||
constexpr AccumType alpha = 2.0f;
|
||||
constexpr AccumType beta = 1.5f;
|
||||
constexpr AccumType beta = 1.5f;
|
||||
|
||||
std::cout << "Problem configuration:\n";
|
||||
std::cout << " M×N×K: " << M << "×" << N << "×" << K << "\n";
|
||||
@@ -409,7 +412,7 @@ int main()
|
||||
std::cout << " B: row-major, ldb=" << ldb << " (fp16)\n";
|
||||
std::cout << " C/D: column-major, ldc=" << ldc << ", ldd=" << ldd << " (fp32)\n";
|
||||
std::cout << " alpha=" << alpha << ", beta=" << beta << "\n";
|
||||
std::cout << " Total FLOPs: " << 2*M*N*K << "\n\n";
|
||||
std::cout << " Total FLOPs: " << 2 * M * N * K << "\n\n";
|
||||
|
||||
// Host memory
|
||||
std::vector<InputType> h_a(M * K);
|
||||
@@ -427,7 +430,7 @@ int main()
|
||||
// CPU reference
|
||||
auto cpu_start = std::chrono::high_resolution_clock::now();
|
||||
reference_gemm_mixed(h_a, h_b, h_c, h_d_ref, M, N, K, lda, ldb, ldc, ldd, alpha, beta);
|
||||
auto cpu_end = std::chrono::high_resolution_clock::now();
|
||||
auto cpu_end = std::chrono::high_resolution_clock::now();
|
||||
double cpu_time_ms = std::chrono::duration<double, std::milli>(cpu_end - cpu_start).count();
|
||||
|
||||
// Device memory
|
||||
@@ -442,34 +445,44 @@ int main()
|
||||
d_d.ToDevice(h_d.data(), M * N * sizeof(AccumType));
|
||||
|
||||
// Launch kernel
|
||||
constexpr index_t block_size = 256; // 4 warps (2×2 configuration)
|
||||
constexpr index_t tiles_per_block_m = 2; // MWarp (no iterations)
|
||||
constexpr index_t tiles_per_block_n = 2; // NWarp (no iterations)
|
||||
constexpr index_t block_size = 256; // 4 warps (2×2 configuration)
|
||||
constexpr index_t tiles_per_block_m = 2; // MWarp (no iterations)
|
||||
constexpr index_t tiles_per_block_n = 2; // NWarp (no iterations)
|
||||
const index_t grid_size = (M / (tiles_per_block_m * 16)) * (N / (tiles_per_block_n * 16));
|
||||
|
||||
std::cout << "Launching kernel:\n";
|
||||
std::cout << " Grid: " << grid_size << " blocks\n";
|
||||
std::cout << " Block: " << block_size << " threads (4 warps in 2×2 config)\n";
|
||||
std::cout << " Each warp computes: ONE 16×16 output tile\n";
|
||||
std::cout << " Each block computes: " << tiles_per_block_m*16 << "×" << tiles_per_block_n*16 << " output\n";
|
||||
std::cout << " Total output tiles: " << (M/16) << "×" << (N/16) << "\n";
|
||||
std::cout << " MFMA instructions per warp: " << (K/16) << "\n\n";
|
||||
std::cout << " Each block computes: " << tiles_per_block_m * 16 << "×"
|
||||
<< tiles_per_block_n * 16 << " output\n";
|
||||
std::cout << " Total output tiles: " << (M / 16) << "×" << (N / 16) << "\n";
|
||||
std::cout << " MFMA instructions per warp: " << (K / 16) << "\n\n";
|
||||
|
||||
stream_config stream;
|
||||
|
||||
// Warmup
|
||||
for(int i = 0; i < 5; ++i) {
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
|
||||
make_kernel<block_size>(
|
||||
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
alpha,
|
||||
beta));
|
||||
}
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
@@ -477,40 +490,50 @@ int main()
|
||||
auto gpu_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
|
||||
make_kernel<block_size>(
|
||||
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
alpha,
|
||||
beta));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
auto gpu_end = std::chrono::high_resolution_clock::now();
|
||||
auto gpu_end = std::chrono::high_resolution_clock::now();
|
||||
double gpu_time_ms = std::chrono::duration<double, std::milli>(gpu_end - gpu_start).count();
|
||||
|
||||
// Get result
|
||||
d_d.FromDevice(h_d.data(), M * N * sizeof(AccumType));
|
||||
|
||||
// Verify
|
||||
bool passed = true;
|
||||
float max_error = 0;
|
||||
bool passed = true;
|
||||
float max_error = 0;
|
||||
index_t error_count = 0;
|
||||
|
||||
for(index_t i = 0; i < M * N; ++i) {
|
||||
for(index_t i = 0; i < M * N; ++i)
|
||||
{
|
||||
float error = std::abs(h_d[i] - h_d_ref[i]);
|
||||
max_error = std::max(max_error, error);
|
||||
if(error > 1e-2f) { // Relaxed tolerance for fp16
|
||||
if(error_count < 5) {
|
||||
max_error = std::max(max_error, error);
|
||||
if(error > 1e-2f)
|
||||
{ // Relaxed tolerance for fp16
|
||||
if(error_count < 5)
|
||||
{
|
||||
index_t m = i % M;
|
||||
index_t n = i / M;
|
||||
std::cout << "Error at [" << m << "," << n << "]: "
|
||||
<< h_d[i] << " vs " << h_d_ref[i]
|
||||
<< " (diff=" << error << ")\n";
|
||||
std::cout << "Error at [" << m << "," << n << "]: " << h_d[i] << " vs "
|
||||
<< h_d_ref[i] << " (diff=" << error << ")\n";
|
||||
}
|
||||
error_count++;
|
||||
}
|
||||
@@ -519,14 +542,15 @@ int main()
|
||||
passed = (error_count == 0);
|
||||
|
||||
// Calculate performance
|
||||
double gflops = 2.0 * M * N * K / 1e9;
|
||||
double gflops = 2.0 * M * N * K / 1e9;
|
||||
double gpu_tflops = gflops / (gpu_time_ms / 1000);
|
||||
double cpu_gflops = gflops / (cpu_time_ms / 1000);
|
||||
|
||||
std::cout << "Results:\n";
|
||||
std::cout << " Correctness: " << (passed ? "✓ PASSED" : "✗ FAILED") << "\n";
|
||||
std::cout << " Max error: " << max_error << "\n";
|
||||
if(!passed) std::cout << " Error count: " << error_count << "/" << M*N << "\n";
|
||||
if(!passed)
|
||||
std::cout << " Error count: " << error_count << "/" << M * N << "\n";
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "Performance:\n";
|
||||
|
||||
@@ -16,25 +16,23 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestADistributionYRepetitionKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t MIterPerWarp = 2;
|
||||
static constexpr index_t KIterPerWarp = 1;
|
||||
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
|
||||
static constexpr index_t kK = 64;
|
||||
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
|
||||
static constexpr index_t kK = 64;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* a,
|
||||
DataType* debug_output,
|
||||
index_t lda) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* a, DataType* debug_output, index_t lda) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
@@ -52,56 +50,56 @@ struct TestADistributionYRepetitionKernel
|
||||
// sequence<>, // Ys_to_Hs: Y maps to both M and K
|
||||
// sequence<>>{}; // Ys_in_Hs
|
||||
|
||||
|
||||
// A distribution with Y-repetition (from tutorial_07)
|
||||
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<16>, sequence<4, 4>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
constexpr auto a_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4, 4>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
|
||||
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
|
||||
constexpr auto a_distribution = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
auto a_window = make_tile_window(
|
||||
a_tensor, make_tuple(number<kM>{}, number<kK>{}),
|
||||
{0, 0}, a_distribution);
|
||||
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_distribution);
|
||||
|
||||
const auto a_tile = load_tile(a_window);
|
||||
const auto a_tile = load_tile(a_window);
|
||||
const auto& thread_buffer = a_tile.get_thread_buffer();
|
||||
|
||||
// Print from all warps sequentially
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== A Distribution with Y-Repetition Test ===\n");
|
||||
printf("Matrix: 64×16 (MWarp=2, MIterPerWarp=2, each warp loads 2×16 tiles)\n");
|
||||
printf("Input: A[m,k] = m + k*100 (unique values)\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("Warp %d (M-warp %d, N-warp %d):\n",
|
||||
w, w/NWarp, w%NWarp);
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("Warp %d (M-warp %d, N-warp %d):\n", w, w / NWarp, w % NWarp);
|
||||
printf(" Thread buffer size: %d\n", static_cast<int>(thread_buffer.size()));
|
||||
printf(" Values: ");
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
printf("%.0f ", static_cast<float>(thread_buffer[i]));
|
||||
}
|
||||
printf("\n");
|
||||
@@ -110,7 +108,8 @@ struct TestADistributionYRepetitionKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Expected Pattern ===\n");
|
||||
printf("Each warp should load 8 elements (2 M-iters × 1 K-iter × 4 warp elements)\n");
|
||||
printf("Warp 0 (M-warp 0): Should have M-rows [0-15] and [16-31]\n");
|
||||
@@ -120,7 +119,8 @@ struct TestADistributionYRepetitionKernel
|
||||
}
|
||||
|
||||
// Store for verification
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
debug_output[tid * 8 + i] = thread_buffer[i];
|
||||
}
|
||||
}
|
||||
@@ -132,8 +132,8 @@ int main()
|
||||
std::cout << "Test A Distribution with Y-Dimension Repetition\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
|
||||
constexpr index_t M = 64;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t M = 64;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t lda = M;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -142,8 +142,10 @@ int main()
|
||||
std::vector<DataType> h_debug(256 * 8, -1);
|
||||
|
||||
// Initialize A[m,k] = m + k*100 (unique for each position)
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t m = 0; m < M; ++m) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
h_a[m + k * lda] = static_cast<DataType>(m + k * 100);
|
||||
}
|
||||
}
|
||||
@@ -156,12 +158,13 @@ int main()
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestADistributionYRepetitionKernel<DataType>{},
|
||||
dim3(1), dim3(256), 0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
lda));
|
||||
make_kernel<256>(TestADistributionYRepetitionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
lda));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
|
||||
@@ -18,24 +18,23 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestAYSlicingKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t MIterPerWarp = 2;
|
||||
static constexpr index_t KIterPerWarp = 1;
|
||||
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
|
||||
static constexpr index_t kK = 16; // Fixed to match distribution coverage
|
||||
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
|
||||
static constexpr index_t kK = 16; // Fixed to match distribution coverage
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* a,
|
||||
index_t lda) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* a, index_t lda) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
@@ -44,51 +43,50 @@ struct TestAYSlicingKernel
|
||||
a, make_tuple(kM, kK), make_tuple(1, lda), number<1>{}, number<1>{});
|
||||
|
||||
// A warp-level distribution
|
||||
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<16>, sequence<4, 4>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
constexpr auto a_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4, 4>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
|
||||
// A block-level outer distribution with Y-repetition
|
||||
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>,
|
||||
sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
|
||||
constexpr auto a_block_distribution = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// Get Y-dimension information
|
||||
using AWarpDstr = decltype(make_static_tile_distribution(a_warp_dstr_encode));
|
||||
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
|
||||
// Create window and load full block tile
|
||||
auto a_window = make_tile_window(
|
||||
a_tensor, make_tuple(number<kM>{}, number<kK>{}),
|
||||
{0, 0}, a_block_distribution);
|
||||
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_block_distribution);
|
||||
|
||||
const auto a_block_tile = load_tile(a_window);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== A Y-Slicing Test ===\n");
|
||||
printf("Block tile: %d×%d (M×K)\n", kM, kK);
|
||||
printf("MIterPerWarp=%d, KIterPerWarp=%d\n", MIterPerWarp, KIterPerWarp);
|
||||
printf("Input: A[m,k] = m*1000 + k (unique values)\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Test Y-slicing for each warp and iteration
|
||||
@@ -97,7 +95,7 @@ struct TestAYSlicingKernel
|
||||
// Extract warp tensor for this iteration
|
||||
auto a_warp_tensor = make_static_distributed_tensor<DataType>(
|
||||
make_static_tile_distribution(a_warp_dstr_encode));
|
||||
|
||||
|
||||
// CORRECTED: kIter first, then mIter (matching the Ys_to_Hs order)
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
@@ -106,17 +104,25 @@ struct TestAYSlicingKernel
|
||||
const auto& warp_buffer = a_warp_tensor.get_thread_buffer();
|
||||
|
||||
// Print from each warp sequentially
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("Warp %d (M-warp %d, N-warp %d), MIter=%d, KIter=%d:\n",
|
||||
w, w/NWarp, w%NWarp, static_cast<int>(mIter), static_cast<int>(kIter));
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("Warp %d (M-warp %d, N-warp %d), MIter=%d, KIter=%d:\n",
|
||||
w,
|
||||
w / NWarp,
|
||||
w % NWarp,
|
||||
static_cast<int>(mIter),
|
||||
static_cast<int>(kIter));
|
||||
printf(" Buffer size: %d\n", static_cast<int>(warp_buffer.size()));
|
||||
printf(" Values: ");
|
||||
for(int i = 0; i < warp_buffer.size() && i < 16; ++i) {
|
||||
for(int i = 0; i < warp_buffer.size() && i < 16; ++i)
|
||||
{
|
||||
printf("%.0f ", static_cast<float>(warp_buffer[i]));
|
||||
}
|
||||
if(warp_buffer.size() > 16) printf("...");
|
||||
if(warp_buffer.size() > 16)
|
||||
printf("...");
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
@@ -125,7 +131,8 @@ struct TestAYSlicingKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Expected Pattern ===\n");
|
||||
printf("Each warp should get 4 elements per iteration (16 M × 16 K / 64 threads)\n");
|
||||
printf("Warp 0, MIter=0: Should have values from A[0:16, 0:16]\n");
|
||||
@@ -144,8 +151,8 @@ int main()
|
||||
std::cout << "Test A Y-Slicing with get_y_sliced_thread_data\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
|
||||
constexpr index_t M = 64; // 2 warps × 2 iters × 16
|
||||
constexpr index_t K = 16; // Match distribution coverage
|
||||
constexpr index_t M = 64; // 2 warps × 2 iters × 16
|
||||
constexpr index_t K = 16; // Match distribution coverage
|
||||
constexpr index_t lda = M;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -154,8 +161,10 @@ int main()
|
||||
|
||||
// Initialize A[m,k] = m*1000 + k (easy to identify position)
|
||||
auto counter = 0;
|
||||
for(index_t m = 0; m < M; ++m) {
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
h_a[m + k * lda] = static_cast<DataType>(counter++);
|
||||
}
|
||||
}
|
||||
@@ -166,18 +175,19 @@ int main()
|
||||
std::cout << " A[16,0] = " << static_cast<float>(h_a[16]) << "\n";
|
||||
std::cout << " A[32,0] = " << static_cast<float>(h_a[32]) << "\n";
|
||||
std::cout << " A[48,0] = " << static_cast<float>(h_a[48]) << "\n";
|
||||
std::cout << " A[0,15] = " << static_cast<float>(h_a[15*lda]) << "\n\n";
|
||||
std::cout << " A[0,15] = " << static_cast<float>(h_a[15 * lda]) << "\n\n";
|
||||
|
||||
DeviceMem d_a(M * K * sizeof(DataType));
|
||||
d_a.ToDevice(h_a.data(), M * K * sizeof(DataType));
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestAYSlicingKernel<DataType>{},
|
||||
dim3(1), dim3(256), 0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
lda));
|
||||
make_kernel<256>(TestAYSlicingKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
|
||||
lda));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
|
||||
@@ -16,26 +16,25 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestBDistributionYRepetitionKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t NIterPerWarp = 2;
|
||||
static constexpr index_t KIterPerWarp = 1;
|
||||
static constexpr index_t kK = 64;
|
||||
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
|
||||
static constexpr index_t kK = 64;
|
||||
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* b,
|
||||
DataType* debug_output,
|
||||
index_t ldb) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* b, DataType* debug_output, index_t ldb) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
//each warp is 64 x 4 items and 4 warps total and 2 iteration, so totally it becomes 64 x 32 we don't cover the whole matrix
|
||||
const index_t tid = threadIdx.x;
|
||||
// each warp is 64 x 4 items and 4 warps total and 2 iteration, so totally it becomes 64 x
|
||||
// 32 we don't cover the whole matrix
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
@@ -44,55 +43,55 @@ struct TestBDistributionYRepetitionKernel
|
||||
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
|
||||
|
||||
// B distribution with Y-repetition (from tutorial_07)
|
||||
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
|
||||
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<KIterPerWarp>,
|
||||
sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<2, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<KIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<2, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
|
||||
constexpr auto b_distribution = make_static_tile_distribution(b_block_dstr_encode);
|
||||
|
||||
auto b_window = make_tile_window(
|
||||
b_tensor, make_tuple(number<kK>{}, number<kN>{}),
|
||||
{0, 0}, b_distribution);
|
||||
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_distribution);
|
||||
|
||||
const auto b_tile = load_tile(b_window);
|
||||
const auto b_tile = load_tile(b_window);
|
||||
const auto& thread_buffer = b_tile.get_thread_buffer();
|
||||
|
||||
// Print from all warps sequentially
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== B Distribution with Y-Repetition Test ===\n");
|
||||
printf("Matrix: 16×64 (NWarp=2, NIterPerWarp=2, each warp loads 2×16 tiles)\n");
|
||||
printf("Input: B[k,n] = k + n*100 (unique values)\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("Warp %d (M-warp %d, N-warp %d):\n",
|
||||
w, w/NWarp, w%NWarp);
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("Warp %d (M-warp %d, N-warp %d):\n", w, w / NWarp, w % NWarp);
|
||||
printf(" Thread buffer size: %d\n", static_cast<int>(thread_buffer.size()));
|
||||
printf(" Values: ");
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
printf("%.0f ", static_cast<float>(thread_buffer[i]));
|
||||
}
|
||||
printf("\n");
|
||||
@@ -101,7 +100,8 @@ struct TestBDistributionYRepetitionKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Expected Pattern ===\n");
|
||||
printf("Each warp should load 8 elements (2 N-iters × 1 K-iter × 4 warp elements)\n");
|
||||
printf("Warp 0 (N-warp 0): Should have N-cols [0-15] and [16-31]\n");
|
||||
@@ -111,7 +111,8 @@ struct TestBDistributionYRepetitionKernel
|
||||
}
|
||||
|
||||
// Store for verification
|
||||
for(int i = 0; i < thread_buffer.size(); ++i) {
|
||||
for(int i = 0; i < thread_buffer.size(); ++i)
|
||||
{
|
||||
debug_output[tid * 8 + i] = thread_buffer[i];
|
||||
}
|
||||
}
|
||||
@@ -123,8 +124,8 @@ int main()
|
||||
std::cout << "Test B Distribution with Y-Dimension Repetition\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t ldb = N;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -134,8 +135,10 @@ int main()
|
||||
|
||||
// Initialize B[k,n] = k + n*100 (unique for each position)
|
||||
auto counter = 0;
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t n = 0; n < N; ++n) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
h_b[k * ldb + n] = static_cast<DataType>(counter++);
|
||||
}
|
||||
}
|
||||
@@ -148,12 +151,13 @@ int main()
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestBDistributionYRepetitionKernel<DataType>{},
|
||||
dim3(1), dim3(256), 0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
ldb));
|
||||
make_kernel<256>(TestBDistributionYRepetitionKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
|
||||
ldb));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
|
||||
@@ -18,24 +18,23 @@
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template<typename DataType>
|
||||
template <typename DataType>
|
||||
struct TestBYSlicingKernel
|
||||
{
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t MWarp = 2;
|
||||
static constexpr index_t NWarp = 2;
|
||||
static constexpr index_t NIterPerWarp = 2;
|
||||
static constexpr index_t KIterPerWarp = 1;
|
||||
static constexpr index_t kK = 16; // Fixed to match distribution coverage
|
||||
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
|
||||
static constexpr index_t kK = 16; // Fixed to match distribution coverage
|
||||
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
|
||||
|
||||
CK_TILE_DEVICE void operator()(const DataType* b,
|
||||
index_t ldb) const
|
||||
CK_TILE_DEVICE void operator()(const DataType* b, index_t ldb) const
|
||||
{
|
||||
if(get_block_id() != 0)
|
||||
return;
|
||||
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t tid = threadIdx.x;
|
||||
const index_t warp_id = tid / 64;
|
||||
const index_t lane_id = tid % 64;
|
||||
|
||||
@@ -44,51 +43,50 @@ struct TestBYSlicingKernel
|
||||
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
|
||||
|
||||
// B warp-level distribution
|
||||
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
|
||||
// B block-level outer distribution with Y-repetition
|
||||
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<KIterPerWarp>,
|
||||
sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<2, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<KIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<2, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
|
||||
constexpr auto b_block_distribution = make_static_tile_distribution(b_block_dstr_encode);
|
||||
|
||||
// Get Y-dimension information
|
||||
using BWarpDstr = decltype(make_static_tile_distribution(b_warp_dstr_encode));
|
||||
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
|
||||
// Create window and load full block tile
|
||||
auto b_window = make_tile_window(
|
||||
b_tensor, make_tuple(number<kK>{}, number<kN>{}),
|
||||
{0, 0}, b_block_distribution);
|
||||
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_block_distribution);
|
||||
|
||||
const auto b_block_tile = load_tile(b_window);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== B Y-Slicing Test ===\n");
|
||||
printf("Block tile: %d×%d (K×N)\n", kK, kN);
|
||||
printf("NIterPerWarp=%d, KIterPerWarp=%d\n", NIterPerWarp, KIterPerWarp);
|
||||
printf("Input: B[k,n] = k*1000 + n (unique values)\n\n");
|
||||
}
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Test Y-slicing for each warp and iteration
|
||||
@@ -97,7 +95,7 @@ struct TestBYSlicingKernel
|
||||
// Extract warp tensor for this iteration
|
||||
auto b_warp_tensor = make_static_distributed_tensor<DataType>(
|
||||
make_static_tile_distribution(b_warp_dstr_encode));
|
||||
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
@@ -105,17 +103,25 @@ struct TestBYSlicingKernel
|
||||
const auto& warp_buffer = b_warp_tensor.get_thread_buffer();
|
||||
|
||||
// Print from each warp sequentially
|
||||
for(int w = 0; w < 4; ++w) {
|
||||
for(int w = 0; w < 4; ++w)
|
||||
{
|
||||
__syncthreads();
|
||||
if(warp_id == w && lane_id == 0) {
|
||||
printf("Warp %d (M-warp %d, N-warp %d), NIter=%d, KIter=%d:\n",
|
||||
w, w/NWarp, w%NWarp, static_cast<int>(nIter), static_cast<int>(kIter));
|
||||
if(warp_id == w && lane_id == 0)
|
||||
{
|
||||
printf("Warp %d (M-warp %d, N-warp %d), NIter=%d, KIter=%d:\n",
|
||||
w,
|
||||
w / NWarp,
|
||||
w % NWarp,
|
||||
static_cast<int>(nIter),
|
||||
static_cast<int>(kIter));
|
||||
printf(" Buffer size: %d\n", static_cast<int>(warp_buffer.size()));
|
||||
printf(" Values: ");
|
||||
for(int i = 0; i < warp_buffer.size() && i < 16; ++i) {
|
||||
for(int i = 0; i < warp_buffer.size() && i < 16; ++i)
|
||||
{
|
||||
printf("%.0f ", static_cast<float>(warp_buffer[i]));
|
||||
}
|
||||
if(warp_buffer.size() > 16) printf("...");
|
||||
if(warp_buffer.size() > 16)
|
||||
printf("...");
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
@@ -124,7 +130,8 @@ struct TestBYSlicingKernel
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0) {
|
||||
if(tid == 0)
|
||||
{
|
||||
printf("\n=== Expected Pattern ===\n");
|
||||
printf("Each warp should get 4 elements per iteration (16 K × 16 N / 64 threads)\n");
|
||||
printf("Warp 0, NIter=0: Should have values from B[0:16, 0:16]\n");
|
||||
@@ -142,8 +149,8 @@ int main()
|
||||
std::cout << "Test B Y-Slicing with get_y_sliced_thread_data\n";
|
||||
std::cout << "==================================================\n\n";
|
||||
|
||||
constexpr index_t K = 16; // Match distribution coverage
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t K = 16; // Match distribution coverage
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t ldb = N;
|
||||
|
||||
using DataType = half_t;
|
||||
@@ -152,8 +159,10 @@ int main()
|
||||
|
||||
// Initialize B[k,n] = k*1000 + n (easy to identify position)
|
||||
auto counter = 0;
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
for(index_t n = 0; n < N; ++n) {
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
h_b[k * ldb + n] = static_cast<DataType>(counter++);
|
||||
}
|
||||
}
|
||||
@@ -164,18 +173,19 @@ int main()
|
||||
std::cout << " B[0,16] = " << static_cast<float>(h_b[16]) << "\n";
|
||||
std::cout << " B[0,32] = " << static_cast<float>(h_b[32]) << "\n";
|
||||
std::cout << " B[0,48] = " << static_cast<float>(h_b[48]) << "\n";
|
||||
std::cout << " B[15,0] = " << static_cast<float>(h_b[15*ldb]) << "\n\n";
|
||||
std::cout << " B[15,0] = " << static_cast<float>(h_b[15 * ldb]) << "\n\n";
|
||||
|
||||
DeviceMem d_b(K * N * sizeof(DataType));
|
||||
d_b.ToDevice(h_b.data(), K * N * sizeof(DataType));
|
||||
|
||||
stream_config stream;
|
||||
launch_kernel(stream,
|
||||
make_kernel<256>(
|
||||
TestBYSlicingKernel<DataType>{},
|
||||
dim3(1), dim3(256), 0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
ldb));
|
||||
make_kernel<256>(TestBYSlicingKernel<DataType>{},
|
||||
dim3(1),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
|
||||
ldb));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
|
||||
@@ -30,23 +30,23 @@
|
||||
using namespace ck_tile;
|
||||
|
||||
// Tile Sweeping HGEMM kernel with Y-dimension repetition
|
||||
template<typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
|
||||
struct TileSweepingYRepetitionHgemmKernel
|
||||
{
|
||||
static constexpr index_t kWaveSize = 64; // AMD wave size
|
||||
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
|
||||
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
|
||||
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
|
||||
static constexpr index_t kWaveSize = 64; // AMD wave size
|
||||
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
|
||||
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
|
||||
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
|
||||
|
||||
// Warp configuration: 2×2 warps per block (SAME as Tutorial 06)
|
||||
static constexpr index_t MWarp = 2; // 2 warps in M dimension
|
||||
static constexpr index_t NWarp = 2; // 2 warps in N dimension
|
||||
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
|
||||
static constexpr index_t MWarp = 2; // 2 warps in M dimension
|
||||
static constexpr index_t NWarp = 2; // 2 warps in N dimension
|
||||
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
|
||||
|
||||
// NEW: Tile iterations per warp (Y-dimension repetition)
|
||||
static constexpr index_t MIterPerWarp = 2; // Each warp sweeps 2 tiles in M
|
||||
static constexpr index_t NIterPerWarp = 2; // Each warp sweeps 2 tiles in N
|
||||
static constexpr index_t KIterPerWarp = 1; // K handled in main loop
|
||||
static constexpr index_t MIterPerWarp = 2; // Each warp sweeps 2 tiles in M
|
||||
static constexpr index_t NIterPerWarp = 2; // Each warp sweeps 2 tiles in N
|
||||
static constexpr index_t KIterPerWarp = 1; // K handled in main loop
|
||||
|
||||
// Use ck_tile's WarpGemm for MFMA
|
||||
using WarpGemm = WarpGemmMfmaF16F16F32M16N16K16;
|
||||
@@ -58,28 +58,28 @@ struct TileSweepingYRepetitionHgemmKernel
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t lda, // Leading dimension of A (column-major)
|
||||
index_t ldb, // Leading dimension of B (row-major)
|
||||
index_t ldc, // Leading dimension of C (column-major)
|
||||
index_t ldd, // Leading dimension of D (column-major)
|
||||
index_t lda, // Leading dimension of A (column-major)
|
||||
index_t ldb, // Leading dimension of B (row-major)
|
||||
index_t ldc, // Leading dimension of C (column-major)
|
||||
index_t ldd, // Leading dimension of D (column-major)
|
||||
AccDataType alpha,
|
||||
AccDataType beta) const
|
||||
{
|
||||
// Calculate which warp this thread belongs to within the block
|
||||
[[maybe_unused]] const index_t warp_id = get_warp_id();
|
||||
[[maybe_unused]] const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
|
||||
[[maybe_unused]] const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
|
||||
[[maybe_unused]] const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
|
||||
[[maybe_unused]] const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
|
||||
|
||||
// Calculate base offset for this block
|
||||
// Each block now computes (MWarp × MIterPerWarp × kWarpM) × (NWarp × NIterPerWarp × kWarpN)
|
||||
const index_t kMPerBlock = MWarp * MIterPerWarp * kWarpM; // 2×2×16 = 64
|
||||
const index_t kNPerBlock = NWarp * NIterPerWarp * kWarpN; // 2×2×16 = 64
|
||||
|
||||
const index_t kMPerBlock = MWarp * MIterPerWarp * kWarpM; // 2×2×16 = 64
|
||||
const index_t kNPerBlock = NWarp * NIterPerWarp * kWarpN; // 2×2×16 = 64
|
||||
|
||||
// Calculate block position in 2D grid
|
||||
const index_t num_blocks_n = N / kNPerBlock;
|
||||
const index_t block_m = get_block_id() / num_blocks_n;
|
||||
const index_t block_n = get_block_id() % num_blocks_n;
|
||||
|
||||
const index_t block_m = get_block_id() / num_blocks_n;
|
||||
const index_t block_n = get_block_id() % num_blocks_n;
|
||||
|
||||
const index_t m_block_base = block_m * kMPerBlock;
|
||||
const index_t n_block_base = block_n * kNPerBlock;
|
||||
|
||||
@@ -89,111 +89,89 @@ struct TileSweepingYRepetitionHgemmKernel
|
||||
|
||||
// Create tensor views for matrices
|
||||
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
a,
|
||||
make_tuple(M, K),
|
||||
make_tuple(1, lda),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
a, make_tuple(M, K), make_tuple(1, lda), number<1>{}, number<1>{});
|
||||
|
||||
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
b,
|
||||
make_tuple(K, N),
|
||||
make_tuple(ldb, 1),
|
||||
number<4>{},
|
||||
number<1>{}
|
||||
);
|
||||
b, make_tuple(K, N), make_tuple(ldb, 1), number<4>{}, number<1>{});
|
||||
|
||||
const auto c_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
c,
|
||||
make_tuple(M, N),
|
||||
make_tuple(1, ldc),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
c, make_tuple(M, N), make_tuple(1, ldc), number<1>{}, number<1>{});
|
||||
|
||||
auto d_tensor = make_naive_tensor_view<address_space_enum::global>(
|
||||
d,
|
||||
make_tuple(M, N),
|
||||
make_tuple(1, ldd),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
d, make_tuple(M, N), make_tuple(1, ldd), number<1>{}, number<1>{});
|
||||
|
||||
// ============================================================================
|
||||
// TILE DISTRIBUTIONS with Y-DIMENSION REPETITION (following 02_gemm pattern)
|
||||
// ============================================================================
|
||||
|
||||
|
||||
// A Distribution: Block-level with Y-repetition
|
||||
// Warp-level distribution (same as Tutorial 06)
|
||||
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<16>, sequence<4, 4>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
constexpr auto a_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4, 4>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
|
||||
// constexpr auto a_block_outer_dstr_encoding =
|
||||
// tile_distribution_encoding<sequence<NWarp>,
|
||||
// tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
// tuple<sequence<1, 0>>,
|
||||
// tuple<sequence<MIterPerWarp, MWarp>,
|
||||
// sequence<KIterPerWarp>>, tuple<sequence<1, 0>>,
|
||||
// tuple<sequence<1, 0>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
|
||||
// Block-level outer distribution with Y-repetition
|
||||
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<NWarp>, // Replicate across N-warps
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // H0: 2 iters × 2 warps in M
|
||||
tuple<sequence<0, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 1>>, // Ps_in_Hs
|
||||
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and K
|
||||
sequence<0, 0>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
constexpr auto a_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<NWarp>, // Replicate across N-warps
|
||||
tuple<sequence<MIterPerWarp, MWarp>,
|
||||
sequence<KIterPerWarp>>, // H0: 2 iters × 2 warps in M
|
||||
tuple<sequence<0, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<0, 1>>, // Ps_in_Hs
|
||||
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and K
|
||||
sequence<0, 0>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encode, a_warp_dstr_encode);
|
||||
|
||||
// B Distribution: Block-level with Y-repetition
|
||||
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
|
||||
constexpr auto b_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
|
||||
// constexpr auto b_block_outer_dstr_encode =
|
||||
// tile_distribution_encoding<sequence<MWarp>,
|
||||
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// tuple<sequence<NIterPerWarp, NWarp>,
|
||||
// sequence<KIterPerWarp>>, tuple<sequence<0, 1>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<MWarp>, // Replicate across M-warps
|
||||
tuple<sequence<KIterPerWarp>, // H0: 2 iters × 2 warps in N
|
||||
sequence<NIterPerWarp, NWarp>>, // H1: 1 K-chunk
|
||||
tuple<sequence<2, 0>>, // Ps_to_Hs
|
||||
tuple<sequence<1, 0>>, // Ps_in_Hs
|
||||
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH N and K
|
||||
sequence<0, 0>>{}; // Ys_in_Hs
|
||||
constexpr auto b_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<MWarp>, // Replicate across M-warps
|
||||
tuple<sequence<KIterPerWarp>, // H0: 2 iters × 2 warps in N
|
||||
sequence<NIterPerWarp, NWarp>>, // H1: 1 K-chunk
|
||||
tuple<sequence<2, 0>>, // Ps_to_Hs
|
||||
tuple<sequence<1, 0>>, // Ps_in_Hs
|
||||
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH N and K
|
||||
sequence<0, 0>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto b_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encode, b_warp_dstr_encode);
|
||||
|
||||
// // C Distribution: Block-level with Y-repetition for output
|
||||
constexpr auto c_warp_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
constexpr auto c_warp_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4, 4>, sequence<16>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<1>,
|
||||
sequence<1>>{};
|
||||
|
||||
// constexpr auto c_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
// sequence<>,
|
||||
@@ -203,18 +181,17 @@ struct TileSweepingYRepetitionHgemmKernel
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encode = tile_distribution_encoding<
|
||||
sequence<>, // No replication for output
|
||||
tuple<sequence<MIterPerWarp, MWarp>, // H0: M iterations
|
||||
sequence<NIterPerWarp, NWarp>>, // H1: N iterations
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<1, 1>>, // Ps_in_Hs
|
||||
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and N
|
||||
sequence<0, 0>>{}; // Ys_in_Hs
|
||||
constexpr auto c_block_outer_dstr_encode =
|
||||
tile_distribution_encoding<sequence<>, // No replication for output
|
||||
tuple<sequence<MIterPerWarp, MWarp>, // H0: M iterations
|
||||
sequence<NIterPerWarp, NWarp>>, // H1: N iterations
|
||||
tuple<sequence<2, 1>>, // Ps_to_Hs
|
||||
tuple<sequence<1, 1>>, // Ps_in_Hs
|
||||
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and N
|
||||
sequence<0, 0>>{}; // Ys_in_Hs
|
||||
|
||||
constexpr auto c_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encode, c_warp_dstr_encode);
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encode, c_warp_dstr_encode);
|
||||
|
||||
// Create distributions
|
||||
constexpr auto a_block_distribution = make_static_tile_distribution(a_block_dstr_encode);
|
||||
@@ -226,72 +203,71 @@ struct TileSweepingYRepetitionHgemmKernel
|
||||
using BWarpDstr = decltype(make_static_tile_distribution(b_warp_dstr_encode));
|
||||
using CWarpDstr = decltype(make_static_tile_distribution(c_warp_dstr_encode));
|
||||
|
||||
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// // Create block-level windows
|
||||
auto a_block_window = make_tile_window(
|
||||
a_tensor,
|
||||
make_tuple(number<kMPerBlock>{}, number<kWarpK>{}),
|
||||
{m_block_base, 0},
|
||||
a_block_distribution
|
||||
);
|
||||
// // Create block-level windows
|
||||
auto a_block_window = make_tile_window(a_tensor,
|
||||
make_tuple(number<kMPerBlock>{}, number<kWarpK>{}),
|
||||
{m_block_base, 0},
|
||||
a_block_distribution);
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_tensor,
|
||||
make_tuple(number<kWarpK>{}, number<kNPerBlock>{}),
|
||||
{0, n_block_base},
|
||||
b_block_distribution
|
||||
);
|
||||
auto b_block_window = make_tile_window(b_tensor,
|
||||
make_tuple(number<kWarpK>{}, number<kNPerBlock>{}),
|
||||
{0, n_block_base},
|
||||
b_block_distribution);
|
||||
|
||||
// // Create block-level accumulator tile
|
||||
// // Create block-level accumulator tile
|
||||
auto c_block_tile = make_static_distributed_tensor<AccDataType>(c_block_distribution);
|
||||
set_tile(c_block_tile, AccDataType{0});
|
||||
|
||||
// // Main K-loop
|
||||
// // Main K-loop
|
||||
const index_t num_k_loops = K / kWarpK;
|
||||
for(index_t k_iter = 0; k_iter < num_k_loops; ++k_iter)
|
||||
{
|
||||
// Load entire block tiles (all iterations at once)
|
||||
const auto a_block_tile = load_tile(a_block_window);
|
||||
const auto b_block_tile = load_tile(b_block_window);
|
||||
|
||||
|
||||
// Nested loops over tile iterations using Y-slicing (like 02_gemm)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Extract A warp tensor for this M-iteration using Y-slicing
|
||||
auto a_warp_tensor = make_static_distributed_tensor<ADataType>(
|
||||
make_static_tile_distribution(a_warp_dstr_encode));
|
||||
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Extract B warp tensor for this N-iteration using Y-slicing
|
||||
auto b_warp_tensor = make_static_distributed_tensor<BDataType>(
|
||||
make_static_tile_distribution(b_warp_dstr_encode));
|
||||
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
|
||||
// Extract C warp tensor for this M,N iteration
|
||||
auto c_warp_tensor = make_static_distributed_tensor<AccDataType>(
|
||||
make_static_tile_distribution(c_warp_dstr_encode));
|
||||
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// Warp GEMM: C[mIter, nIter] += A[mIter, kIter] * B[nIter, kIter]
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
|
||||
// Write C warp tensor back to block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
@@ -302,7 +278,8 @@ struct TileSweepingYRepetitionHgemmKernel
|
||||
});
|
||||
|
||||
// Move windows to next K chunk
|
||||
if(k_iter < num_k_loops - 1) {
|
||||
if(k_iter < num_k_loops - 1)
|
||||
{
|
||||
move_tile_window(a_block_window, {0, kWarpK});
|
||||
move_tile_window(b_block_window, {kWarpK, 0});
|
||||
}
|
||||
@@ -314,62 +291,67 @@ struct TileSweepingYRepetitionHgemmKernel
|
||||
// Add beta * C if needed
|
||||
if(std::abs(beta) > 1e-6f)
|
||||
{
|
||||
auto c_block_window = make_tile_window(
|
||||
c_tensor,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{m_block_base, n_block_base},
|
||||
c_block_distribution
|
||||
);
|
||||
auto c_block_window =
|
||||
make_tile_window(c_tensor,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{m_block_base, n_block_base},
|
||||
c_block_distribution);
|
||||
|
||||
const auto c_input_block_tile = load_tile(c_block_window);
|
||||
|
||||
tile_elementwise_inout(
|
||||
[beta](const auto& c_val, auto& acc_val) {
|
||||
acc_val += beta * c_val;
|
||||
},
|
||||
c_input_block_tile, c_block_tile);
|
||||
[beta](const auto& c_val, auto& acc_val) { acc_val += beta * c_val; },
|
||||
c_input_block_tile,
|
||||
c_block_tile);
|
||||
}
|
||||
|
||||
// Store final result to D
|
||||
auto d_block_window = make_tile_window(
|
||||
d_tensor,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{m_block_base, n_block_base},
|
||||
c_block_distribution
|
||||
);
|
||||
auto d_block_window =
|
||||
make_tile_window(d_tensor,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
{m_block_base, n_block_base},
|
||||
c_block_distribution);
|
||||
|
||||
store_tile(d_block_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
// CPU reference for verification
|
||||
template<typename InType, typename AccType>
|
||||
template <typename InType, typename AccType>
|
||||
void reference_gemm_mixed(const std::vector<InType>& a,
|
||||
const std::vector<InType>& b,
|
||||
const std::vector<AccType>& c,
|
||||
std::vector<AccType>& d,
|
||||
index_t M, index_t N, index_t K,
|
||||
index_t lda, index_t ldb, index_t ldc, index_t ldd,
|
||||
AccType alpha, AccType beta)
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t lda,
|
||||
index_t ldb,
|
||||
index_t ldc,
|
||||
index_t ldd,
|
||||
AccType alpha,
|
||||
AccType beta)
|
||||
{
|
||||
for(index_t n = 0; n < N; ++n) {
|
||||
for(index_t m = 0; m < M; ++m) {
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
AccType sum = 0;
|
||||
for(index_t k = 0; k < K; ++k) {
|
||||
sum += static_cast<AccType>(a[m + k * lda]) *
|
||||
static_cast<AccType>(b[k * ldb + n]);
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
sum += static_cast<AccType>(a[m + k * lda]) * static_cast<AccType>(b[k * ldb + n]);
|
||||
}
|
||||
d[m + n * ldd] = alpha * sum + beta * c[m + n * ldc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void fill_random(std::vector<T>& data, T min_val = -1, T max_val = 1)
|
||||
{
|
||||
for(auto& val : data) {
|
||||
val = static_cast<T>(min_val + (max_val - min_val) *
|
||||
static_cast<float>(rand()) / RAND_MAX);
|
||||
for(auto& val : data)
|
||||
{
|
||||
val = static_cast<T>(min_val + (max_val - min_val) * static_cast<float>(rand()) / RAND_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,7 +372,7 @@ int main()
|
||||
// constexpr index_t M = 128;
|
||||
// constexpr index_t N = 128;
|
||||
// constexpr index_t K = 64;
|
||||
|
||||
|
||||
constexpr index_t M = 128;
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t K = 64;
|
||||
@@ -404,13 +386,13 @@ int main()
|
||||
using AccumType = float;
|
||||
|
||||
constexpr AccumType alpha = 2.0f;
|
||||
constexpr AccumType beta = 1.5f;
|
||||
constexpr AccumType beta = 1.5f;
|
||||
|
||||
std::cout << "Problem configuration:\n";
|
||||
std::cout << " M×N×K: " << M << "×" << N << "×" << K << "\n";
|
||||
std::cout << " Block output: 64×64 (2 warps × 2 iters × 16)\n";
|
||||
std::cout << " Warp output: 32×32 (2 iters × 16 in each dim)\n";
|
||||
std::cout << " Total blocks: " << (M/64) << "×" << (N/64) << "\n\n";
|
||||
std::cout << " Total blocks: " << (M / 64) << "×" << (N / 64) << "\n\n";
|
||||
|
||||
// Host memory
|
||||
std::vector<InputType> h_a(M * K);
|
||||
@@ -427,7 +409,7 @@ int main()
|
||||
// CPU reference
|
||||
auto cpu_start = std::chrono::high_resolution_clock::now();
|
||||
reference_gemm_mixed(h_a, h_b, h_c, h_d_ref, M, N, K, lda, ldb, ldc, ldd, alpha, beta);
|
||||
auto cpu_end = std::chrono::high_resolution_clock::now();
|
||||
auto cpu_end = std::chrono::high_resolution_clock::now();
|
||||
double cpu_time_ms = std::chrono::duration<double, std::milli>(cpu_end - cpu_start).count();
|
||||
|
||||
// Device memory
|
||||
@@ -443,7 +425,7 @@ int main()
|
||||
|
||||
// Launch kernel
|
||||
constexpr index_t block_size = 256;
|
||||
const index_t grid_size = (M / 64) * (N / 64); // 64×64 per block
|
||||
const index_t grid_size = (M / 64) * (N / 64); // 64×64 per block
|
||||
|
||||
std::cout << "Launching kernel:\n";
|
||||
std::cout << " Grid: " << grid_size << " blocks\n";
|
||||
@@ -454,59 +436,80 @@ int main()
|
||||
stream_config stream;
|
||||
|
||||
// Warmup
|
||||
for(int i = 0; i < 5; ++i) {
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
launch_kernel(
|
||||
stream,
|
||||
make_kernel<block_size>(
|
||||
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
alpha,
|
||||
beta));
|
||||
}
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
// Timed run
|
||||
auto gpu_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
launch_kernel(stream,
|
||||
make_kernel<block_size>(
|
||||
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
|
||||
launch_kernel(
|
||||
stream,
|
||||
make_kernel<block_size>(
|
||||
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0,
|
||||
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
|
||||
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
|
||||
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
|
||||
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldd,
|
||||
alpha,
|
||||
beta));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
auto gpu_end = std::chrono::high_resolution_clock::now();
|
||||
auto gpu_end = std::chrono::high_resolution_clock::now();
|
||||
double gpu_time_ms = std::chrono::duration<double, std::milli>(gpu_end - gpu_start).count();
|
||||
|
||||
// Get result
|
||||
d_d.FromDevice(h_d.data(), M * N * sizeof(AccumType));
|
||||
|
||||
// Verify
|
||||
bool passed = true;
|
||||
float max_error = 0;
|
||||
bool passed = true;
|
||||
float max_error = 0;
|
||||
index_t error_count = 0;
|
||||
|
||||
for(index_t i = 0; i < M * N; ++i) {
|
||||
for(index_t i = 0; i < M * N; ++i)
|
||||
{
|
||||
float error = std::abs(h_d[i] - h_d_ref[i]);
|
||||
max_error = std::max(max_error, error);
|
||||
if(error > 1e-2f) {
|
||||
if(error_count < 5) {
|
||||
max_error = std::max(max_error, error);
|
||||
if(error > 1e-2f)
|
||||
{
|
||||
if(error_count < 5)
|
||||
{
|
||||
index_t m = i % M;
|
||||
index_t n = i / M;
|
||||
std::cout << "Error at [" << m << "," << n << "]: "
|
||||
<< h_d[i] << " vs " << h_d_ref[i]
|
||||
<< " (diff=" << error << ")\n";
|
||||
std::cout << "Error at [" << m << "," << n << "]: " << h_d[i] << " vs "
|
||||
<< h_d_ref[i] << " (diff=" << error << ")\n";
|
||||
}
|
||||
error_count++;
|
||||
}
|
||||
@@ -514,14 +517,15 @@ int main()
|
||||
|
||||
passed = (error_count == 0);
|
||||
|
||||
double gflops = 2.0 * M * N * K / 1e9;
|
||||
double gflops = 2.0 * M * N * K / 1e9;
|
||||
double gpu_tflops = gflops / (gpu_time_ms / 1000);
|
||||
double cpu_gflops = gflops / (cpu_time_ms / 1000);
|
||||
|
||||
std::cout << "Results:\n";
|
||||
std::cout << " Correctness: " << (passed ? "✓ PASSED" : "✗ FAILED") << "\n";
|
||||
std::cout << " Max error: " << max_error << "\n";
|
||||
if(!passed) std::cout << " Error count: " << error_count << "/" << M*N << "\n";
|
||||
if(!passed)
|
||||
std::cout << " Error count: " << error_count << "/" << M * N << "\n";
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "Performance:\n";
|
||||
|
||||
Reference in New Issue
Block a user