Setup build environment. Format source code.

This commit is contained in:
Andriy Roshchenko
2026-01-22 01:22:51 +00:00
parent 576956298c
commit 483c7696c0
17 changed files with 1902 additions and 1535 deletions

View File

@@ -9,10 +9,10 @@ include_directories(AFTER
# Each stage builds on the previous one
# Stage 00: Hello ck_tile - First program
add_subdirectory(stage_00_hello_ck_tile)
#add_subdirectory(stage_00_hello_ck_tile)
# Stage 01: Tile Distribution - Understanding thread-to-data mapping
add_subdirectory(stage_01_tile_distribution)
#add_subdirectory(stage_01_tile_distribution)
# Stage 02: 2D Tensors and Windows
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/stage_02_2d_tensors/CMakeLists.txt)

View File

@@ -25,91 +25,85 @@ using namespace ck_tile;
// Note: These functions demonstrate the API but may be scalarized by the compiler
// when returning by value. For true vectorization, use get_vectorized_elements inline.
template<typename DataType>
CK_TILE_DEVICE thread_buffer<DataType, 2>
vectorized_read_2(const DataType* p_data, index_t offset)
template <typename DataType>
CK_TILE_DEVICE thread_buffer<DataType, 2> vectorized_read_2(const DataType* p_data, index_t offset)
{
auto view = make_naive_tensor_view<address_space_enum::global>(
p_data,
make_tuple(12), // total elements
make_tuple(1), // stride
number<2>{}, // GuaranteedLastDimensionVectorLength
number<1>{} // GuaranteedLastDimensionVectorStride
make_tuple(12), // total elements
make_tuple(1), // stride
number<2>{}, // GuaranteedLastDimensionVectorLength
number<1>{} // GuaranteedLastDimensionVectorStride
);
auto desc = view.get_tensor_descriptor();
auto desc = view.get_tensor_descriptor();
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
return view.template get_vectorized_elements<thread_buffer<DataType, 2>>(coord, 0);
}
template<typename DataType>
CK_TILE_DEVICE thread_buffer<DataType, 4>
vectorized_read_4(const DataType* p_data, index_t offset)
template <typename DataType>
CK_TILE_DEVICE thread_buffer<DataType, 4> vectorized_read_4(const DataType* p_data, index_t offset)
{
auto view = make_naive_tensor_view<address_space_enum::global>(
p_data,
make_tuple(12), // total elements
make_tuple(1), // stride
number<4>{}, // GuaranteedLastDimensionVectorLength
number<1>{} // GuaranteedLastDimensionVectorStride
make_tuple(12), // total elements
make_tuple(1), // stride
number<4>{}, // GuaranteedLastDimensionVectorLength
number<1>{} // GuaranteedLastDimensionVectorStride
);
auto desc = view.get_tensor_descriptor();
auto desc = view.get_tensor_descriptor();
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
return view.template get_vectorized_elements<thread_buffer<DataType, 4>>(coord, 0);
}
template<typename DataType>
CK_TILE_DEVICE void
template <typename DataType>
CK_TILE_DEVICE void
vectorized_write_4(DataType* p_data, index_t offset, thread_buffer<DataType, 4> buffer)
{
auto view = make_naive_tensor_view<address_space_enum::global>(
p_data,
make_tuple(12), // total elements
make_tuple(1) // stride
auto view = make_naive_tensor_view<address_space_enum::global>(p_data,
make_tuple(12), // total elements
make_tuple(1) // stride
);
auto desc = view.get_tensor_descriptor();
auto desc = view.get_tensor_descriptor();
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
view.set_vectorized_elements(coord, 0, buffer);
}
// Additional functions with fp16 to demonstrate vectorization with smaller types
CK_TILE_DEVICE thread_buffer<half_t, 4>
vectorized_read_4_fp16(const half_t* p_data, index_t offset)
CK_TILE_DEVICE thread_buffer<half_t, 4> vectorized_read_4_fp16(const half_t* p_data, index_t offset)
{
auto view = make_naive_tensor_view<address_space_enum::global>(
p_data,
make_tuple(24), // total elements (more for fp16)
make_tuple(1) // stride
make_tuple(24), // total elements (more for fp16)
make_tuple(1) // stride
);
auto desc = view.get_tensor_descriptor();
auto desc = view.get_tensor_descriptor();
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
return view.template get_vectorized_elements<thread_buffer<half_t, 4>>(coord, 0);
}
CK_TILE_DEVICE thread_buffer<half_t, 8>
vectorized_read_8_fp16(const half_t* p_data, index_t offset)
CK_TILE_DEVICE thread_buffer<half_t, 8> vectorized_read_8_fp16(const half_t* p_data, index_t offset)
{
auto view = make_naive_tensor_view<address_space_enum::global>(
p_data,
make_tuple(24), // total elements
make_tuple(1) // stride
auto view = make_naive_tensor_view<address_space_enum::global>(p_data,
make_tuple(24), // total elements
make_tuple(1) // stride
);
auto desc = view.get_tensor_descriptor();
auto desc = view.get_tensor_descriptor();
auto coord = make_tensor_coordinate(desc, make_tuple(offset));
return view.template get_vectorized_elements<thread_buffer<half_t, 8>>(coord, 0);
}
// The kernel that demonstrates all fundamental concepts
template<typename DataType>
template <typename DataType>
struct TensorFundamentalsKernel
{
static constexpr index_t kBlockSize = 64;
CK_TILE_DEVICE void operator()(const DataType* p_input,
DataType* p_output,
index_t H, index_t W, index_t C) const
CK_TILE_DEVICE void
operator()(const DataType* p_input, DataType* p_output, index_t H, index_t W, index_t C) const
{
// Only thread 0 for clean output
if(get_thread_id() != 0) return;
if(get_thread_id() != 0)
return;
printf("\n=== TENSOR FUNDAMENTALS IN CK_TILE ===\n\n");
@@ -123,25 +117,31 @@ struct TensorFundamentalsKernel
// It contains: lengths (shape) + strides (memory layout)
// Create a descriptor for [H,W,C] tensor in row-major layout
auto hwc_descriptor = make_naive_tensor_descriptor(
make_tuple(H, W, C), // lengths: [2, 3, 2]
make_tuple(W*C, C, 1) // strides: [6, 2, 1] for row-major
);
auto hwc_descriptor =
make_naive_tensor_descriptor(make_tuple(H, W, C), // lengths: [2, 3, 2]
make_tuple(W * C, C, 1) // strides: [6, 2, 1] for row-major
);
// Access descriptor properties
auto lengths = hwc_descriptor.get_lengths();
// Note: Descriptors don't expose strides directly after transformation
printf("Descriptor for [H=%ld, W=%ld, C=%ld] tensor:\n",
static_cast<long>(H), static_cast<long>(W), static_cast<long>(C));
static_cast<long>(H),
static_cast<long>(W),
static_cast<long>(C));
printf(" Lengths: [%ld, %ld, %ld]\n",
static_cast<long>(lengths.at(number<0>{})),
static_cast<long>(lengths.at(number<1>{})),
static_cast<long>(lengths.at(number<2>{})));
printf(" Strides: [%ld, %ld, %ld] (row-major)\n",
static_cast<long>(W*C), static_cast<long>(C), static_cast<long>(1));
static_cast<long>(W * C),
static_cast<long>(C),
static_cast<long>(1));
printf(" Memory formula: offset = h*%ld + w*%ld + c*%ld\n\n",
static_cast<long>(W*C), static_cast<long>(C), static_cast<long>(1));
static_cast<long>(W * C),
static_cast<long>(C),
static_cast<long>(1));
//==================================================================
// PART 2: TENSOR VIEW - Three Creation Methods
@@ -152,46 +152,51 @@ struct TensorFundamentalsKernel
// Method 1: Explicit strides (most control)
printf("Method 1: make_naive_tensor_view with explicit strides\n");
auto view1 = make_naive_tensor_view<address_space_enum::global>(
p_input, // GPU memory pointer
make_tuple(H, W, C), // lengths
make_tuple(W*C, C, 1) // explicit strides
p_input, // GPU memory pointer
make_tuple(H, W, C), // lengths
make_tuple(W * C, C, 1) // explicit strides
);
printf(" Created view with shape [%ld,%ld,%ld] and strides [%ld,%ld,%ld]\n",
static_cast<long>(H), static_cast<long>(W), static_cast<long>(C),
static_cast<long>(W*C), static_cast<long>(C), static_cast<long>(1));
static_cast<long>(H),
static_cast<long>(W),
static_cast<long>(C),
static_cast<long>(W * C),
static_cast<long>(C),
static_cast<long>(1));
// Method 2: Packed/contiguous (auto-computes row-major strides)
printf("\nMethod 2: make_naive_tensor_view_packed (auto strides)\n");
auto view2 = make_naive_tensor_view_packed<address_space_enum::global>(
p_input, // GPU memory pointer
make_tuple(H, W, C) // lengths only, strides auto-computed
p_input, // GPU memory pointer
make_tuple(H, W, C) // lengths only, strides auto-computed
);
printf(" Created packed view - strides computed automatically\n");
printf(" For row-major: last dim stride=1, each dim stride = next_dim_stride * next_dim_length\n");
printf(" For row-major: last dim stride=1, each dim stride = next_dim_stride * "
"next_dim_length\n");
// Method 3: From existing descriptor
printf("\nMethod 3: make_tensor_view from descriptor\n");
auto view3 = make_tensor_view<address_space_enum::global>(
p_input, // GPU memory pointer
hwc_descriptor // existing descriptor
);
auto view3 =
make_tensor_view<address_space_enum::global>(p_input, // GPU memory pointer
hwc_descriptor // existing descriptor
);
printf(" Created view using pre-existing descriptor\n");
// Demonstrate all three views access the same data
printf("\nVerifying all three methods create equivalent views:\n");
{
auto coord_test = make_tensor_coordinate(
view1.get_tensor_descriptor(), make_tuple(0, 1, 0));
auto coord_test =
make_tensor_coordinate(view1.get_tensor_descriptor(), make_tuple(0, 1, 0));
auto val1 = view1.template get_vectorized_elements<thread_buffer<DataType, 1>>(
coord_test, 0)[number<0>{}];
auto coord_test2 = make_tensor_coordinate(
view2.get_tensor_descriptor(), make_tuple(0, 1, 0));
auto coord_test2 =
make_tensor_coordinate(view2.get_tensor_descriptor(), make_tuple(0, 1, 0));
auto val2 = view2.template get_vectorized_elements<thread_buffer<DataType, 1>>(
coord_test2, 0)[number<0>{}];
auto coord_test3 = make_tensor_coordinate(
view3.get_tensor_descriptor(), make_tuple(0, 1, 0));
auto coord_test3 =
make_tensor_coordinate(view3.get_tensor_descriptor(), make_tuple(0, 1, 0));
auto val3 = view3.template get_vectorized_elements<thread_buffer<DataType, 1>>(
coord_test3, 0)[number<0>{}];
@@ -217,11 +222,12 @@ struct TensorFundamentalsKernel
// Coordinate can compute its linear offset
index_t offset = coord.get_offset();
printf("Coordinate [1,2,0] maps to linear offset: %ld\n",
static_cast<long>(offset));
printf("Coordinate [1,2,0] maps to linear offset: %ld\n", static_cast<long>(offset));
printf(" Calculation: 1*%ld + 2*%ld + 0*%ld = %ld\n\n",
static_cast<long>(W*C), static_cast<long>(C),
static_cast<long>(1), static_cast<long>(offset));
static_cast<long>(W * C),
static_cast<long>(C),
static_cast<long>(1),
static_cast<long>(offset));
//==================================================================
// PART 4: ELEMENT ACCESS - The Critical Pattern
@@ -237,8 +243,8 @@ struct TensorFundamentalsKernel
// get_vectorized_elements returns thread_buffer<T,N>, not T!
auto buffer = view1.template get_vectorized_elements<thread_buffer<DataType, 1>>(
read_coord, // coordinate
0 // linear_offset (usually 0)
read_coord, // coordinate
0 // linear_offset (usually 0)
);
// Extract actual value from thread_buffer
@@ -261,10 +267,7 @@ struct TensorFundamentalsKernel
// Write to output view
auto output_view = make_naive_tensor_view<address_space_enum::global>(
p_output,
make_tuple(H, W, C),
make_tuple(W*C, C, 1)
);
p_output, make_tuple(H, W, C), make_tuple(W * C, C, 1));
output_view.set_vectorized_elements(write_coord, 0, write_buffer);
@@ -284,8 +287,8 @@ struct TensorFundamentalsKernel
// Create a flattened view for easier vectorized access
auto flat_view = make_naive_tensor_view<address_space_enum::global>(
p_input,
make_tuple(H*W*C), // [12] - all elements in linear order
make_tuple(1) // stride = 1 (contiguous)
make_tuple(H * W * C), // [12] - all elements in linear order
make_tuple(1) // stride = 1 (contiguous)
);
auto flat_desc = flat_view.get_tensor_descriptor();
@@ -294,10 +297,10 @@ struct TensorFundamentalsKernel
{
// Call the vectorized_read_2 function (easy to disassemble in debugger)
auto buffer = vectorized_read_2(p_input, 0);
DataType val0 = buffer[number<0>{}];
DataType val1 = buffer[number<1>{}];
printf(" Position [0]: Read 2 elements in one operation\n");
printf(" buffer[0] = %.0f\n", static_cast<float>(val0));
printf(" buffer[1] = %.0f\n", static_cast<float>(val1));
@@ -310,7 +313,7 @@ struct TensorFundamentalsKernel
{
// Call the vectorized_read_4 function (easy to disassemble in debugger)
auto buffer = vectorized_read_4(p_input, 4);
printf(" Position [4]: Read 4 elements in one operation\n");
printf(" buffer[0] = %.0f\n", static_cast<float>(buffer[number<0>{}]));
printf(" buffer[1] = %.0f\n", static_cast<float>(buffer[number<1>{}]));
@@ -329,10 +332,10 @@ struct TensorFundamentalsKernel
write_buffer[number<1>{}] = 101.0f;
write_buffer[number<2>{}] = 102.0f;
write_buffer[number<3>{}] = 103.0f;
// Call the vectorized_write_4 function (easy to disassemble in debugger)
vectorized_write_4(p_output, 4, write_buffer);
printf(" Position [4-7]: Wrote 4 elements in one operation\n");
printf(" Wrote: 100, 101, 102, 103\n");
printf(" ✓ 4x faster than writing elements individually!\n");
@@ -343,30 +346,32 @@ struct TensorFundamentalsKernel
printf("Example 4: Vectorized copy - INLINE usage (TRUE vectorization)\n");
{
auto output_flat_view = make_naive_tensor_view<address_space_enum::global>(
p_output,
make_tuple(H*W*C),
make_tuple(1)
);
p_output, make_tuple(H * W * C), make_tuple(1));
auto out_flat_desc = output_flat_view.get_tensor_descriptor();
// Copy first 8 elements using vector size 4 (2 iterations)
// THIS is where real vectorization happens - inline, no function calls!
printf(" Copying first 8 elements using 2 vectorized operations:\n");
for(index_t i = 0; i < 8; i += 4) {
auto in_coord = make_tensor_coordinate(flat_desc, make_tuple(i));
for(index_t i = 0; i < 8; i += 4)
{
auto in_coord = make_tensor_coordinate(flat_desc, make_tuple(i));
auto out_coord = make_tensor_coordinate(out_flat_desc, make_tuple(i));
// Read 4 elements - INLINE vectorized load (not through function)
auto buffer = flat_view.template get_vectorized_elements<
thread_buffer<DataType, 4>>(in_coord, 0);
auto buffer =
flat_view.template get_vectorized_elements<thread_buffer<DataType, 4>>(in_coord,
0);
// Write 4 elements (skip positions 4-7 which we already wrote)
if(i != 4) {
if(i != 4)
{
output_flat_view.set_vectorized_elements(out_coord, 0, buffer);
}
printf(" Iteration %ld: Copied elements [%ld-%ld]\n",
static_cast<long>(i/4), static_cast<long>(i), static_cast<long>(i+3));
printf(" Iteration %ld: Copied elements [%ld-%ld]\n",
static_cast<long>(i / 4),
static_cast<long>(i),
static_cast<long>(i + 3));
}
printf(" ✓ Copied 8 elements with only 2 memory operations!\n");
printf(" ✓ THIS loop shows true vectorization in assembly!\n\n");
@@ -390,45 +395,39 @@ struct TensorFundamentalsKernel
// Create two different views of the same memory
// View A: [H, W, C] = [2, 3, 2]
auto view_hwc = make_naive_tensor_view<address_space_enum::global>(
p_input,
make_tuple(H, W, C),
make_tuple(W*C, C, 1)
);
p_input, make_tuple(H, W, C), make_tuple(W * C, C, 1));
// View B: [HW, C] = [6, 2] - flattened spatial dimensions
auto view_hw_c = make_naive_tensor_view<address_space_enum::global>(
p_input,
make_tuple(H*W, C),
make_tuple(C, 1)
);
p_input, make_tuple(H * W, C), make_tuple(C, 1));
printf("Two views of same memory:\n");
printf(" View A: [H=%ld, W=%ld, C=%ld]\n",
static_cast<long>(H), static_cast<long>(W), static_cast<long>(C));
printf(" View B: [HW=%ld, C=%ld]\n",
static_cast<long>(H*W), static_cast<long>(C));
static_cast<long>(H),
static_cast<long>(W),
static_cast<long>(C));
printf(" View B: [HW=%ld, C=%ld]\n", static_cast<long>(H * W), static_cast<long>(C));
// Show they access the same data
printf("\nAccessing same element through different views:\n");
// Access element at h=1, w=1, c=0 through View A
auto desc_a = view_hwc.get_tensor_descriptor();
auto desc_a = view_hwc.get_tensor_descriptor();
auto coord_a = make_tensor_coordinate(desc_a, make_tuple(1, 1, 0));
auto buffer_a = view_hwc.template get_vectorized_elements<thread_buffer<DataType, 1>>(
coord_a, 0);
auto buffer_a =
view_hwc.template get_vectorized_elements<thread_buffer<DataType, 1>>(coord_a, 0);
DataType val_a = buffer_a[number<0>{}];
// Access same element through View B at hw=4 (1*3+1), c=0
auto desc_b = view_hw_c.get_tensor_descriptor();
auto desc_b = view_hw_c.get_tensor_descriptor();
auto coord_b = make_tensor_coordinate(desc_b, make_tuple(4, 0));
auto buffer_b = view_hw_c.template get_vectorized_elements<thread_buffer<DataType, 1>>(
coord_b, 0);
auto buffer_b =
view_hw_c.template get_vectorized_elements<thread_buffer<DataType, 1>>(coord_b, 0);
DataType val_b = buffer_b[number<0>{}];
printf(" View A[1,1,0] = %.0f\n", static_cast<float>(val_a));
printf(" View B[4,0] = %.0f (same value!)\n", static_cast<float>(val_b));
printf(" Both access offset %ld in memory\n\n",
static_cast<long>(coord_a.get_offset()));
printf(" Both access offset %ld in memory\n\n", static_cast<long>(coord_a.get_offset()));
//==================================================================
// PART 6: PRACTICAL EXAMPLE - Copy with Views
@@ -438,24 +437,26 @@ struct TensorFundamentalsKernel
// Create output view
auto output_view = make_naive_tensor_view<address_space_enum::global>(
p_output,
make_tuple(H, W, C),
make_tuple(W*C, C, 1)
);
p_output, make_tuple(H, W, C), make_tuple(W * C, C, 1));
auto out_desc = output_view.get_tensor_descriptor();
// Copy all elements using tensor_view API
index_t count = 0;
for(index_t h = 0; h < H; h++) {
for(index_t w = 0; w < W; w++) {
for(index_t c = 0; c < C; c++) {
for(index_t h = 0; h < H; h++)
{
for(index_t w = 0; w < W; w++)
{
for(index_t c = 0; c < C; c++)
{
// Read from input
auto in_coord = make_tensor_coordinate(desc, make_tuple(h, w, c));
auto in_buffer = view1.template get_vectorized_elements<
thread_buffer<DataType, 1>>(in_coord, 0);
auto in_buffer =
view1.template get_vectorized_elements<thread_buffer<DataType, 1>>(in_coord,
0);
// Write to output (except [0,0,1] which we already wrote as 99)
if(!(h == 0 && w == 0 && c == 1)) {
if(!(h == 0 && w == 0 && c == 1))
{
auto out_coord = make_tensor_coordinate(out_desc, make_tuple(h, w, c));
output_view.set_vectorized_elements(out_coord, 0, in_buffer);
}
@@ -489,7 +490,8 @@ int main()
// Initialize HIP
int device_count;
hip_check_error(hipGetDeviceCount(&device_count));
if(device_count == 0) {
if(device_count == 0)
{
std::cerr << "No GPU devices found!\n";
return 1;
}
@@ -500,36 +502,42 @@ int main()
std::cout << "Using GPU: " << props.name << "\n";
// Small tensor for demonstration
constexpr index_t H = 2;
constexpr index_t W = 3;
constexpr index_t C = 2;
constexpr index_t H = 2;
constexpr index_t W = 3;
constexpr index_t C = 2;
constexpr index_t size = H * W * C;
std::cout << "\nTensor configuration:\n";
std::cout << " Shape: [" << H << ", " << W << ", " << C << "]\n";
std::cout << " Total elements: " << size << "\n";
std::cout << " Layout: Row-major (strides = [" << W*C << ", " << C << ", 1])\n\n";
std::cout << " Layout: Row-major (strides = [" << W * C << ", " << C << ", 1])\n\n";
// Create test data: 1, 2, 3, 4, ... 12
std::vector<float> h_input(size);
std::iota(h_input.begin(), h_input.end(), 1.0f);
std::cout << "Input data (row-major memory order):\n";
for(index_t i = 0; i < size; ++i) {
if(i % C == 0 && i > 0) std::cout << " | ";
for(index_t i = 0; i < size; ++i)
{
if(i % C == 0 && i > 0)
std::cout << " | ";
std::cout << std::setw(2) << h_input[i] << " ";
}
std::cout << "\n";
std::cout << "\nLogical view [H,W,C]:\n";
for(index_t h = 0; h < H; h++) {
for(index_t h = 0; h < H; h++)
{
std::cout << " H=" << h << ": ";
for(index_t w = 0; w < W; w++) {
for(index_t w = 0; w < W; w++)
{
std::cout << "[";
for(index_t c = 0; c < C; c++) {
for(index_t c = 0; c < C; c++)
{
index_t idx = h * W * C + w * C + c;
std::cout << std::setw(2) << h_input[idx];
if(c < C-1) std::cout << ",";
if(c < C - 1)
std::cout << ",";
}
std::cout << "] ";
}
@@ -555,14 +563,15 @@ int main()
std::cout << "=====================================\n";
launch_kernel(stream,
make_kernel<block_size>(
TensorFundamentalsKernel<float>{},
dim3(1), // 1 block
dim3(block_size), // 64 threads
0, // no shared memory
static_cast<const float*>(d_input.GetDeviceBuffer()),
static_cast<float*>(d_output.GetDeviceBuffer()),
H, W, C));
make_kernel<block_size>(TensorFundamentalsKernel<float>{},
dim3(1), // 1 block
dim3(block_size), // 64 threads
0, // no shared memory
static_cast<const float*>(d_input.GetDeviceBuffer()),
static_cast<float*>(d_output.GetDeviceBuffer()),
H,
W,
C));
hip_check_error(hipDeviceSynchronize());
std::cout << "=====================================\n";
@@ -574,17 +583,19 @@ int main()
// Verify results
std::cout << "\nOutput verification:\n";
bool passed = true;
for(index_t i = 0; i < size; ++i) {
float expected = (i == 1) ? 99.0f : h_input[i]; // We wrote 99 to position [0,0,1]
if(std::abs(h_output[i] - expected) > 1e-6f) {
for(index_t i = 0; i < size; ++i)
{
float expected = (i == 1) ? 99.0f : h_input[i]; // We wrote 99 to position [0,0,1]
if(std::abs(h_output[i] - expected) > 1e-6f)
{
passed = false;
std::cout << " ✗ Mismatch at index " << i
<< ": expected " << expected
<< ", got " << h_output[i] << "\n";
std::cout << " ✗ Mismatch at index " << i << ": expected " << expected << ", got "
<< h_output[i] << "\n";
}
}
if(passed) {
if(passed)
{
std::cout << " ✓ All elements correct!\n";
std::cout << " ✓ output[0,0,1] = 99 (modified as expected)\n";
std::cout << " ✓ All other elements copied correctly\n";

View File

@@ -22,7 +22,7 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TensorAdaptorsKernel
{
static constexpr index_t kBlockSize = 64;
@@ -33,43 +33,43 @@ struct TensorAdaptorsKernel
printf("PART 1: make_single_stage_tensor_adaptor\n");
printf("=========================================\n\n");
printf("Purpose: Create a tensor adaptor with transformations applied in a single stage.\n");
printf(
"Purpose: Create a tensor adaptor with transformations applied in a single stage.\n");
printf("This is the foundation for building complex layout transformations.\n\n");
// Example 1.1: Simple dimension split (Unmerge)
printf("Example 1.1: Split M dimension for tiling\n");
printf("------------------------------------------\n");
{
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M0 = 4;
constexpr index_t M1 = 32;
printf("Input layout: [M=%ld, K=%ld]\n", static_cast<long>(M), static_cast<long>(K));
printf("Goal: Split M into [M0=%ld, M1=%ld] for tiling\n",
static_cast<long>(M0), static_cast<long>(M1));
printf("Goal: Split M into [M0=%ld, M1=%ld] for tiling\n",
static_cast<long>(M0),
static_cast<long>(M1));
auto transforms = make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})
);
auto transforms =
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{}));
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{});
auto upper_dims = make_tuple(sequence<0, 1>{}, sequence<2>{});
auto adaptor = make_single_stage_tensor_adaptor(
transforms, lower_dims, upper_dims
);
auto adaptor = make_single_stage_tensor_adaptor(transforms, lower_dims, upper_dims);
printf("\nAdaptor created:\n");
printf(" Input: [M, K] = [%ld, %ld]\n",
static_cast<long>(M), static_cast<long>(K));
printf(" Input: [M, K] = [%ld, %ld]\n", static_cast<long>(M), static_cast<long>(K));
printf(" Output: [M0, M1, K] = [%ld, %ld, %ld]\n",
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K));
auto top_idx = make_tuple(1, 16, 32);
auto top_idx = make_tuple(1, 16, 32);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
printf("\nTest: [M0=1, M1=16, K=32] -> [M=%ld, K=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
static_cast<long>(bottom_idx[number<1>{}]));
@@ -81,8 +81,8 @@ struct TensorAdaptorsKernel
printf("Example 1.2: GEMM C Matrix Tiling (Interleaved)\n");
printf("------------------------------------------------\n");
{
constexpr index_t M = 256;
constexpr index_t N = 256;
constexpr index_t M = 256;
constexpr index_t N = 256;
constexpr index_t M0 = 4;
constexpr index_t M1 = 64;
constexpr index_t N0 = 4;
@@ -90,25 +90,23 @@ struct TensorAdaptorsKernel
printf("Input: [M=%ld, N=%ld]\n", static_cast<long>(M), static_cast<long>(N));
printf("Output: [M0=%ld, N0=%ld, M1=%ld, N1=%ld] (interleaved)\n",
static_cast<long>(M0), static_cast<long>(N0),
static_cast<long>(M1), static_cast<long>(N1));
static_cast<long>(M0),
static_cast<long>(N0),
static_cast<long>(M1),
static_cast<long>(N1));
auto transforms = make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<N0>{}, number<N1>{}))
);
auto transforms =
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<N0>{}, number<N1>{})));
auto lower_dims = make_tuple(sequence<0>{}, sequence<1>{});
auto upper_dims = make_tuple(
sequence<0, 2>{}, // M splits to dims 0,2
sequence<1, 3>{} // N splits to dims 1,3
auto upper_dims = make_tuple(sequence<0, 2>{}, // M splits to dims 0,2
sequence<1, 3>{} // N splits to dims 1,3
);
auto adaptor = make_single_stage_tensor_adaptor(
transforms, lower_dims, upper_dims
);
auto adaptor = make_single_stage_tensor_adaptor(transforms, lower_dims, upper_dims);
auto top_idx = make_tuple(2, 3, 16, 32);
auto top_idx = make_tuple(2, 3, 16, 32);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
printf("\nTest: [M0=2, N0=3, M1=16, N1=32] -> [M=%ld, N=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
@@ -130,43 +128,44 @@ struct TensorAdaptorsKernel
printf("Example 2.1: Two-Stage Hierarchical Tiling\n");
printf("-------------------------------------------\n");
{
constexpr index_t M = 256;
constexpr index_t K = 128;
constexpr index_t M = 256;
constexpr index_t K = 128;
constexpr index_t M0 = 4;
constexpr index_t M1 = 64;
constexpr index_t K0 = 4;
constexpr index_t K1 = 32;
printf("Stage 1: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(M), static_cast<long>(K),
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
static_cast<long>(M),
static_cast<long>(K),
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K));
auto stage1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})
),
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{})
);
make_tuple(sequence<0, 1>{}, sequence<2>{}));
printf("Stage 2: [M0=%ld, M1=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K),
static_cast<long>(M0), static_cast<long>(M1),
static_cast<long>(K0), static_cast<long>(K1));
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K),
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K0),
static_cast<long>(K1));
auto final_adaptor = transform_tensor_adaptor(
stage1_adaptor,
make_tuple(
make_pass_through_transform(number<M0>{}),
make_pass_through_transform(number<M1>{}),
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))
),
make_tuple(make_pass_through_transform(number<M0>{}),
make_pass_through_transform(number<M1>{}),
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{})
);
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}));
auto top_idx = make_tuple(2, 32, 3, 16);
auto top_idx = make_tuple(2, 32, 3, 16);
auto bottom_idx = final_adaptor.calculate_bottom_index(top_idx);
printf("\nTest: [M0=2, M1=32, K0=3, K1=16] -> [M=%ld, K=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
@@ -187,49 +186,53 @@ struct TensorAdaptorsKernel
printf("Example 3.1: Chain Two Adaptors\n");
printf("--------------------------------\n");
{
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M0 = 4;
constexpr index_t M1 = 32;
constexpr index_t K0 = 4;
constexpr index_t K1 = 16;
printf("Adaptor A: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(M), static_cast<long>(K),
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
static_cast<long>(M),
static_cast<long>(K),
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K));
auto adaptor_a = make_single_stage_tensor_adaptor(
make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})
),
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{})
);
make_tuple(sequence<0, 1>{}, sequence<2>{}));
printf("Adaptor B: [M0=%ld, M1=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K),
static_cast<long>(M0), static_cast<long>(M1),
static_cast<long>(K0), static_cast<long>(K1));
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K),
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K0),
static_cast<long>(K1));
auto adaptor_b = make_single_stage_tensor_adaptor(
make_tuple(
make_pass_through_transform(number<M0>{}),
make_pass_through_transform(number<M1>{}),
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))
),
make_tuple(make_pass_through_transform(number<M0>{}),
make_pass_through_transform(number<M1>{}),
make_unmerge_transform(make_tuple(number<K0>{}, number<K1>{}))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{})
);
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}));
auto chained = chain_tensor_adaptors(adaptor_a, adaptor_b);
printf("\nChained: [M=%ld, K=%ld] -> [M0=%ld, M1=%ld, K0=%ld, K1=%ld]\n",
static_cast<long>(M), static_cast<long>(K),
static_cast<long>(M0), static_cast<long>(M1),
static_cast<long>(K0), static_cast<long>(K1));
static_cast<long>(M),
static_cast<long>(K),
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K0),
static_cast<long>(K1));
auto top_idx = make_tuple(2, 16, 3, 8);
auto top_idx = make_tuple(2, 16, 3, 8);
auto bottom_idx = chained.calculate_bottom_index(top_idx);
printf("Test: [M0=2, M1=16, K0=3, K1=8] -> [M=%ld, K=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
@@ -245,32 +248,33 @@ struct TensorAdaptorsKernel
printf("PART 4: Real-World GEMM Tiling Example\n");
printf("=======================================\n\n");
constexpr index_t M = 256;
constexpr index_t N = 256;
constexpr index_t MWaves = 4;
constexpr index_t NWaves = 4;
constexpr index_t M = 256;
constexpr index_t N = 256;
constexpr index_t MWaves = 4;
constexpr index_t NWaves = 4;
constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 16;
constexpr index_t M0 = M / (MWaves * MPerXDL);
constexpr index_t N0 = N / (NWaves * NPerXDL);
constexpr index_t M0 = M / (MWaves * MPerXDL);
constexpr index_t N0 = N / (NWaves * NPerXDL);
printf("GEMM C Matrix: [M=%ld, N=%ld]\n",
static_cast<long>(M), static_cast<long>(N));
printf("GEMM C Matrix: [M=%ld, N=%ld]\n", static_cast<long>(M), static_cast<long>(N));
printf("Tiling: [M0=%ld, N0=%ld, M1=%ld, N1=%ld, M2=%ld, N2=%ld]\n",
static_cast<long>(M0), static_cast<long>(N0),
static_cast<long>(MWaves), static_cast<long>(NWaves),
static_cast<long>(MPerXDL), static_cast<long>(NPerXDL));
static_cast<long>(M0),
static_cast<long>(N0),
static_cast<long>(MWaves),
static_cast<long>(NWaves),
static_cast<long>(MPerXDL),
static_cast<long>(NPerXDL));
auto adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<MWaves>{}, number<MPerXDL>{})),
make_unmerge_transform(make_tuple(number<N0>{}, number<NWaves>{}, number<NPerXDL>{}))
),
make_tuple(make_unmerge_transform(
make_tuple(number<M0>{}, number<MWaves>{}, number<MPerXDL>{})),
make_unmerge_transform(
make_tuple(number<N0>{}, number<NWaves>{}, number<NPerXDL>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{})
);
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}));
auto top_idx = make_tuple(2, 3, 1, 2, 8, 12);
auto top_idx = make_tuple(2, 3, 1, 2, 8, 12);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
printf("\nTest: [M0=2, N0=3, M1=1, N1=2, M2=8, N2=12] -> [M=%ld, N=%ld]\n",
static_cast<long>(bottom_idx[number<0>{}]),
@@ -288,8 +292,8 @@ struct TensorAdaptorsKernel
printf("Demonstrating padding transform with coordinate mapping.\n\n");
// Original size: 10 elements, pad to 16
constexpr index_t OrigSize = 10;
constexpr index_t PadRight = 6;
constexpr index_t OrigSize = 10;
constexpr index_t PadRight = 6;
constexpr index_t TotalSize = OrigSize + PadRight;
printf("Original size: %ld elements\n", static_cast<long>(OrigSize));
@@ -301,30 +305,35 @@ struct TensorAdaptorsKernel
make_naive_tensor_descriptor_packed(make_tuple(number<OrigSize>{})),
make_tuple(make_right_pad_transform(number<OrigSize>{}, number<PadRight>{})),
make_tuple(sequence<0>{}),
make_tuple(sequence<0>{})
);
make_tuple(sequence<0>{}));
printf("Coordinate mapping and memory reads:\n");
printf("------------------------------------\n\n");
printf("Real area (indices 0-9):\n");
for(index_t i = 0; i < OrigSize; i++) {
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
for(index_t i = 0; i < OrigSize; i++)
{
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
index_t offset = coord.get_offset();
DataType val = p_data[offset];
DataType val = p_data[offset];
printf(" Index %ld -> offset %ld -> value %.1f (real data)\n",
static_cast<long>(i), static_cast<long>(offset), static_cast<float>(val));
static_cast<long>(i),
static_cast<long>(offset),
static_cast<float>(val));
}
printf("\nPadded area (indices 10-15):\n");
for(index_t i = OrigSize; i < TotalSize; i++) {
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
for(index_t i = OrigSize; i < TotalSize; i++)
{
auto coord = make_tensor_coordinate(desc_padded, make_tuple(i));
index_t offset = coord.get_offset();
DataType val = p_data[offset];
DataType val = p_data[offset];
printf(" Index %ld -> offset %ld -> value %.1f (wraps around)\n",
static_cast<long>(i), static_cast<long>(offset), static_cast<float>(val));
static_cast<long>(i),
static_cast<long>(offset),
static_cast<float>(val));
}
printf("\nKey Observations:\n");
@@ -344,13 +353,11 @@ struct TensorAdaptorsKernel
printf("Demonstrating replicate transform with complete coordinate mapping.\n\n");
// Start with flattened 1D tensor
constexpr index_t Size = 16; // H*W = 2*8
constexpr index_t Size = 16; // H*W = 2*8
printf("Step 1: Create initial 1D descriptor [Size=%ld]\n", static_cast<long>(Size));
auto desc = make_naive_tensor_descriptor_packed(
make_tuple(number<Size>{})
);
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<Size>{}));
printf(" Initial: [16] (flattened)\n\n");
@@ -362,11 +369,11 @@ struct TensorAdaptorsKernel
auto desc_stage1 = transform_tensor_descriptor(
desc,
make_tuple(
make_replicate_transform(make_tuple(number<8>{})), // Broadcast to 8
make_unmerge_transform(make_tuple(number<8>{}, number<2>{})) // Split 16 -> [8,2]
),
make_tuple(sequence<>{}, sequence<0>{}), // Replicate has no input, Unmerge uses dim 0
make_tuple(sequence<0>{}, sequence<1, 2>{}) // Rep0=dim0, Unmerge produces dims 1,2
make_replicate_transform(make_tuple(number<8>{})), // Broadcast to 8
make_unmerge_transform(make_tuple(number<8>{}, number<2>{})) // Split 16 -> [8,2]
),
make_tuple(sequence<>{}, sequence<0>{}), // Replicate has no input, Unmerge uses dim 0
make_tuple(sequence<0>{}, sequence<1, 2>{}) // Rep0=dim0, Unmerge produces dims 1,2
);
printf("\n After Stage 1: [Rep0=8, Dim0=8, Dim1=2]\n");
@@ -378,55 +385,63 @@ struct TensorAdaptorsKernel
auto desc_final = transform_tensor_descriptor(
desc_stage1,
make_tuple(
make_merge_transform(make_tuple(number<8>{}, number<8>{})), // Merge Rep0, Dim0
make_pass_through_transform(number<2>{}) // Dim1 unchanged
),
make_tuple(sequence<0, 1>{}, sequence<2>{}), // Merge dims 0,1; pass-through dim 2
make_tuple(sequence<0>{}, sequence<1>{}) // Output: [Merged, Dim1]
make_merge_transform(make_tuple(number<8>{}, number<8>{})), // Merge Rep0, Dim0
make_pass_through_transform(number<2>{}) // Dim1 unchanged
),
make_tuple(sequence<0, 1>{}, sequence<2>{}), // Merge dims 0,1; pass-through dim 2
make_tuple(sequence<0>{}, sequence<1>{}) // Output: [Merged, Dim1]
);
printf("\n Final: [Merged=64, Dim1=2]\n\n");
// Comprehensive coordinate testing - ALL coordinates
printf("COORDINATE MAPPING TEST - ALL %ld coordinates:\n",
static_cast<long>(64 * 2));
printf("COORDINATE MAPPING TEST - ALL %ld coordinates:\n", static_cast<long>(64 * 2));
printf("=======================================================\n");
printf("Format: [Merged, Dim1] -> memory_offset\n\n");
auto lengths_final = desc_final.get_lengths();
index_t merged_len = lengths_final[number<0>{}];
index_t dim1_len = lengths_final[number<1>{}];
index_t dim1_len = lengths_final[number<1>{}];
printf("Descriptor: [Merged=%ld, Dim1=%ld] = %ld total coordinates\n",
static_cast<long>(merged_len), static_cast<long>(dim1_len),
static_cast<long>(merged_len),
static_cast<long>(dim1_len),
static_cast<long>(merged_len * dim1_len));
printf("Memory: Only 16 locations (broadcasting effect!)\n\n");
// Print ALL coordinates to show broadcasting pattern
index_t count = 0;
for(index_t merged = 0; merged < merged_len; merged++) {
for(index_t dim1 = 0; dim1 < dim1_len; dim1++) {
auto coord = make_tensor_coordinate(desc_final, make_tuple(merged, dim1));
for(index_t merged = 0; merged < merged_len; merged++)
{
for(index_t dim1 = 0; dim1 < dim1_len; dim1++)
{
auto coord = make_tensor_coordinate(desc_final, make_tuple(merged, dim1));
index_t offset = coord.get_offset();
printf(" [%2ld, %ld] -> offset %2ld",
static_cast<long>(merged), static_cast<long>(dim1),
static_cast<long>(merged),
static_cast<long>(dim1),
static_cast<long>(offset));
// Add newline every 4 coordinates for readability
count++;
if(count % 4 == 0) {
if(count % 4 == 0)
{
printf("\n");
} else {
}
else
{
printf(" | ");
}
}
}
if(count % 4 != 0) printf("\n");
if(count % 4 != 0)
printf("\n");
printf("\nKey Observations:\n");
printf(" - Total coordinates: %ld (Merged=%ld × Dim1=%ld)\n",
static_cast<long>(merged_len * dim1_len),
static_cast<long>(merged_len), static_cast<long>(dim1_len));
static_cast<long>(merged_len),
static_cast<long>(dim1_len));
printf(" - Memory locations: 16 (original size)\n");
printf(" - Broadcasting ratio: %ld:1 (each memory location accessed by %ld coordinates)\n",
static_cast<long>((merged_len * dim1_len) / 16),
@@ -436,7 +451,8 @@ struct TensorAdaptorsKernel
CK_TILE_DEVICE void operator()(const DataType* p_data) const
{
if(get_thread_id() != 0) return;
if(get_thread_id() != 0)
return;
printf("\n=== TENSOR ADAPTORS IN CK_TILE ===\n\n");
@@ -476,7 +492,8 @@ int main()
int device_count;
hip_check_error(hipGetDeviceCount(&device_count));
if(device_count == 0) {
if(device_count == 0)
{
std::cerr << "No GPU devices found!\n";
return 1;
}
@@ -488,13 +505,15 @@ int main()
// Allocate data for padding example (16 elements, but only first 10 have real data)
constexpr index_t data_size = 16;
std::vector<float> h_data(data_size, 0.0f); // Initialize all to 0
std::iota(h_data.begin(), h_data.begin() + 10, 1.0f); // First 10: 1,2,3,...,10
std::vector<float> h_data(data_size, 0.0f); // Initialize all to 0
std::iota(h_data.begin(), h_data.begin() + 10, 1.0f); // First 10: 1,2,3,...,10
std::cout << "\nTest data (first 10 real, last 6 padding zeros): ";
for(size_t i = 0; i < h_data.size(); i++) {
for(size_t i = 0; i < h_data.size(); i++)
{
std::cout << h_data[i];
if(i < h_data.size() - 1) std::cout << " ";
if(i < h_data.size() - 1)
std::cout << " ";
}
std::cout << "\n";
@@ -508,12 +527,11 @@ int main()
std::cout << "=====================================\n";
launch_kernel(stream,
make_kernel<block_size>(
TensorAdaptorsKernel<float>{},
dim3(1),
dim3(block_size),
0,
static_cast<const float*>(d_data.GetDeviceBuffer())));
make_kernel<block_size>(TensorAdaptorsKernel<float>{},
dim3(1),
dim3(block_size),
0,
static_cast<const float*>(d_data.GetDeviceBuffer())));
hip_check_error(hipDeviceSynchronize());
std::cout << "=====================================\n";

View File

@@ -21,16 +21,16 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct PaddingTileKernel
{
static constexpr index_t kBlockSize = 64;
CK_TILE_DEVICE void operator()(const DataType* p_data,
index_t orig_size,
index_t padded_size) const
CK_TILE_DEVICE void
operator()(const DataType* p_data, index_t orig_size, index_t padded_size) const
{
if(get_thread_id() != 0) return;
if(get_thread_id() != 0)
return;
printf("\n=== PADDING WITH TILE WINDOWS ===\n\n");
@@ -38,24 +38,17 @@ struct PaddingTileKernel
printf("Padded size: %ld\n\n", static_cast<long>(padded_size));
// Step 1: Create original descriptor (runtime)
auto desc_orig = make_naive_tensor_descriptor(
make_tuple(orig_size),
make_tuple(1)
);
auto desc_orig = make_naive_tensor_descriptor(make_tuple(orig_size), make_tuple(1));
// Step 2: Apply padding transform
index_t pad_amount = padded_size - orig_size;
auto desc_padded = transform_tensor_descriptor(
desc_orig,
make_tuple(make_right_pad_transform(orig_size, pad_amount)),
make_tuple(sequence<0>{}),
make_tuple(sequence<0>{})
);
auto desc_padded =
transform_tensor_descriptor(desc_orig,
make_tuple(make_right_pad_transform(orig_size, pad_amount)),
make_tuple(sequence<0>{}),
make_tuple(sequence<0>{}));
auto tensor_simple = make_tensor_view<address_space_enum::global>(
p_data,
desc_padded
);
auto tensor_simple = make_tensor_view<address_space_enum::global>(p_data, desc_padded);
printf("Created tensor_view (simple API, no identity value)\n");
printf(" - Padded reads will wrap around to existing data\n\n");
@@ -63,26 +56,28 @@ struct PaddingTileKernel
// Step 5: Read tiles using get_vectorized_elements
constexpr index_t tile_size = 8;
printf("Reading tiles of size %ld using get_vectorized_elements:\n\n",
printf("Reading tiles of size %ld using get_vectorized_elements:\n\n",
static_cast<long>(tile_size));
// Load tiles covering the entire padded range
index_t num_tiles = (padded_size + tile_size - 1) / tile_size;
for(index_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) {
for(index_t tile_idx = 0; tile_idx < num_tiles; tile_idx++)
{
// Use get_vectorized_elements directly on tensor_view
printf("Tile %ld (indices %ld-%ld):\n",
static_cast<long>(tile_idx),
static_cast<long>(tile_idx * tile_size),
static_cast<long>(tile_idx * tile_size + tile_size - 1));
printf(" Values: ");
// Use static_for to access elements with compile-time indices
static_for<0, tile_size, 1>{}([&](auto i) {
index_t global_idx = tile_idx * tile_size + i;
auto coord = make_tensor_coordinate(desc_padded, make_tuple(global_idx));
auto buffer = tensor_simple.template get_vectorized_elements<
thread_buffer<DataType, 1>>(coord, 0);
auto coord = make_tensor_coordinate(desc_padded, make_tuple(global_idx));
auto buffer =
tensor_simple.template get_vectorized_elements<thread_buffer<DataType, 1>>(
coord, 0);
// static_for<0, 4, 1>{}([&](auto j) {
// DataType val = buffer[number<j>{}];
// printf("%.1f ", static_cast<float>(val));
@@ -91,11 +86,12 @@ struct PaddingTileKernel
printf("%.1f ", static_cast<float>(val));
});
printf("\n");
// Check if this tile contains padding
index_t tile_start = tile_idx * tile_size;
index_t tile_end = tile_start + tile_size;
if(tile_end > orig_size) {
index_t tile_end = tile_start + tile_size;
if(tile_end > orig_size)
{
printf(" Note: Elements %ld-%ld are padded (return identity value 0.0)\n",
static_cast<long>(orig_size - tile_start),
static_cast<long>(tile_size - 1));
@@ -108,7 +104,6 @@ struct PaddingTileKernel
printf(" - Out-of-bounds accesses return identity value (0.0)\n");
printf(" - get_vectorized_elements properly handles padding\n");
printf(" - This is the pattern used in pooling/convolution kernels\n\n");
}
};
@@ -120,7 +115,8 @@ int main()
int device_count;
hip_check_error(hipGetDeviceCount(&device_count));
if(device_count == 0) {
if(device_count == 0)
{
std::cerr << "No GPU devices found!\n";
return 1;
}
@@ -131,14 +127,15 @@ int main()
std::cout << "Using GPU: " << props.name << "\n";
// Create test data: 10 real elements
constexpr index_t orig_size = 10;
constexpr index_t orig_size = 10;
constexpr index_t padded_size = 16;
std::vector<float> h_data(orig_size);
std::iota(h_data.begin(), h_data.end(), 1.0f); // 1, 2, 3, ..., 10
std::iota(h_data.begin(), h_data.end(), 1.0f); // 1, 2, 3, ..., 10
std::cout << "\nTest data (" << orig_size << " elements): ";
for(auto val : h_data) {
for(auto val : h_data)
{
std::cout << val << " ";
}
std::cout << "\n";
@@ -154,14 +151,13 @@ int main()
std::cout << "=====================================\n";
launch_kernel(stream,
make_kernel<block_size>(
PaddingTileKernel<float>{},
dim3(1),
dim3(block_size),
0,
static_cast<const float*>(d_data.GetDeviceBuffer()),
orig_size,
padded_size));
make_kernel<block_size>(PaddingTileKernel<float>{},
dim3(1),
dim3(block_size),
0,
static_cast<const float*>(d_data.GetDeviceBuffer()),
orig_size,
padded_size));
hip_check_error(hipDeviceSynchronize());
std::cout << "=====================================\n";

View File

@@ -21,7 +21,7 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct DescriptorVsAdaptorKernel
{
static constexpr index_t kBlockSize = 64;
@@ -38,24 +38,23 @@ struct DescriptorVsAdaptorKernel
printf("Example 1.1: Matrix Tiling [M, K] -> [M0, M1, K]\n");
printf("------------------------------------------------\n");
{
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M0 = 4;
constexpr index_t M1 = 32;
printf("Input: [M=%ld, K=%ld]\n", static_cast<long>(M), static_cast<long>(K));
printf("Output: [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
printf("Output: [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K));
// Create adaptor - only transformation logic
auto adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})
),
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{})
);
make_tuple(sequence<0, 1>{}, sequence<2>{}));
printf("\nAdaptor properties:\n");
printf(" - Stores: Transformation logic only\n");
@@ -64,9 +63,9 @@ struct DescriptorVsAdaptorKernel
printf(" - Cannot do: Calculate memory offsets\n\n");
// Test coordinate mapping
auto top_idx = make_tuple(2, 16, 32);
auto top_idx = make_tuple(2, 16, 32);
auto bottom_idx = adaptor.calculate_bottom_index(top_idx);
printf("Coordinate mapping test:\n");
printf(" Input: [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(top_idx.template get<0>()),
@@ -90,23 +89,18 @@ struct DescriptorVsAdaptorKernel
// Define a generic 2D tiling pattern
auto create_tiling_adaptor = [](auto M0, auto M1, auto N0, auto N1) {
return make_single_stage_tensor_adaptor(
make_tuple(
make_unmerge_transform(make_tuple(M0, M1)),
make_unmerge_transform(make_tuple(N0, N1))
),
make_tuple(make_unmerge_transform(make_tuple(M0, M1)),
make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{})
);
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}));
};
// Use same pattern for different matrix sizes
[[maybe_unused]] auto adaptor_64x64 = create_tiling_adaptor(
number<4>{}, number<16>{}, number<4>{}, number<16>{}
);
[[maybe_unused]] auto adaptor_128x128 = create_tiling_adaptor(
number<8>{}, number<16>{}, number<8>{}, number<16>{}
);
[[maybe_unused]] auto adaptor_64x64 =
create_tiling_adaptor(number<4>{}, number<16>{}, number<4>{}, number<16>{});
[[maybe_unused]] auto adaptor_128x128 =
create_tiling_adaptor(number<8>{}, number<16>{}, number<8>{}, number<16>{});
printf("Created two adaptors with same pattern:\n");
printf(" - 64x64 matrix: [64, 64] -> [4, 16, 4, 16]\n");
@@ -133,12 +127,11 @@ struct DescriptorVsAdaptorKernel
constexpr index_t K = 64;
// Create descriptor - includes memory information
auto desc = make_naive_tensor_descriptor_packed(
make_tuple(number<M>{}, number<K>{})
);
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
printf("Created descriptor for [M=%ld, K=%ld] matrix\n",
static_cast<long>(M), static_cast<long>(K));
static_cast<long>(M),
static_cast<long>(K));
auto space_size = desc.get_element_space_size();
printf("\nDescriptor properties:\n");
@@ -150,13 +143,15 @@ struct DescriptorVsAdaptorKernel
// Calculate memory offset
auto offset1 = desc.calculate_offset(make_tuple(10, 20));
auto offset2 = desc.calculate_offset(make_tuple(0, 0));
auto offset3 = desc.calculate_offset(make_tuple(M-1, K-1));
auto offset3 = desc.calculate_offset(make_tuple(M - 1, K - 1));
printf("Memory offset calculations:\n");
printf(" [10, 20] -> offset %ld (10*64 + 20)\n", static_cast<long>(offset1));
printf(" [0, 0] -> offset %ld (first element)\n", static_cast<long>(offset2));
printf(" [%ld, %ld] -> offset %ld (last element)\n",
static_cast<long>(M-1), static_cast<long>(K-1), static_cast<long>(offset3));
static_cast<long>(M - 1),
static_cast<long>(K - 1),
static_cast<long>(offset3));
}
printf("\n");
@@ -165,32 +160,30 @@ struct DescriptorVsAdaptorKernel
printf("Example 2.2: Transforming a Descriptor\n");
printf("---------------------------------------\n");
{
constexpr index_t M = 256;
constexpr index_t K = 128;
constexpr index_t M = 256;
constexpr index_t K = 128;
constexpr index_t M0 = 4;
constexpr index_t M1 = 64;
printf("Step 1: Create initial descriptor\n");
auto desc_initial = make_naive_tensor_descriptor_packed(
make_tuple(number<M>{}, number<K>{})
);
auto desc_initial =
make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
printf(" Initial: [M=%ld, K=%ld]\n", static_cast<long>(M), static_cast<long>(K));
printf(" Memory size: %ld elements\n\n",
printf(" Memory size: %ld elements\n\n",
static_cast<long>(desc_initial.get_element_space_size()));
printf("Step 2: Transform to add tiling\n");
auto desc_tiled = transform_tensor_descriptor(
desc_initial,
make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})
),
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{})
);
make_tuple(sequence<0, 1>{}, sequence<2>{}));
printf(" Transformed: [M, K] -> [M0=%ld, M1=%ld, K=%ld]\n",
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K));
printf(" Memory size preserved: %ld elements\n\n",
static_cast<long>(desc_tiled.get_element_space_size()));
@@ -217,9 +210,7 @@ struct DescriptorVsAdaptorKernel
constexpr index_t M = 64;
constexpr index_t K = 32;
auto desc = make_naive_tensor_descriptor_packed(
make_tuple(number<M>{}, number<K>{})
);
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
printf("Descriptor: [M=%ld, K=%ld]\n\n", static_cast<long>(M), static_cast<long>(K));
@@ -243,9 +234,7 @@ struct DescriptorVsAdaptorKernel
constexpr index_t M = 64;
constexpr index_t K = 32;
auto desc = make_naive_tensor_descriptor_packed(
make_tuple(number<M>{}, number<K>{})
);
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
printf("Scenario: Iterate through a row of tiles\n");
printf("Descriptor: [M=%ld, K=%ld]\n\n", static_cast<long>(M), static_cast<long>(K));
@@ -286,28 +275,25 @@ struct DescriptorVsAdaptorKernel
printf("Example 3.3: Moving Coordinates with Transformations\n");
printf("----------------------------------------------------\n");
{
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M = 128;
constexpr index_t K = 64;
constexpr index_t M0 = 4;
constexpr index_t M1 = 32;
// Create tiled descriptor
auto desc = make_naive_tensor_descriptor_packed(
make_tuple(number<M>{}, number<K>{})
);
auto desc = make_naive_tensor_descriptor_packed(make_tuple(number<M>{}, number<K>{}));
auto desc_tiled = transform_tensor_descriptor(
desc,
make_tuple(
make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})
),
make_tuple(make_unmerge_transform(make_tuple(number<M0>{}, number<M1>{})),
make_pass_through_transform(number<K>{})),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{})
);
make_tuple(sequence<0, 1>{}, sequence<2>{}));
printf("Tiled descriptor: [M, K] -> [M0=%ld, M1=%ld, K=%ld]\n\n",
static_cast<long>(M0), static_cast<long>(M1), static_cast<long>(K));
static_cast<long>(M0),
static_cast<long>(M1),
static_cast<long>(K));
// Create coordinate
auto coord = make_tensor_coordinate(desc_tiled, make_tuple(1, 8, 16));
@@ -360,7 +346,8 @@ struct DescriptorVsAdaptorKernel
CK_TILE_DEVICE void operator()() const
{
if(get_thread_id() != 0) return;
if(get_thread_id() != 0)
return;
printf("\n=== TENSOR DESCRIPTOR VS TENSOR ADAPTOR ===\n\n");
@@ -400,7 +387,8 @@ int main()
int device_count;
hip_check_error(hipGetDeviceCount(&device_count));
if(device_count == 0) {
if(device_count == 0)
{
std::cerr << "No GPU devices found!\n";
return 1;
}
@@ -416,12 +404,9 @@ int main()
std::cout << "\nLaunching kernel...\n";
std::cout << "=====================================\n";
launch_kernel(stream,
make_kernel<block_size>(
DescriptorVsAdaptorKernel<float>{},
dim3(1),
dim3(block_size),
0));
launch_kernel(
stream,
make_kernel<block_size>(DescriptorVsAdaptorKernel<float>{}, dim3(1), dim3(block_size), 0));
hip_check_error(hipDeviceSynchronize());
std::cout << "=====================================\n";

View File

@@ -26,16 +26,16 @@
using namespace ck_tile;
// Distributed HGEMM kernel using proper tile_distribution
template<typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
template <typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
struct DistributedHgemmKernel
{
static constexpr index_t kWaveSize = 64; // AMD wave size
static constexpr index_t kBlockM = 16; // MFMA M dimension
static constexpr index_t kBlockN = 16; // MFMA N dimension
static constexpr index_t kBlockK = 16; // MFMA K dimension per instruction
static constexpr index_t kWaveSize = 64; // AMD wave size
static constexpr index_t kBlockM = 16; // MFMA M dimension
static constexpr index_t kBlockN = 16; // MFMA N dimension
static constexpr index_t kBlockK = 16; // MFMA K dimension per instruction
// Use ck_tile's WarpGemm for MFMA
using WarpGemm = WarpGemmMfmaF16F16F32M16N16K16;
using WarpGemm = WarpGemmMfmaF16F16F32M16N16K16;
static constexpr index_t kBlockSize = kWaveSize;
CK_TILE_DEVICE void operator()(const ADataType* a,
@@ -45,18 +45,19 @@ struct DistributedHgemmKernel
index_t M,
index_t N,
index_t K,
index_t lda, // Leading dimension of A (column-major)
index_t ldb, // Leading dimension of B (row-major)
index_t ldc, // Leading dimension of C (column-major)
index_t ldd, // Leading dimension of D (column-major)
index_t lda, // Leading dimension of A (column-major)
index_t ldb, // Leading dimension of B (row-major)
index_t ldc, // Leading dimension of C (column-major)
index_t ldd, // Leading dimension of D (column-major)
AccDataType alpha,
AccDataType beta) const
{
// Calculate which 16×16 block this wave computes
// const index_t wave_id = get_block_id() * get_block_size() / kWaveSize + threadIdx.x / kWaveSize;
// const index_t wave_id = get_block_id() * get_block_size() / kWaveSize + threadIdx.x /
// kWaveSize;
const index_t wave_id = get_warp_id();
const index_t wave_m = wave_id / (N / kBlockN);
const index_t wave_n = wave_id % (N / kBlockN);
const index_t wave_m = wave_id / (N / kBlockN);
const index_t wave_n = wave_id % (N / kBlockN);
const index_t m_offset = wave_m * kBlockM;
const index_t n_offset = wave_n * kBlockN;
@@ -73,89 +74,81 @@ struct DistributedHgemmKernel
// A is column-major: M×K with stride lda between columns
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
a,
make_tuple(M, K), // Shape: M×K
make_tuple(1, lda), // Strides: column-major
make_tuple(M, K), // Shape: M×K
make_tuple(1, lda), // Strides: column-major
number<1>{},
number<1>{}
);
number<1>{});
// B is row-major: K×N with stride ldb between rows
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
b,
make_tuple(K, N), // Shape: K×N
make_tuple(ldb, 1), // Strides: row-major
make_tuple(K, N), // Shape: K×N
make_tuple(ldb, 1), // Strides: row-major
number<4>{},
number<1>{}
);
number<1>{});
// C is column-major: M×N with stride ldc between columns
const auto c_tensor = make_naive_tensor_view<address_space_enum::global>(
c,
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldc), // Strides: column-major
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldc), // Strides: column-major
number<1>{},
number<1>{}
);
number<1>{});
// D is column-major: M×N with stride ldd between columns
auto d_tensor = make_naive_tensor_view<address_space_enum::global>(
d,
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldd), // Strides: column-major
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldd), // Strides: column-major
number<1>{},
number<1>{}
);
number<1>{});
// Use our tested custom distributions from test_a_distribution.cpp and test_b_distribution.cpp
// A: Column-major M×K with each thread loading 4 consecutive K values from one M position
// Use our tested custom distributions from test_a_distribution.cpp and
// test_b_distribution.cpp A: Column-major M×K with each thread loading 4 consecutive K
// values from one M position
constexpr auto a_distribution = make_static_tile_distribution(
tile_distribution_encoding<
sequence<>, // No replication
tuple<sequence<16>, // H0 (M): 16 lanes for M
sequence<4, 4>>, // H1 (K): 4 lanes × 4 per lane
tuple<sequence<2, 1>>, // P-dims map to H-dims
tuple<sequence<0, 0>>, // P positions in H-dims
sequence<2>, // Y maps to K dimension only
sequence<1>>{} // Y at position 1
tile_distribution_encoding<sequence<>, // No replication
tuple<sequence<16>, // H0 (M): 16 lanes for M
sequence<4, 4>>, // H1 (K): 4 lanes × 4 per lane
tuple<sequence<2, 1>>, // P-dims map to H-dims
tuple<sequence<0, 0>>, // P positions in H-dims
sequence<2>, // Y maps to K dimension only
sequence<1>>{} // Y at position 1
);
// B: Row-major K×N with each thread loading 4 consecutive K values from one N position
constexpr auto b_distribution = make_static_tile_distribution(
tile_distribution_encoding<
sequence<>, // No replication
tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K values
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
tuple<sequence<0, 0>>, // P positions in H-dims
sequence<1>, // Y maps to K dimension (H0)
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
sequence<>, // No replication
tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K values
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
tuple<sequence<0, 0>>, // P positions in H-dims
sequence<1>, // Y maps to K dimension (H0)
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
);
// Create windows for A and B that we'll move along K
auto a_window = make_tile_window(
a_tensor,
make_tuple(number<kBlockM>{}, number<kBlockK>{}),
{m_offset, 0},
a_distribution
);
auto a_window = make_tile_window(a_tensor,
make_tuple(number<kBlockM>{}, number<kBlockK>{}),
{m_offset, 0},
a_distribution);
auto b_window = make_tile_window(
b_tensor,
make_tuple(number<kBlockK>{}, number<kBlockN>{}),
{0, n_offset},
b_distribution
);
auto b_window = make_tile_window(b_tensor,
make_tuple(number<kBlockK>{}, number<kBlockN>{}),
{0, n_offset},
b_distribution);
// C distribution (column-major M×N output) - tested in test_c_distribution.cpp
constexpr auto c_distribution = make_static_tile_distribution(
tile_distribution_encoding<
sequence<>, // No replication
tuple<sequence<4, 4>, // H0 (M): 4 groups of 4 consecutive M values
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
tuple<sequence<0, 0>>, // P positions in H-dims
sequence<1>, // Y maps to M dimension (H0)
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
sequence<>, // No replication
tuple<sequence<4, 4>, // H0 (M): 4 groups of 4 consecutive M values
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
tuple<sequence<0, 0>>, // P positions in H-dims
sequence<1>, // Y maps to M dimension (H0)
sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
);
// Create accumulator using our tested C distribution
@@ -178,9 +171,10 @@ struct DistributedHgemmKernel
// Move windows to next K chunk using the move API
// This efficiently updates window_origin_ without recreating the window
if(k_iter < num_k_loops - 1) {
a_window.move({0, kBlockK}); // Move K forward for A
b_window.move({kBlockK, 0}); // Move K forward for B
if(k_iter < num_k_loops - 1)
{
a_window.move({0, kBlockK}); // Move K forward for A
b_window.move({kBlockK, 0}); // Move K forward for B
}
}
@@ -191,57 +185,60 @@ struct DistributedHgemmKernel
// Load C, apply beta, and add to result
if(std::abs(beta) > 1e-6f)
{
auto c_window = make_tile_window(
c_tensor,
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
{m_offset, n_offset},
c_distribution
);
auto c_window = make_tile_window(c_tensor,
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
{m_offset, n_offset},
c_distribution);
const auto c_tile = load_tile(c_window);
// Apply beta * C + acc using ck_tile's elementwise API
// This combines two tiles with a lambda function
tile_elementwise_inout(
[beta](const auto& c_val, auto& acc_val) {
acc_val += beta * c_val;
},
c_tile, acc_tile);
[beta](const auto& c_val, auto& acc_val) { acc_val += beta * c_val; },
c_tile,
acc_tile);
}
// Store final result to D
auto d_window = make_tile_window(
d_tensor,
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
{m_offset, n_offset},
c_distribution
);
auto d_window = make_tile_window(d_tensor,
make_tuple(number<kBlockM>{}, number<kBlockN>{}),
{m_offset, n_offset},
c_distribution);
store_tile(d_window, acc_tile);
}
};
// CPU reference for verification
template<typename InType, typename AccType>
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
const std::vector<InType>& b, // Row-major
const std::vector<AccType>& c, // Column-major
std::vector<AccType>& d, // Column-major
index_t M, index_t N, index_t K,
index_t lda, index_t ldb, index_t ldc, index_t ldd,
AccType alpha, AccType beta)
template <typename InType, typename AccType>
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
const std::vector<InType>& b, // Row-major
const std::vector<AccType>& c, // Column-major
std::vector<AccType>& d, // Column-major
index_t M,
index_t N,
index_t K,
index_t lda,
index_t ldb,
index_t ldc,
index_t ldd,
AccType alpha,
AccType beta)
{
// D = alpha * A * B + beta * C
for(index_t n = 0; n < N; ++n) {
for(index_t m = 0; m < M; ++m) {
for(index_t n = 0; n < N; ++n)
{
for(index_t m = 0; m < M; ++m)
{
AccType sum = 0;
// Compute A * B
for(index_t k = 0; k < K; ++k) {
for(index_t k = 0; k < K; ++k)
{
// A is column-major: A[m,k] = a[m + k*lda]
// B is row-major: B[k,n] = b[k*ldb + n]
sum += static_cast<AccType>(a[m + k * lda]) *
static_cast<AccType>(b[k * ldb + n]);
sum += static_cast<AccType>(a[m + k * lda]) * static_cast<AccType>(b[k * ldb + n]);
}
// D[m,n] = alpha * sum + beta * C[m,n]
@@ -252,30 +249,38 @@ void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
}
// Helper to fill matrix with random values
template<typename T>
template <typename T>
void fill_random(std::vector<T>& data, T min_val = -1, T max_val = 1)
{
for(auto& val : data) {
val = static_cast<T>(min_val + (max_val - min_val) *
static_cast<float>(rand()) / RAND_MAX);
for(auto& val : data)
{
val = static_cast<T>(min_val + (max_val - min_val) * static_cast<float>(rand()) / RAND_MAX);
}
}
// Helper to print matrix (for debugging)
template<typename T>
void print_matrix(const std::vector<T>& mat, index_t rows, index_t cols,
index_t ld, bool col_major = true, const std::string& name = "Matrix")
template <typename T>
void print_matrix(const std::vector<T>& mat,
index_t rows,
index_t cols,
index_t ld,
bool col_major = true,
const std::string& name = "Matrix")
{
std::cout << name << " (" << rows << "×" << cols << "):\n";
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i) {
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j) {
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i)
{
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j)
{
index_t idx = col_major ? (i + j * ld) : (i * ld + j);
std::cout << std::setw(8) << std::setprecision(3) << mat[idx] << " ";
}
if(cols > 8) std::cout << "...";
if(cols > 8)
std::cout << "...";
std::cout << "\n";
}
if(rows > 8) std::cout << "...\n";
if(rows > 8)
std::cout << "...\n";
std::cout << "\n";
}
@@ -298,16 +303,16 @@ int main()
constexpr index_t K = 64;
// Leading dimensions
constexpr index_t lda = M; // Column-major
constexpr index_t ldb = N; // Row-major
constexpr index_t ldc = M; // Column-major
constexpr index_t ldd = M; // Column-major
constexpr index_t lda = M; // Column-major
constexpr index_t ldb = N; // Row-major
constexpr index_t ldc = M; // Column-major
constexpr index_t ldd = M; // Column-major
using InputType = half_t; // fp16
using AccumType = float; // fp32
using InputType = half_t; // fp16
using AccumType = float; // fp32
constexpr AccumType alpha = 2.0f;
constexpr AccumType beta = 1.5f;
constexpr AccumType beta = 1.5f;
std::cout << "Problem configuration:\n";
std::cout << " M×N×K: " << M << "×" << N << "×" << K << "\n";
@@ -315,7 +320,7 @@ int main()
std::cout << " B: row-major, ldb=" << ldb << " (fp16)\n";
std::cout << " C/D: column-major, ldc=" << ldc << ", ldd=" << ldd << " (fp32)\n";
std::cout << " alpha=" << alpha << ", beta=" << beta << "\n";
std::cout << " Total FLOPs: " << 2*M*N*K << "\n\n";
std::cout << " Total FLOPs: " << 2 * M * N * K << "\n\n";
// Host memory
std::vector<InputType> h_a(M * K);
@@ -333,7 +338,7 @@ int main()
// CPU reference
auto cpu_start = std::chrono::high_resolution_clock::now();
reference_gemm_mixed(h_a, h_b, h_c, h_d_ref, M, N, K, lda, ldb, ldc, ldd, alpha, beta);
auto cpu_end = std::chrono::high_resolution_clock::now();
auto cpu_end = std::chrono::high_resolution_clock::now();
double cpu_time_ms = std::chrono::duration<double, std::milli>(cpu_end - cpu_start).count();
// Device memory
@@ -348,30 +353,39 @@ int main()
d_d.ToDevice(h_d.data(), M * N * sizeof(AccumType));
// Launch kernel
constexpr index_t block_size = 64; // One wave
const index_t grid_size = (M / 16) * (N / 16); // One wave per 16×16 output block
constexpr index_t block_size = 64; // One wave
const index_t grid_size = (M / 16) * (N / 16); // One wave per 16×16 output block
std::cout << "Launching kernel:\n";
std::cout << " Grid: " << grid_size << " blocks\n";
std::cout << " Block: " << block_size << " threads (1 wave)\n";
std::cout << " Output blocks: " << (M/16) << "×" << (N/16) << " = " << grid_size << "\n";
std::cout << " MFMA instructions per block: " << K/16 << "\n\n";
std::cout << " Output blocks: " << (M / 16) << "×" << (N / 16) << " = " << grid_size << "\n";
std::cout << " MFMA instructions per block: " << K / 16 << "\n\n";
stream_config stream;
// Warmup
for(int i = 0; i < 5; ++i) {
for(int i = 0; i < 5; ++i)
{
launch_kernel(stream,
make_kernel<block_size>(
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
make_kernel<block_size>(
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M,
N,
K,
lda,
ldb,
ldc,
ldd,
alpha,
beta));
}
hip_check_error(hipDeviceSynchronize());
@@ -379,40 +393,50 @@ int main()
auto gpu_start = std::chrono::high_resolution_clock::now();
launch_kernel(stream,
make_kernel<block_size>(
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
make_kernel<block_size>(
DistributedHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M,
N,
K,
lda,
ldb,
ldc,
ldd,
alpha,
beta));
hip_check_error(hipDeviceSynchronize());
auto gpu_end = std::chrono::high_resolution_clock::now();
auto gpu_end = std::chrono::high_resolution_clock::now();
double gpu_time_ms = std::chrono::duration<double, std::milli>(gpu_end - gpu_start).count();
// Get result
d_d.FromDevice(h_d.data(), M * N * sizeof(AccumType));
// Verify
bool passed = true;
float max_error = 0;
bool passed = true;
float max_error = 0;
index_t error_count = 0;
for(index_t i = 0; i < M * N; ++i) {
for(index_t i = 0; i < M * N; ++i)
{
float error = std::abs(h_d[i] - h_d_ref[i]);
max_error = std::max(max_error, error);
if(error > 1e-2f) { // Relaxed tolerance for fp16
if(error_count < 5) {
max_error = std::max(max_error, error);
if(error > 1e-2f)
{ // Relaxed tolerance for fp16
if(error_count < 5)
{
index_t m = i % M;
index_t n = i / M;
std::cout << "Error at [" << m << "," << n << "]: "
<< h_d[i] << " vs " << h_d_ref[i]
<< " (diff=" << error << ")\n";
std::cout << "Error at [" << m << "," << n << "]: " << h_d[i] << " vs "
<< h_d_ref[i] << " (diff=" << error << ")\n";
}
error_count++;
}
@@ -421,14 +445,15 @@ int main()
passed = (error_count == 0);
// Calculate performance
double gflops = 2.0 * M * N * K / 1e9;
double gflops = 2.0 * M * N * K / 1e9;
double gpu_tflops = gflops / (gpu_time_ms / 1000);
double cpu_gflops = gflops / (cpu_time_ms / 1000);
std::cout << "Results:\n";
std::cout << " Correctness: " << (passed ? "✓ PASSED" : "✗ FAILED") << "\n";
std::cout << " Max error: " << max_error << "\n";
if(!passed) std::cout << " Error count: " << error_count << "/" << M*N << "\n";
if(!passed)
std::cout << " Error count: " << error_count << "/" << M * N << "\n";
std::cout << "\n";
std::cout << "Performance:\n";

View File

@@ -18,101 +18,94 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestADistributionKernel
{
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kM = 16;
static constexpr index_t kK = 16;
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kM = 16;
static constexpr index_t kK = 16;
CK_TILE_DEVICE void operator()(const DataType* a,
DataType* debug_output,
index_t lda) const
CK_TILE_DEVICE void operator()(const DataType* a, DataType* debug_output, index_t lda) const
{
if(get_block_id() != 0)
return;
const index_t tid = threadIdx.x;
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
// Create tensor view for A (column-major)
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
a,
make_tuple(kM, kK),
make_tuple(1, lda),
number<1>{},
number<1>{}
);
a, make_tuple(kM, kK), make_tuple(1, lda), number<1>{}, number<1>{});
// A distribution WITH NWarp replication and MWarp in H-dimension
// Based on 02_gemm pattern: include MWarp in H-tuple
// R: NWarp replication
// H0: MWarp × 16 threads = 2×16 = 32 M positions
// H0: MWarp × 16 threads = 2×16 = 32 M positions
// H1: 4×4 = 16 K elements
constexpr auto a_distribution = make_static_tile_distribution(
tile_distribution_encoding<
sequence<NWarp>, // R: REPLICATE across 2 N-warps
tuple<sequence<MWarp, 16>, // H0 (M): 2 M-warps × 16 threads = 32 M
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
tuple<sequence<0, 1>, sequence<2, 1>>, // Ps_to_Hs: P0→(R,M), P1→(M,K)
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
sequence<2>, // Ys_to_Hs: Y maps to K (dimension 2)
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
sequence<NWarp>, // R: REPLICATE across 2 N-warps
tuple<sequence<MWarp, 16>, // H0 (M): 2 M-warps × 16 threads = 32 M
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
tuple<sequence<0, 1>, sequence<2, 1>>, // Ps_to_Hs: P0→(R,M), P1→(M,K)
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
sequence<2>, // Ys_to_Hs: Y maps to K (dimension 2)
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
);
auto a_window = make_tile_window(
a_tensor,
make_tuple(number<kM>{}, number<kK>{}),
{0, 0},
a_distribution
);
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_distribution);
const auto a_tile = load_tile(a_window);
const auto a_tile = load_tile(a_window);
const auto& thread_buffer = a_tile.get_thread_buffer();
// Calculate matrix coordinates using make_tensor_coordinate
// This shows which matrix positions each thread accesses
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
printf("Distribution covers 32×16 matrix (MWarp×16 threads × K)\n");
printf("With NWarp=2 replication, pattern repeats\n");
printf("Showing first 16 threads of each warp:\n\n");
}
__syncthreads();
// Print warp by warp with calculated coordinates
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
w, w/NWarp, w%NWarp);
if(warp_id == w && lane_id == 0)
{
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
}
__syncthreads();
// Print lanes sequentially within each warp
for(int lane = 0; lane < 16; ++lane) {
for(int lane = 0; lane < 16; ++lane)
{
__syncthreads();
if(warp_id == w && lane_id == lane) {
if(warp_id == w && lane_id == lane)
{
printf("W%d L%02d: ", w, lane);
// For each Y element, just print the loaded value
// The distribution handles the coordinate mapping internally
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
{
float val = static_cast<float>(thread_buffer[y_idx]);
int m = static_cast<int>(val) % 100;
int k = static_cast<int>(val) / 100;
int m = static_cast<int>(val) % 100;
int k = static_cast<int>(val) / 100;
printf("A[%2d,%2d] ", m, k);
}
printf("\n");
@@ -122,7 +115,8 @@ struct TestADistributionKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Expected Pattern with NWarp Replication ===\n");
printf("sequence<NWarp> replicates across N-warp dimension:\n");
printf("Warp 0 (M-warp 0, N-warp 0): Loads some M-rows, K[0-15]\n");
@@ -130,13 +124,16 @@ struct TestADistributionKernel
printf("Warp 2 (M-warp 1, N-warp 0): Loads SAME as Warp 0 (NWarp replication!)\n");
printf("Warp 3 (M-warp 1, N-warp 1): Loads SAME as Warp 1 (NWarp replication!)\n");
printf("\nReplication pairs:\n");
printf(" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
printf(" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
printf(
" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
printf(
" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
printf(" Warps 0 & 1 should be DIFFERENT (different N-warps)\n");
}
// Store for verification
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
debug_output[tid * 4 + i] = thread_buffer[i];
}
}
@@ -148,8 +145,8 @@ int main()
std::cout << "Test A Distribution with NWarp Replication\n";
std::cout << "==================================================\n\n";
constexpr index_t M = 16;
constexpr index_t K = 16;
constexpr index_t M = 16;
constexpr index_t K = 16;
constexpr index_t lda = M;
using DataType = half_t;
@@ -159,8 +156,10 @@ int main()
std::vector<DataType> h_debug(256 * 4, -1);
// Initialize A[m,k] = m + k*100
for(index_t k = 0; k < K; ++k) {
for(index_t m = 0; m < M; ++m) {
for(index_t k = 0; k < K; ++k)
{
for(index_t m = 0; m < M; ++m)
{
h_a[m + k * lda] = static_cast<DataType>(m + k * 100);
}
}
@@ -173,14 +172,13 @@ int main()
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestADistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
lda));
make_kernel<256>(TestADistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
lda));
hip_check_error(hipDeviceSynchronize());
@@ -188,40 +186,52 @@ int main()
// Verify NWarp replication: warps 0&2 identical, warps 1&3 identical
bool passed = true;
// Check warps 0 and 2 (same N-warp 0, replicated across M-warps)
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp0_val = h_debug[lane * 4 + i];
float warp2_val = h_debug[(128 + lane) * 4 + i];
if(std::abs(warp0_val - warp2_val) > 0.01f) {
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp0_val - warp2_val) > 0.01f)
{
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i
<< "\n";
std::cout << " Warp 0: " << warp0_val << ", Warp 2: " << warp2_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
// Check warps 1 and 3 (same N-warp 1, replicated across M-warps)
if(passed) {
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
if(passed)
{
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp1_val = h_debug[(64 + lane) * 4 + i];
float warp3_val = h_debug[(192 + lane) * 4 + i];
if(std::abs(warp1_val - warp3_val) > 0.01f) {
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp1_val - warp3_val) > 0.01f)
{
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element "
<< i << "\n";
std::cout << " Warp 1: " << warp1_val << ", Warp 3: " << warp3_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
}
if(passed) {
if(passed)
{
std::cout << "\n✓ NWarp Replication verified:\n";
std::cout << " Warps 0 & 2 load identical data (N-warp 0, replicated across M-warps)\n";
std::cout << " Warps 1 & 3 load identical data (N-warp 1, replicated across M-warps)\n";

View File

@@ -18,35 +18,27 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestADistributionKernel
{
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kM = 16;
static constexpr index_t kK = 16;
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kM = 16;
static constexpr index_t kK = 16;
CK_TILE_DEVICE void operator()(const DataType* a,
DataType* debug_output,
index_t lda) const
CK_TILE_DEVICE void operator()(const DataType* a, DataType* debug_output, index_t lda) const
{
if(get_block_id() != 0)
return;
const index_t tid = threadIdx.x;
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
// Create tensor view for A (column-major)
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
a,
make_tuple(kM, kK),
make_tuple(1, lda),
number<1>{},
number<1>{}
);
a, make_tuple(kM, kK), make_tuple(1, lda), number<1>{}, number<1>{});
// A distribution using EMBED API (like 02_gemm)
// Separate block-level and warp-level distributions
@@ -62,84 +54,84 @@ struct TestADistributionKernel
// sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
// );
// Step 1: Warp-level distribution (64 threads within one warp)
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
sequence<>, // No replication at warp level
tuple<sequence<16>, // H0 (M): 16 M positions
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<2>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
constexpr auto a_warp_dstr_encode =
tile_distribution_encoding<sequence<>, // No replication at warp level
tuple<sequence<16>, // H0 (M): 16 M positions
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<2>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
// Step 2: Block-level outer distribution (warp organization)
// Must have same NDimX as inner (2 dimensions: M and K)
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
sequence<NWarp>, // R: Replicate across N-warps
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1 in K-dim
tuple<sequence<0, 1>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both M and K
sequence<>>{}; // Ys_in_Hs
constexpr auto a_block_outer_dstr_encode =
tile_distribution_encoding<sequence<NWarp>, // R: Replicate across N-warps
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1
// in K-dim
tuple<sequence<0, 1>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both M and K
sequence<>>{}; // Ys_in_Hs
// Step 3: Embed warp-level into block-level
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
// Step 4: Create final distribution
constexpr auto a_distribution = make_static_tile_distribution(a_block_dstr_encode);
auto a_window = make_tile_window(
a_tensor,
make_tuple(number<kM>{}, number<kK>{}),
{0, 0},
a_distribution
);
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_distribution);
const auto a_tile = load_tile(a_window);
const auto a_tile = load_tile(a_window);
const auto& thread_buffer = a_tile.get_thread_buffer();
// Calculate matrix coordinates using make_tensor_coordinate
// This shows which matrix positions each thread accesses
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
printf("Distribution covers 32×16 matrix (MWarp×16 threads × K)\n");
printf("With NWarp=2 replication, pattern repeats\n");
printf("Showing first 16 threads of each warp:\n\n");
}
__syncthreads();
// Print warp by warp with calculated coordinates
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
w, w/NWarp, w%NWarp);
if(warp_id == w && lane_id == 0)
{
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
}
__syncthreads();
// Print lanes sequentially within each warp
for(int lane = 0; lane < 16; ++lane) {
for(int lane = 0; lane < 16; ++lane)
{
__syncthreads();
if(warp_id == w && lane_id == lane) {
if(warp_id == w && lane_id == lane)
{
printf("W%d L%02d: ", w, lane);
// For each Y element, just print the loaded value
// The distribution handles the coordinate mapping internally
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
{
float val = static_cast<float>(thread_buffer[y_idx]);
int m = static_cast<int>(val) % 100;
int k = static_cast<int>(val) / 100;
int m = static_cast<int>(val) % 100;
int k = static_cast<int>(val) / 100;
printf("A[%2d,%2d] ", m, k);
}
printf("\n");
@@ -149,7 +141,8 @@ struct TestADistributionKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Expected Pattern with NWarp Replication ===\n");
printf("sequence<NWarp> replicates across N-warp dimension:\n");
printf("Warp 0 (M-warp 0, N-warp 0): Loads some M-rows, K[0-15]\n");
@@ -157,13 +150,16 @@ struct TestADistributionKernel
printf("Warp 2 (M-warp 1, N-warp 0): Loads SAME as Warp 0 (NWarp replication!)\n");
printf("Warp 3 (M-warp 1, N-warp 1): Loads SAME as Warp 1 (NWarp replication!)\n");
printf("\nReplication pairs:\n");
printf(" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
printf(" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
printf(
" Warps 0 & 2 should be identical (same N-warp 0, replicated across M-warps)\n");
printf(
" Warps 1 & 3 should be identical (same N-warp 1, replicated across M-warps)\n");
printf(" Warps 0 & 1 should be DIFFERENT (different N-warps)\n");
}
// Store for verification
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
debug_output[tid * 4 + i] = thread_buffer[i];
}
}
@@ -174,10 +170,11 @@ int main()
std::cout << "\n==================================================\n";
std::cout << "Test A Distribution using EMBED API\n";
std::cout << "==================================================\n\n";
std::cout << "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
std::cout
<< "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
constexpr index_t M = 16;
constexpr index_t K = 16;
constexpr index_t M = 16;
constexpr index_t K = 16;
constexpr index_t lda = M;
using DataType = half_t;
@@ -187,8 +184,10 @@ int main()
std::vector<DataType> h_debug(256 * 4, -1);
// Initialize A[m,k] = m + k*100
for(index_t k = 0; k < K; ++k) {
for(index_t m = 0; m < M; ++m) {
for(index_t k = 0; k < K; ++k)
{
for(index_t m = 0; m < M; ++m)
{
h_a[m + k * lda] = static_cast<DataType>(m + k * 100);
}
}
@@ -201,14 +200,13 @@ int main()
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestADistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
lda));
make_kernel<256>(TestADistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
lda));
hip_check_error(hipDeviceSynchronize());
@@ -216,40 +214,52 @@ int main()
// Verify NWarp replication: warps 0&2 identical, warps 1&3 identical
bool passed = true;
// Check warps 0 and 2 (same N-warp 0, replicated across M-warps)
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp0_val = h_debug[lane * 4 + i];
float warp2_val = h_debug[(128 + lane) * 4 + i];
if(std::abs(warp0_val - warp2_val) > 0.01f) {
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp0_val - warp2_val) > 0.01f)
{
std::cout << "ERROR: Warp 0 and Warp 2 differ at lane " << lane << " element " << i
<< "\n";
std::cout << " Warp 0: " << warp0_val << ", Warp 2: " << warp2_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
// Check warps 1 and 3 (same N-warp 1, replicated across M-warps)
if(passed) {
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
if(passed)
{
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp1_val = h_debug[(64 + lane) * 4 + i];
float warp3_val = h_debug[(192 + lane) * 4 + i];
if(std::abs(warp1_val - warp3_val) > 0.01f) {
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp1_val - warp3_val) > 0.01f)
{
std::cout << "ERROR: Warp 1 and Warp 3 differ at lane " << lane << " element "
<< i << "\n";
std::cout << " Warp 1: " << warp1_val << ", Warp 3: " << warp3_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
}
if(passed) {
if(passed)
{
std::cout << "\n✓ NWarp Replication verified:\n";
std::cout << " Warps 0 & 2 load identical data (N-warp 0, replicated across M-warps)\n";
std::cout << " Warps 1 & 3 load identical data (N-warp 1, replicated across M-warps)\n";

View File

@@ -18,34 +18,27 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestBDistributionKernel
{
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kK = 16;
static constexpr index_t kN = 16;
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kK = 16;
static constexpr index_t kN = 16;
CK_TILE_DEVICE void operator()(const DataType* b,
DataType* debug_output,
index_t ldb) const
CK_TILE_DEVICE void operator()(const DataType* b, DataType* debug_output, index_t ldb) const
{
if(get_block_id() != 0)
return;
const index_t tid = threadIdx.x;
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
// Create tensor view for B (row-major)
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
b,
make_tuple(kK, kN),
make_tuple(ldb, 1),
number<4>{},
number<1>{}
);
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
// B distribution WITH MWarp replication
// R dimension (index 0) has MWarp=2 replicas
@@ -65,84 +58,87 @@ struct TestBDistributionKernel
// sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
// );
// constexpr auto b_distribution = make_static_tile_distribution(
// tile_distribution_encoding<
// sequence<>, // No replication
// tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K values
// tuple<sequence<4, 4>, // H0 (K): 4 groups of 4 consecutive K
// values
// sequence<16>>, // H1 (N): 16 N positions
// tuple<sequence<1, 2>>, // P-dims map to H-dims (P0->H1, P1->H0)
// tuple<sequence<0, 0>>, // P positions in H-dims
// sequence<1>, // Y maps to K dimension (H0)
// sequence<1>>{} // Y at position 1 in H0 (the second 4 in sequence<4,4>)
// sequence<1>>{} // Y at position 1 in H0 (the second 4 in
// sequence<4,4>)
// );
// constexpr auto b_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<MWarp>,
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<NIterPerWarp, NWarp>,
// sequence<KIterPerWarp>>, tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// The key: Ps_to_Hs must include dimension 0 (the R dimension)!
constexpr auto b_distribution = make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1), P2→N(dim 2)
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1),
// P2→N(dim 2)
tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
);
auto b_window = make_tile_window(
b_tensor,
make_tuple(number<kK>{}, number<kN>{}),
{0, 0},
b_distribution
);
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_distribution);
const auto b_tile = load_tile(b_window);
const auto b_tile = load_tile(b_window);
const auto& thread_buffer = b_tile.get_thread_buffer();
// Sequential printing with synchronizations (copied from test_a)
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
printf("Distribution covers K×32 matrix (K × NWarp×16 threads)\n");
printf("With MWarp=2 replication, pattern repeats\n");
printf("Showing first 16 threads of each warp:\n\n");
}
__syncthreads();
// Print warp by warp with calculated coordinates
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
w, w/NWarp, w%NWarp);
if(warp_id == w && lane_id == 0)
{
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
}
__syncthreads();
// Print lanes sequentially within each warp
for(int lane = 0; lane < 64; ++lane) {
for(int lane = 0; lane < 64; ++lane)
{
__syncthreads();
if(warp_id == w && lane_id == lane) {
if(warp_id == w && lane_id == lane)
{
printf("W%d L%02d: ", w, lane);
// Print loaded values
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
{
float val = static_cast<float>(thread_buffer[y_idx]);
int k = static_cast<int>(val) % 100;
int n = static_cast<int>(val) / 100;
int k = static_cast<int>(val) % 100;
int n = static_cast<int>(val) / 100;
printf("B[%2d,%2d] ", k, n);
}
printf("\n");
@@ -152,7 +148,8 @@ struct TestBDistributionKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Observed Pattern ===\n");
printf("Warps 0 & 1 are identical\n");
printf("Warps 2 & 3 are identical\n");
@@ -160,7 +157,8 @@ struct TestBDistributionKernel
}
// Store for verification
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
debug_output[tid * 4 + i] = thread_buffer[i];
}
}
@@ -172,8 +170,8 @@ int main()
std::cout << "Test B Distribution with MWarp Replication\n";
std::cout << "==================================================\n\n";
constexpr index_t K = 16;
constexpr index_t N = 32;
constexpr index_t K = 16;
constexpr index_t N = 32;
constexpr index_t ldb = N;
using DataType = half_t;
@@ -183,8 +181,10 @@ int main()
std::vector<DataType> h_debug(256 * 4, -1);
// Initialize B[k,n] = k + n*100 (row-major)
for(index_t k = 0; k < K; ++k) {
for(index_t n = 0; n < N; ++n) {
for(index_t k = 0; k < K; ++k)
{
for(index_t n = 0; n < N; ++n)
{
h_b[k * ldb + n] = static_cast<DataType>(k + n * 100);
}
}
@@ -197,14 +197,13 @@ int main()
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestBDistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
ldb));
make_kernel<256>(TestBDistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
ldb));
hip_check_error(hipDeviceSynchronize());
@@ -212,40 +211,52 @@ int main()
// Verify: Based on your observation, warps 0&1 identical, warps 2&3 identical
bool passed = true;
// Check warps 0 and 1
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp0_val = h_debug[lane * 4 + i];
float warp1_val = h_debug[(64 + lane) * 4 + i];
if(std::abs(warp0_val - warp1_val) > 0.01f) {
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp0_val - warp1_val) > 0.01f)
{
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i
<< "\n";
std::cout << " Warp 0: " << warp0_val << ", Warp 1: " << warp1_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
// Check warps 2 and 3
if(passed) {
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
if(passed)
{
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp2_val = h_debug[(128 + lane) * 4 + i];
float warp3_val = h_debug[(192 + lane) * 4 + i];
if(std::abs(warp2_val - warp3_val) > 0.01f) {
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp2_val - warp3_val) > 0.01f)
{
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element "
<< i << "\n";
std::cout << " Warp 2: " << warp2_val << ", Warp 3: " << warp3_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
}
if(passed) {
if(passed)
{
std::cout << "\n✓ Replication verified:\n";
std::cout << " Warps 0 & 1 load identical data\n";
std::cout << " Warps 2 & 3 load identical data\n";

View File

@@ -18,123 +18,119 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestBDistributionKernel
{
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kK = 16;
static constexpr index_t kN = 16;
static constexpr index_t kBlockSize = 256; // 4 warps
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kK = 16;
static constexpr index_t kN = 16;
CK_TILE_DEVICE void operator()(const DataType* b,
DataType* debug_output,
index_t ldb) const
CK_TILE_DEVICE void operator()(const DataType* b, DataType* debug_output, index_t ldb) const
{
if(get_block_id() != 0)
return;
const index_t tid = threadIdx.x;
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
// Create tensor view for B (row-major)
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
b,
make_tuple(kK, kN),
make_tuple(ldb, 1),
number<4>{},
number<1>{}
);
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
// B distribution using EMBED API (like 02_gemm)
// Separate block-level and warp-level distributions
// Step 1: Warp-level distribution (64 threads within one warp)
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
sequence<>, // No replication at warp level
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values (2D P-space)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<1>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
constexpr auto b_warp_dstr_encode =
tile_distribution_encoding<sequence<>, // No replication at warp level
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values
// (2D P-space)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<1>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
// Step 2: Block-level outer distribution (warp organization)
// Must have same NDimX as inner (2 dimensions: K and N)
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
sequence<MWarp>, // R: Replicate across M-warps
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1 in K-dim
tuple<sequence<2, 0>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both K and N
sequence<>>{}; // Ys_in_Hs
constexpr auto b_block_outer_dstr_encode =
tile_distribution_encoding<sequence<MWarp>, // R: Replicate across M-warps
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1
// in K-dim
tuple<sequence<2, 0>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both K and N
sequence<>>{}; // Ys_in_Hs
// constexpr auto b_distribution = make_static_tile_distribution(
// tile_distribution_encoding<
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
// tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2
// M-warps tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K
// elements
// sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1), P2→N(dim 2)
// tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1),
// P2→N(dim 2) tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
// sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
// sequence<1>>{} // Ys_in_Hs: Y at position 1 in K
// );
// Step 3: Embed warp-level into block-level
constexpr auto b_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
// Step 4: Create final distribution
constexpr auto b_distribution = make_static_tile_distribution(b_block_dstr_encode);
auto b_window = make_tile_window(
b_tensor,
make_tuple(number<kK>{}, number<kN>{}),
{0, 0},
b_distribution
);
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_distribution);
const auto b_tile = load_tile(b_window);
const auto b_tile = load_tile(b_window);
const auto& thread_buffer = b_tile.get_thread_buffer();
// Sequential printing with synchronizations (copied from test_a)
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Matrix Coverage (Tiled by Warp) ===\n");
printf("Distribution covers K×32 matrix (K × NWarp×16 threads)\n");
printf("With MWarp=2 replication, pattern repeats\n");
printf("Showing first 16 threads of each warp:\n\n");
}
__syncthreads();
// Print warp by warp with calculated coordinates
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n",
w, w/NWarp, w%NWarp);
if(warp_id == w && lane_id == 0)
{
printf("\n--- Warp %d (M-warp %d, N-warp %d) ---\n", w, w / NWarp, w % NWarp);
}
__syncthreads();
// Print lanes sequentially within each warp
for(int lane = 0; lane < 16; ++lane) {
for(int lane = 0; lane < 16; ++lane)
{
__syncthreads();
if(warp_id == w && lane_id == lane) {
if(warp_id == w && lane_id == lane)
{
printf("W%d L%02d: ", w, lane);
// Print loaded values
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx) {
for(int y_idx = 0; y_idx < thread_buffer.size(); ++y_idx)
{
float val = static_cast<float>(thread_buffer[y_idx]);
int k = static_cast<int>(val) % 100;
int n = static_cast<int>(val) / 100;
int k = static_cast<int>(val) % 100;
int n = static_cast<int>(val) / 100;
printf("B[%2d,%2d] ", k, n);
}
printf("\n");
@@ -144,7 +140,8 @@ struct TestBDistributionKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Observed Pattern ===\n");
printf("Warps 0 & 1 are identical\n");
printf("Warps 2 & 3 are identical\n");
@@ -152,7 +149,8 @@ struct TestBDistributionKernel
}
// Store for verification
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
debug_output[tid * 4 + i] = thread_buffer[i];
}
}
@@ -163,10 +161,11 @@ int main()
std::cout << "\n==================================================\n";
std::cout << "Test B Distribution using EMBED API\n";
std::cout << "==================================================\n\n";
std::cout << "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
std::cout
<< "Separates block-level (warp organization) from warp-level (thread organization)\n\n";
constexpr index_t K = 32;
constexpr index_t N = 32;
constexpr index_t K = 32;
constexpr index_t N = 32;
constexpr index_t ldb = N;
using DataType = half_t;
@@ -176,8 +175,10 @@ int main()
std::vector<DataType> h_debug(256 * 4, -1);
// Initialize B[k,n] = k + n*100 (row-major)
for(index_t k = 0; k < K; ++k) {
for(index_t n = 0; n < N; ++n) {
for(index_t k = 0; k < K; ++k)
{
for(index_t n = 0; n < N; ++n)
{
h_b[k * ldb + n] = static_cast<DataType>(k + n * 100);
}
}
@@ -190,14 +191,13 @@ int main()
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestBDistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
ldb));
make_kernel<256>(TestBDistributionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
ldb));
hip_check_error(hipDeviceSynchronize());
@@ -205,40 +205,52 @@ int main()
// Verify: Based on your observation, warps 0&1 identical, warps 2&3 identical
bool passed = true;
// Check warps 0 and 1
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp0_val = h_debug[lane * 4 + i];
float warp1_val = h_debug[(64 + lane) * 4 + i];
if(std::abs(warp0_val - warp1_val) > 0.01f) {
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp0_val - warp1_val) > 0.01f)
{
std::cout << "ERROR: Warp 0 and Warp 1 differ at lane " << lane << " element " << i
<< "\n";
std::cout << " Warp 0: " << warp0_val << ", Warp 1: " << warp1_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
// Check warps 2 and 3
if(passed) {
for(int lane = 0; lane < 64; ++lane) {
for(int i = 0; i < 4; ++i) {
if(passed)
{
for(int lane = 0; lane < 64; ++lane)
{
for(int i = 0; i < 4; ++i)
{
float warp2_val = h_debug[(128 + lane) * 4 + i];
float warp3_val = h_debug[(192 + lane) * 4 + i];
if(std::abs(warp2_val - warp3_val) > 0.01f) {
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element " << i << "\n";
if(std::abs(warp2_val - warp3_val) > 0.01f)
{
std::cout << "ERROR: Warp 2 and Warp 3 differ at lane " << lane << " element "
<< i << "\n";
std::cout << " Warp 2: " << warp2_val << ", Warp 3: " << warp3_val << "\n";
passed = false;
break;
}
}
if(!passed) break;
if(!passed)
break;
}
}
if(passed) {
if(passed)
{
std::cout << "\n✓ Replication verified:\n";
std::cout << " Warps 0 & 1 load identical data\n";
std::cout << " Warps 2 & 3 load identical data\n";

View File

@@ -28,18 +28,18 @@
using namespace ck_tile;
// Tile Sweeping HGEMM kernel with multiple warps
template<typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
template <typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
struct TileSweepingHgemmKernel
{
static constexpr index_t kWaveSize = 64; // AMD wave size
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
static constexpr index_t kWaveSize = 64; // AMD wave size
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
// Warp configuration: 2×2 warps per block
static constexpr index_t MWarp = 2; // 2 warps in M dimension
static constexpr index_t NWarp = 2; // 2 warps in N dimension
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
static constexpr index_t MWarp = 2; // 2 warps in M dimension
static constexpr index_t NWarp = 2; // 2 warps in N dimension
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
// No iterations - each warp computes exactly one 16×16 output tile
@@ -53,24 +53,24 @@ struct TileSweepingHgemmKernel
index_t M,
index_t N,
index_t K,
index_t lda, // Leading dimension of A (column-major)
index_t ldb, // Leading dimension of B (row-major)
index_t ldc, // Leading dimension of C (column-major)
index_t ldd, // Leading dimension of D (column-major)
index_t lda, // Leading dimension of A (column-major)
index_t ldb, // Leading dimension of B (row-major)
index_t ldc, // Leading dimension of C (column-major)
index_t ldd, // Leading dimension of D (column-major)
AccDataType alpha,
AccDataType beta) const
{
// Calculate which warp this thread belongs to within the block
const index_t warp_id = get_warp_id();
const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
const index_t warp_id = get_warp_id();
const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
const index_t block_id = get_block_id();
// Convert linear block_id to 2D grid coordinates
const index_t num_blocks_n = N / (NWarp * kWarpN); // Number of blocks in N dimension
const index_t block_m = block_id / num_blocks_n; // M-block index
const index_t block_n = block_id % num_blocks_n; // N-block index
const index_t num_blocks_n = N / (NWarp * kWarpN); // Number of blocks in N dimension
const index_t block_m = block_id / num_blocks_n; // M-block index
const index_t block_n = block_id % num_blocks_n; // N-block index
// printf("Block %d (grid [%d,%d]), Warp %d (M-warp %d, N-warp %d)\n",
// block_id, block_m, block_n, warp_id, iMWarp, iNWarp);
@@ -86,96 +86,91 @@ struct TileSweepingHgemmKernel
// A is column-major: M×K with stride lda between columns
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
a,
make_tuple(M, K), // Shape: M×K
make_tuple(1, lda), // Strides: column-major
make_tuple(M, K), // Shape: M×K
make_tuple(1, lda), // Strides: column-major
number<1>{},
number<1>{}
);
number<1>{});
// B is row-major: K×N with stride ldb between rows
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
b,
make_tuple(K, N), // Shape: K×N
make_tuple(ldb, 1), // Strides: row-major
make_tuple(K, N), // Shape: K×N
make_tuple(ldb, 1), // Strides: row-major
number<4>{},
number<1>{}
);
number<1>{});
// C is column-major: M×N with stride ldc between columns
const auto c_tensor = make_naive_tensor_view<address_space_enum::global>(
c,
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldc), // Strides: column-major
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldc), // Strides: column-major
number<1>{},
number<1>{}
);
number<1>{});
// D is column-major: M×N with stride ldd between columns
auto d_tensor = make_naive_tensor_view<address_space_enum::global>(
d,
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldd), // Strides: column-major
make_tuple(M, N), // Shape: M×N
make_tuple(1, ldd), // Strides: column-major
number<1>{},
number<1>{}
);
number<1>{});
// ============================================================================
// TILE DISTRIBUTIONS using EMBED API (from verified tests)
// ============================================================================
// Step 1: Warp-level distribution (64 threads within one warp)
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
sequence<>, // No replication at warp level
tuple<sequence<16>, // H0 (M): 16 M positions
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<2>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
constexpr auto a_warp_dstr_encode =
tile_distribution_encoding<sequence<>, // No replication at warp level
tuple<sequence<16>, // H0 (M): 16 M positions
sequence<4, 4>>, // H1 (K): 4×4 = 16 K elements
tuple<sequence<2, 1>>, // Ps_to_Hs: 2D P-space (64 threads)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<2>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
// Step 2: Block-level outer distribution (warp organization)
// Must have same NDimX as inner (2 dimensions: M and K)
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
sequence<NWarp>, // R: Replicate across N-warps
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1 in K-dim
tuple<sequence<0, 1>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both M and K
sequence<>>{}; // Ys_in_Hs
constexpr auto a_block_outer_dstr_encode =
tile_distribution_encoding<sequence<NWarp>, // R: Replicate across N-warps
tuple<sequence<MWarp>, sequence<>>, // H: MWarp in M-dim, 1
// in K-dim
tuple<sequence<0, 1>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both M and K
sequence<>>{}; // Ys_in_Hs
// B warp-level distribution
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
sequence<>, // No replication at warp level
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values (2D P-space)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<1>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
constexpr auto b_warp_dstr_encode =
tile_distribution_encoding<sequence<>, // No replication at warp level
tuple<sequence<4, 4>, // H0 (K): 4×4 = 16 K elements
sequence<16>>, // H1 (N): 16 N positions
tuple<sequence<1, 2>>, // Ps_to_Hs: 1 sequence with 2 values
// (2D P-space)
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<1>, // Ys_to_Hs: Y maps to K
sequence<1>>{}; // Ys_in_Hs
// Step 2: Block-level outer distribution (warp organization)
// Must have same NDimX as inner (2 dimensions: K and N)
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
sequence<MWarp>, // R: Replicate across M-warps
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1 in K-dim
tuple<sequence<2, 0>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both K and N
sequence<>>{}; // Ys_in_Hs
constexpr auto b_block_outer_dstr_encode =
tile_distribution_encoding<sequence<MWarp>, // R: Replicate across M-warps
tuple<sequence<>, sequence<NWarp>>, // H: NWarp in N-dim, 1
// in K-dim
tuple<sequence<2, 0>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to both K and N
sequence<>>{}; // Ys_in_Hs
// Embed to create block-level distributions with replication
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto b_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
// /*direct approach*/
// constexpr auto a_block_dstr_encode =
// constexpr auto a_block_dstr_encode =
// tile_distribution_encoding<
// sequence<NWarp>, // R: REPLICATE across 2 N-warps
// tuple<sequence<MWarp, 16>, // H0 (M): 2 M-warps × 16 threads = 32 M
@@ -185,24 +180,23 @@ struct TileSweepingHgemmKernel
// sequence<2>, // Ys_to_Hs: Y maps to K (dimension 2)
// sequence<1>>{}; // Ys_in_Hs: Y at position 1 in K
// // Direct approach (like test_b_distribution_with_replication.cpp)
// constexpr auto b_block_dstr_encode =
// constexpr auto b_block_dstr_encode =
// tile_distribution_encoding<
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2 M-warps
// tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K elements
// sequence<MWarp>, // R: dimension 0, REPLICATE across 2
// M-warps tuple<sequence<4, 4>, // H: dimension 1 (K): 4×4 = 16 K
// elements
// sequence<2, 16>>, // H: dimension 2 (N): 16 N positions
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1), P2→N(dim 2)
// tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
// tuple<sequence<2, 0>, sequence<1, 2>>, // Ps_to_Hs: P0→R(dim 0), P1→K(dim 1),
// P2→N(dim 2) tuple<sequence<0, 0>, sequence<0, 1>>, // Ps_in_Hs: positions
// sequence<1>, // Ys_to_Hs: Y maps to K (dimension 1)
// sequence<1>>{}; // Ys_in_Hs: Y at position 1 in K
// /*direct approach*/
// Use block-level distributions for loading (includes replication)
constexpr auto a_block_distribution = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_distribution = make_static_tile_distribution(b_block_dstr_encode);
// C Distribution: Create block-level distribution for 32×32 output
// No replication needed - each warp computes its own unique output region
// 2D P-space for 4 warps: P[0] for M-warp, P[1] for N-warp
@@ -215,52 +209,47 @@ struct TileSweepingHgemmKernel
// sequence<1>, // No Y dimension for output
// sequence<2>>{};
constexpr auto c_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
constexpr auto c_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
constexpr auto c_block_outer_dstr_encode = tile_distribution_encoding<
sequence<>, // No replication for output
tuple<sequence<MWarp>, // H0: M iterations
sequence<NWarp>>, // H1: N iterations
tuple<sequence<2, 1>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to BOTH M and N
sequence<>>{}; // Ys_in_Hs
constexpr auto c_block_outer_dstr_encode =
tile_distribution_encoding<sequence<>, // No replication for output
tuple<sequence<MWarp>, // H0: M iterations
sequence<NWarp>>, // H1: N iterations
tuple<sequence<2, 1>>, // Ps_to_Hs
tuple<sequence<0, 0>>, // Ps_in_Hs
sequence<>, // Ys_to_Hs: Y maps to BOTH M and N
sequence<>>{}; // Ys_in_Hs
constexpr auto c_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encode, c_warp_dstr_encode);
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encode, c_warp_dstr_encode);
constexpr auto c_block_distribution = make_static_tile_distribution(c_block_dstr_encode);
// Create block-level windows (one K-chunk at a time)
// A: 32×16 (all M-rows × one K-chunk)
// B: 16×32 (one K-chunk × all N-columns)
auto a_block_window = make_tile_window(
a_tensor,
make_tuple(number<MWarp * kWarpM>{}, number<kWarpK>{}), // 32×16
{block_m * (MWarp * kWarpM), 0},
a_block_distribution
);
auto a_block_window =
make_tile_window(a_tensor,
make_tuple(number<MWarp * kWarpM>{}, number<kWarpK>{}), // 32×16
{block_m * (MWarp * kWarpM), 0},
a_block_distribution);
auto b_block_window = make_tile_window(
b_tensor,
make_tuple(number<kWarpK>{}, number<NWarp * kWarpN>{}), // 16×32
{0, block_n * (NWarp * kWarpN)},
b_block_distribution
);
auto b_block_window =
make_tile_window(b_tensor,
make_tuple(number<kWarpK>{}, number<NWarp * kWarpN>{}), // 16×32
{0, block_n * (NWarp * kWarpN)},
b_block_distribution);
// Create block-level accumulator tile (covers all 4 warps)
auto c_block_tile = make_static_distributed_tensor<AccDataType>(c_block_distribution);
set_tile(c_block_tile, AccDataType{0});
// Main K-loop
const index_t num_k_loops = K / kWarpK;
for(index_t k_iter = 0; k_iter < num_k_loops; ++k_iter)
@@ -269,13 +258,14 @@ struct TileSweepingHgemmKernel
// Each warp gets its correct portion based on the distribution encoding
const auto a_tile = load_tile(a_block_window);
const auto b_tile = load_tile(b_block_window);
// Perform MFMA: C += A * B
// Each warp updates its portion of the block tile
WarpGemm{}(c_block_tile, a_tile, b_tile);
// // Move windows to next K chunk
if(k_iter < num_k_loops - 1) {
if(k_iter < num_k_loops - 1)
{
move_tile_window(a_block_window, {0, kWarpK});
move_tile_window(b_block_window, {kWarpK, 0});
}
@@ -289,53 +279,58 @@ struct TileSweepingHgemmKernel
{
auto c_block_window = make_tile_window(
c_tensor,
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
{block_m * (MWarp * kWarpM), block_n * (NWarp * kWarpN)},
c_block_distribution
);
c_block_distribution);
const auto c_input_block_tile = load_tile(c_block_window);
tile_elementwise_inout(
[beta](const auto& c_val, auto& acc_val) {
acc_val += beta * c_val;
},
c_input_block_tile, c_block_tile);
[beta](const auto& c_val, auto& acc_val) { acc_val += beta * c_val; },
c_input_block_tile,
c_block_tile);
}
// Store final result to D (entire block)
auto d_block_window = make_tile_window(
d_tensor,
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
make_tuple(number<MWarp * kWarpM>{}, number<NWarp * kWarpN>{}), // 32×32
{block_m * (MWarp * kWarpM), block_n * (NWarp * kWarpN)},
c_block_distribution
);
c_block_distribution);
store_tile(d_block_window, c_block_tile);
}
};
// CPU reference for verification
template<typename InType, typename AccType>
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
const std::vector<InType>& b, // Row-major
const std::vector<AccType>& c, // Column-major
std::vector<AccType>& d, // Column-major
index_t M, index_t N, index_t K,
index_t lda, index_t ldb, index_t ldc, index_t ldd,
AccType alpha, AccType beta)
template <typename InType, typename AccType>
void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
const std::vector<InType>& b, // Row-major
const std::vector<AccType>& c, // Column-major
std::vector<AccType>& d, // Column-major
index_t M,
index_t N,
index_t K,
index_t lda,
index_t ldb,
index_t ldc,
index_t ldd,
AccType alpha,
AccType beta)
{
// D = alpha * A * B + beta * C
for(index_t n = 0; n < N; ++n) {
for(index_t m = 0; m < M; ++m) {
for(index_t n = 0; n < N; ++n)
{
for(index_t m = 0; m < M; ++m)
{
AccType sum = 0;
// Compute A * B
for(index_t k = 0; k < K; ++k) {
for(index_t k = 0; k < K; ++k)
{
// A is column-major: A[m,k] = a[m + k*lda]
// B is row-major: B[k,n] = b[k*ldb + n]
sum += static_cast<AccType>(a[m + k * lda]) *
static_cast<AccType>(b[k * ldb + n]);
sum += static_cast<AccType>(a[m + k * lda]) * static_cast<AccType>(b[k * ldb + n]);
}
// D[m,n] = alpha * sum + beta * C[m,n]
@@ -346,30 +341,38 @@ void reference_gemm_mixed(const std::vector<InType>& a, // Column-major
}
// Helper to fill matrix with random values
template<typename T>
template <typename T>
void fill_random(std::vector<T>& data, T min_val = -1, T max_val = 1)
{
for(auto& val : data) {
val = static_cast<T>(min_val + (max_val - min_val) *
static_cast<float>(rand()) / RAND_MAX);
for(auto& val : data)
{
val = static_cast<T>(min_val + (max_val - min_val) * static_cast<float>(rand()) / RAND_MAX);
}
}
// Helper to print matrix (for debugging)
template<typename T>
void print_matrix(const std::vector<T>& mat, index_t rows, index_t cols,
index_t ld, bool col_major = true, const std::string& name = "Matrix")
template <typename T>
void print_matrix(const std::vector<T>& mat,
index_t rows,
index_t cols,
index_t ld,
bool col_major = true,
const std::string& name = "Matrix")
{
std::cout << name << " (" << rows << "×" << cols << "):\n";
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i) {
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j) {
for(index_t i = 0; i < std::min(rows, index_t(8)); ++i)
{
for(index_t j = 0; j < std::min(cols, index_t(8)); ++j)
{
index_t idx = col_major ? (i + j * ld) : (i * ld + j);
std::cout << std::setw(8) << std::setprecision(3) << mat[idx] << " ";
}
if(cols > 8) std::cout << "...";
if(cols > 8)
std::cout << "...";
std::cout << "\n";
}
if(rows > 8) std::cout << "...\n";
if(rows > 8)
std::cout << "...\n";
std::cout << "\n";
}
@@ -392,16 +395,16 @@ int main()
constexpr index_t K = 32;
// Leading dimensions
constexpr index_t lda = M; // Column-major
constexpr index_t ldb = N; // Row-major
constexpr index_t ldc = M; // Column-major
constexpr index_t ldd = M; // Column-major
constexpr index_t lda = M; // Column-major
constexpr index_t ldb = N; // Row-major
constexpr index_t ldc = M; // Column-major
constexpr index_t ldd = M; // Column-major
using InputType = half_t; // fp16
using AccumType = float; // fp32
using InputType = half_t; // fp16
using AccumType = float; // fp32
constexpr AccumType alpha = 2.0f;
constexpr AccumType beta = 1.5f;
constexpr AccumType beta = 1.5f;
std::cout << "Problem configuration:\n";
std::cout << " M×N×K: " << M << "×" << N << "×" << K << "\n";
@@ -409,7 +412,7 @@ int main()
std::cout << " B: row-major, ldb=" << ldb << " (fp16)\n";
std::cout << " C/D: column-major, ldc=" << ldc << ", ldd=" << ldd << " (fp32)\n";
std::cout << " alpha=" << alpha << ", beta=" << beta << "\n";
std::cout << " Total FLOPs: " << 2*M*N*K << "\n\n";
std::cout << " Total FLOPs: " << 2 * M * N * K << "\n\n";
// Host memory
std::vector<InputType> h_a(M * K);
@@ -427,7 +430,7 @@ int main()
// CPU reference
auto cpu_start = std::chrono::high_resolution_clock::now();
reference_gemm_mixed(h_a, h_b, h_c, h_d_ref, M, N, K, lda, ldb, ldc, ldd, alpha, beta);
auto cpu_end = std::chrono::high_resolution_clock::now();
auto cpu_end = std::chrono::high_resolution_clock::now();
double cpu_time_ms = std::chrono::duration<double, std::milli>(cpu_end - cpu_start).count();
// Device memory
@@ -442,34 +445,44 @@ int main()
d_d.ToDevice(h_d.data(), M * N * sizeof(AccumType));
// Launch kernel
constexpr index_t block_size = 256; // 4 warps (2×2 configuration)
constexpr index_t tiles_per_block_m = 2; // MWarp (no iterations)
constexpr index_t tiles_per_block_n = 2; // NWarp (no iterations)
constexpr index_t block_size = 256; // 4 warps (2×2 configuration)
constexpr index_t tiles_per_block_m = 2; // MWarp (no iterations)
constexpr index_t tiles_per_block_n = 2; // NWarp (no iterations)
const index_t grid_size = (M / (tiles_per_block_m * 16)) * (N / (tiles_per_block_n * 16));
std::cout << "Launching kernel:\n";
std::cout << " Grid: " << grid_size << " blocks\n";
std::cout << " Block: " << block_size << " threads (4 warps in 2×2 config)\n";
std::cout << " Each warp computes: ONE 16×16 output tile\n";
std::cout << " Each block computes: " << tiles_per_block_m*16 << "×" << tiles_per_block_n*16 << " output\n";
std::cout << " Total output tiles: " << (M/16) << "×" << (N/16) << "\n";
std::cout << " MFMA instructions per warp: " << (K/16) << "\n\n";
std::cout << " Each block computes: " << tiles_per_block_m * 16 << "×"
<< tiles_per_block_n * 16 << " output\n";
std::cout << " Total output tiles: " << (M / 16) << "×" << (N / 16) << "\n";
std::cout << " MFMA instructions per warp: " << (K / 16) << "\n\n";
stream_config stream;
// Warmup
for(int i = 0; i < 5; ++i) {
for(int i = 0; i < 5; ++i)
{
launch_kernel(stream,
make_kernel<block_size>(
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
make_kernel<block_size>(
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M,
N,
K,
lda,
ldb,
ldc,
ldd,
alpha,
beta));
}
hip_check_error(hipDeviceSynchronize());
@@ -477,40 +490,50 @@ int main()
auto gpu_start = std::chrono::high_resolution_clock::now();
launch_kernel(stream,
make_kernel<block_size>(
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
make_kernel<block_size>(
TileSweepingHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M,
N,
K,
lda,
ldb,
ldc,
ldd,
alpha,
beta));
hip_check_error(hipDeviceSynchronize());
auto gpu_end = std::chrono::high_resolution_clock::now();
auto gpu_end = std::chrono::high_resolution_clock::now();
double gpu_time_ms = std::chrono::duration<double, std::milli>(gpu_end - gpu_start).count();
// Get result
d_d.FromDevice(h_d.data(), M * N * sizeof(AccumType));
// Verify
bool passed = true;
float max_error = 0;
bool passed = true;
float max_error = 0;
index_t error_count = 0;
for(index_t i = 0; i < M * N; ++i) {
for(index_t i = 0; i < M * N; ++i)
{
float error = std::abs(h_d[i] - h_d_ref[i]);
max_error = std::max(max_error, error);
if(error > 1e-2f) { // Relaxed tolerance for fp16
if(error_count < 5) {
max_error = std::max(max_error, error);
if(error > 1e-2f)
{ // Relaxed tolerance for fp16
if(error_count < 5)
{
index_t m = i % M;
index_t n = i / M;
std::cout << "Error at [" << m << "," << n << "]: "
<< h_d[i] << " vs " << h_d_ref[i]
<< " (diff=" << error << ")\n";
std::cout << "Error at [" << m << "," << n << "]: " << h_d[i] << " vs "
<< h_d_ref[i] << " (diff=" << error << ")\n";
}
error_count++;
}
@@ -519,14 +542,15 @@ int main()
passed = (error_count == 0);
// Calculate performance
double gflops = 2.0 * M * N * K / 1e9;
double gflops = 2.0 * M * N * K / 1e9;
double gpu_tflops = gflops / (gpu_time_ms / 1000);
double cpu_gflops = gflops / (cpu_time_ms / 1000);
std::cout << "Results:\n";
std::cout << " Correctness: " << (passed ? "✓ PASSED" : "✗ FAILED") << "\n";
std::cout << " Max error: " << max_error << "\n";
if(!passed) std::cout << " Error count: " << error_count << "/" << M*N << "\n";
if(!passed)
std::cout << " Error count: " << error_count << "/" << M * N << "\n";
std::cout << "\n";
std::cout << "Performance:\n";

View File

@@ -16,25 +16,23 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestADistributionYRepetitionKernel
{
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t MIterPerWarp = 2;
static constexpr index_t KIterPerWarp = 1;
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
static constexpr index_t kK = 64;
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
static constexpr index_t kK = 64;
CK_TILE_DEVICE void operator()(const DataType* a,
DataType* debug_output,
index_t lda) const
CK_TILE_DEVICE void operator()(const DataType* a, DataType* debug_output, index_t lda) const
{
if(get_block_id() != 0)
return;
const index_t tid = threadIdx.x;
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
@@ -52,56 +50,56 @@ struct TestADistributionYRepetitionKernel
// sequence<>, // Ys_to_Hs: Y maps to both M and K
// sequence<>>{}; // Ys_in_Hs
// A distribution with Y-repetition (from tutorial_07)
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<16>, sequence<4, 4>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
constexpr auto a_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<4, 4>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_outer_dstr_encode =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto a_distribution = make_static_tile_distribution(a_block_dstr_encode);
auto a_window = make_tile_window(
a_tensor, make_tuple(number<kM>{}, number<kK>{}),
{0, 0}, a_distribution);
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_distribution);
const auto a_tile = load_tile(a_window);
const auto a_tile = load_tile(a_window);
const auto& thread_buffer = a_tile.get_thread_buffer();
// Print from all warps sequentially
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== A Distribution with Y-Repetition Test ===\n");
printf("Matrix: 64×16 (MWarp=2, MIterPerWarp=2, each warp loads 2×16 tiles)\n");
printf("Input: A[m,k] = m + k*100 (unique values)\n\n");
}
__syncthreads();
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("Warp %d (M-warp %d, N-warp %d):\n",
w, w/NWarp, w%NWarp);
if(warp_id == w && lane_id == 0)
{
printf("Warp %d (M-warp %d, N-warp %d):\n", w, w / NWarp, w % NWarp);
printf(" Thread buffer size: %d\n", static_cast<int>(thread_buffer.size()));
printf(" Values: ");
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
printf("%.0f ", static_cast<float>(thread_buffer[i]));
}
printf("\n");
@@ -110,7 +108,8 @@ struct TestADistributionYRepetitionKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Expected Pattern ===\n");
printf("Each warp should load 8 elements (2 M-iters × 1 K-iter × 4 warp elements)\n");
printf("Warp 0 (M-warp 0): Should have M-rows [0-15] and [16-31]\n");
@@ -120,7 +119,8 @@ struct TestADistributionYRepetitionKernel
}
// Store for verification
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
debug_output[tid * 8 + i] = thread_buffer[i];
}
}
@@ -132,8 +132,8 @@ int main()
std::cout << "Test A Distribution with Y-Dimension Repetition\n";
std::cout << "==================================================\n\n";
constexpr index_t M = 64;
constexpr index_t K = 64;
constexpr index_t M = 64;
constexpr index_t K = 64;
constexpr index_t lda = M;
using DataType = half_t;
@@ -142,8 +142,10 @@ int main()
std::vector<DataType> h_debug(256 * 8, -1);
// Initialize A[m,k] = m + k*100 (unique for each position)
for(index_t k = 0; k < K; ++k) {
for(index_t m = 0; m < M; ++m) {
for(index_t k = 0; k < K; ++k)
{
for(index_t m = 0; m < M; ++m)
{
h_a[m + k * lda] = static_cast<DataType>(m + k * 100);
}
}
@@ -156,12 +158,13 @@ int main()
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestADistributionYRepetitionKernel<DataType>{},
dim3(1), dim3(256), 0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
lda));
make_kernel<256>(TestADistributionYRepetitionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
lda));
hip_check_error(hipDeviceSynchronize());

View File

@@ -18,24 +18,23 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestAYSlicingKernel
{
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t MIterPerWarp = 2;
static constexpr index_t KIterPerWarp = 1;
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
static constexpr index_t kK = 16; // Fixed to match distribution coverage
static constexpr index_t kM = 64; // 2 warps × 2 iters × 16
static constexpr index_t kK = 16; // Fixed to match distribution coverage
CK_TILE_DEVICE void operator()(const DataType* a,
index_t lda) const
CK_TILE_DEVICE void operator()(const DataType* a, index_t lda) const
{
if(get_block_id() != 0)
return;
const index_t tid = threadIdx.x;
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
@@ -44,51 +43,50 @@ struct TestAYSlicingKernel
a, make_tuple(kM, kK), make_tuple(1, lda), number<1>{}, number<1>{});
// A warp-level distribution
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<16>, sequence<4, 4>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
constexpr auto a_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<4, 4>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
// A block-level outer distribution with Y-repetition
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>,
sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_outer_dstr_encode =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto a_block_distribution = make_static_tile_distribution(a_block_dstr_encode);
// Get Y-dimension information
using AWarpDstr = decltype(make_static_tile_distribution(a_warp_dstr_encode));
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
// Create window and load full block tile
auto a_window = make_tile_window(
a_tensor, make_tuple(number<kM>{}, number<kK>{}),
{0, 0}, a_block_distribution);
a_tensor, make_tuple(number<kM>{}, number<kK>{}), {0, 0}, a_block_distribution);
const auto a_block_tile = load_tile(a_window);
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== A Y-Slicing Test ===\n");
printf("Block tile: %d×%d (M×K)\n", kM, kK);
printf("MIterPerWarp=%d, KIterPerWarp=%d\n", MIterPerWarp, KIterPerWarp);
printf("Input: A[m,k] = m*1000 + k (unique values)\n\n");
}
__syncthreads();
// Test Y-slicing for each warp and iteration
@@ -97,7 +95,7 @@ struct TestAYSlicingKernel
// Extract warp tensor for this iteration
auto a_warp_tensor = make_static_distributed_tensor<DataType>(
make_static_tile_distribution(a_warp_dstr_encode));
// CORRECTED: kIter first, then mIter (matching the Ys_to_Hs order)
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
@@ -106,17 +104,25 @@ struct TestAYSlicingKernel
const auto& warp_buffer = a_warp_tensor.get_thread_buffer();
// Print from each warp sequentially
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("Warp %d (M-warp %d, N-warp %d), MIter=%d, KIter=%d:\n",
w, w/NWarp, w%NWarp, static_cast<int>(mIter), static_cast<int>(kIter));
if(warp_id == w && lane_id == 0)
{
printf("Warp %d (M-warp %d, N-warp %d), MIter=%d, KIter=%d:\n",
w,
w / NWarp,
w % NWarp,
static_cast<int>(mIter),
static_cast<int>(kIter));
printf(" Buffer size: %d\n", static_cast<int>(warp_buffer.size()));
printf(" Values: ");
for(int i = 0; i < warp_buffer.size() && i < 16; ++i) {
for(int i = 0; i < warp_buffer.size() && i < 16; ++i)
{
printf("%.0f ", static_cast<float>(warp_buffer[i]));
}
if(warp_buffer.size() > 16) printf("...");
if(warp_buffer.size() > 16)
printf("...");
printf("\n");
}
}
@@ -125,7 +131,8 @@ struct TestAYSlicingKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Expected Pattern ===\n");
printf("Each warp should get 4 elements per iteration (16 M × 16 K / 64 threads)\n");
printf("Warp 0, MIter=0: Should have values from A[0:16, 0:16]\n");
@@ -144,8 +151,8 @@ int main()
std::cout << "Test A Y-Slicing with get_y_sliced_thread_data\n";
std::cout << "==================================================\n\n";
constexpr index_t M = 64; // 2 warps × 2 iters × 16
constexpr index_t K = 16; // Match distribution coverage
constexpr index_t M = 64; // 2 warps × 2 iters × 16
constexpr index_t K = 16; // Match distribution coverage
constexpr index_t lda = M;
using DataType = half_t;
@@ -154,8 +161,10 @@ int main()
// Initialize A[m,k] = m*1000 + k (easy to identify position)
auto counter = 0;
for(index_t m = 0; m < M; ++m) {
for(index_t k = 0; k < K; ++k) {
for(index_t m = 0; m < M; ++m)
{
for(index_t k = 0; k < K; ++k)
{
h_a[m + k * lda] = static_cast<DataType>(counter++);
}
}
@@ -166,18 +175,19 @@ int main()
std::cout << " A[16,0] = " << static_cast<float>(h_a[16]) << "\n";
std::cout << " A[32,0] = " << static_cast<float>(h_a[32]) << "\n";
std::cout << " A[48,0] = " << static_cast<float>(h_a[48]) << "\n";
std::cout << " A[0,15] = " << static_cast<float>(h_a[15*lda]) << "\n\n";
std::cout << " A[0,15] = " << static_cast<float>(h_a[15 * lda]) << "\n\n";
DeviceMem d_a(M * K * sizeof(DataType));
d_a.ToDevice(h_a.data(), M * K * sizeof(DataType));
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestAYSlicingKernel<DataType>{},
dim3(1), dim3(256), 0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
lda));
make_kernel<256>(TestAYSlicingKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_a.GetDeviceBuffer()),
lda));
hip_check_error(hipDeviceSynchronize());

View File

@@ -16,26 +16,25 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestBDistributionYRepetitionKernel
{
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t NIterPerWarp = 2;
static constexpr index_t KIterPerWarp = 1;
static constexpr index_t kK = 64;
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
static constexpr index_t kK = 64;
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
CK_TILE_DEVICE void operator()(const DataType* b,
DataType* debug_output,
index_t ldb) const
CK_TILE_DEVICE void operator()(const DataType* b, DataType* debug_output, index_t ldb) const
{
if(get_block_id() != 0)
return;
//each warp is 64 x 4 items and 4 warps total and 2 iteration, so totally it becomes 64 x 32 we don't cover the whole matrix
const index_t tid = threadIdx.x;
// each warp is 64 x 4 items and 4 warps total and 2 iteration, so totally it becomes 64 x
// 32 we don't cover the whole matrix
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
@@ -44,55 +43,55 @@ struct TestBDistributionYRepetitionKernel
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
// B distribution with Y-repetition (from tutorial_07)
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
constexpr auto b_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<KIterPerWarp>,
sequence<NIterPerWarp, NWarp>>,
tuple<sequence<2, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encode =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<KIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<2, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
constexpr auto b_distribution = make_static_tile_distribution(b_block_dstr_encode);
auto b_window = make_tile_window(
b_tensor, make_tuple(number<kK>{}, number<kN>{}),
{0, 0}, b_distribution);
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_distribution);
const auto b_tile = load_tile(b_window);
const auto b_tile = load_tile(b_window);
const auto& thread_buffer = b_tile.get_thread_buffer();
// Print from all warps sequentially
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== B Distribution with Y-Repetition Test ===\n");
printf("Matrix: 16×64 (NWarp=2, NIterPerWarp=2, each warp loads 2×16 tiles)\n");
printf("Input: B[k,n] = k + n*100 (unique values)\n\n");
}
__syncthreads();
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("Warp %d (M-warp %d, N-warp %d):\n",
w, w/NWarp, w%NWarp);
if(warp_id == w && lane_id == 0)
{
printf("Warp %d (M-warp %d, N-warp %d):\n", w, w / NWarp, w % NWarp);
printf(" Thread buffer size: %d\n", static_cast<int>(thread_buffer.size()));
printf(" Values: ");
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
printf("%.0f ", static_cast<float>(thread_buffer[i]));
}
printf("\n");
@@ -101,7 +100,8 @@ struct TestBDistributionYRepetitionKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Expected Pattern ===\n");
printf("Each warp should load 8 elements (2 N-iters × 1 K-iter × 4 warp elements)\n");
printf("Warp 0 (N-warp 0): Should have N-cols [0-15] and [16-31]\n");
@@ -111,7 +111,8 @@ struct TestBDistributionYRepetitionKernel
}
// Store for verification
for(int i = 0; i < thread_buffer.size(); ++i) {
for(int i = 0; i < thread_buffer.size(); ++i)
{
debug_output[tid * 8 + i] = thread_buffer[i];
}
}
@@ -123,8 +124,8 @@ int main()
std::cout << "Test B Distribution with Y-Dimension Repetition\n";
std::cout << "==================================================\n\n";
constexpr index_t K = 64;
constexpr index_t N = 64;
constexpr index_t K = 64;
constexpr index_t N = 64;
constexpr index_t ldb = N;
using DataType = half_t;
@@ -134,8 +135,10 @@ int main()
// Initialize B[k,n] = k + n*100 (unique for each position)
auto counter = 0;
for(index_t k = 0; k < K; ++k) {
for(index_t n = 0; n < N; ++n) {
for(index_t k = 0; k < K; ++k)
{
for(index_t n = 0; n < N; ++n)
{
h_b[k * ldb + n] = static_cast<DataType>(counter++);
}
}
@@ -148,12 +151,13 @@ int main()
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestBDistributionYRepetitionKernel<DataType>{},
dim3(1), dim3(256), 0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
ldb));
make_kernel<256>(TestBDistributionYRepetitionKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
static_cast<DataType*>(d_debug.GetDeviceBuffer()),
ldb));
hip_check_error(hipDeviceSynchronize());

View File

@@ -18,24 +18,23 @@
using namespace ck_tile;
template<typename DataType>
template <typename DataType>
struct TestBYSlicingKernel
{
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t kBlockSize = 256;
static constexpr index_t MWarp = 2;
static constexpr index_t NWarp = 2;
static constexpr index_t NIterPerWarp = 2;
static constexpr index_t KIterPerWarp = 1;
static constexpr index_t kK = 16; // Fixed to match distribution coverage
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
static constexpr index_t kK = 16; // Fixed to match distribution coverage
static constexpr index_t kN = 64; // 2 warps × 2 iters × 16
CK_TILE_DEVICE void operator()(const DataType* b,
index_t ldb) const
CK_TILE_DEVICE void operator()(const DataType* b, index_t ldb) const
{
if(get_block_id() != 0)
return;
const index_t tid = threadIdx.x;
const index_t tid = threadIdx.x;
const index_t warp_id = tid / 64;
const index_t lane_id = tid % 64;
@@ -44,51 +43,50 @@ struct TestBYSlicingKernel
b, make_tuple(kK, kN), make_tuple(ldb, 1), number<4>{}, number<1>{});
// B warp-level distribution
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
constexpr auto b_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
// B block-level outer distribution with Y-repetition
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<KIterPerWarp>,
sequence<NIterPerWarp, NWarp>>,
tuple<sequence<2, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encode =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<KIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<2, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
constexpr auto b_block_distribution = make_static_tile_distribution(b_block_dstr_encode);
// Get Y-dimension information
using BWarpDstr = decltype(make_static_tile_distribution(b_warp_dstr_encode));
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
// Create window and load full block tile
auto b_window = make_tile_window(
b_tensor, make_tuple(number<kK>{}, number<kN>{}),
{0, 0}, b_block_distribution);
b_tensor, make_tuple(number<kK>{}, number<kN>{}), {0, 0}, b_block_distribution);
const auto b_block_tile = load_tile(b_window);
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== B Y-Slicing Test ===\n");
printf("Block tile: %d×%d (K×N)\n", kK, kN);
printf("NIterPerWarp=%d, KIterPerWarp=%d\n", NIterPerWarp, KIterPerWarp);
printf("Input: B[k,n] = k*1000 + n (unique values)\n\n");
}
__syncthreads();
// Test Y-slicing for each warp and iteration
@@ -97,7 +95,7 @@ struct TestBYSlicingKernel
// Extract warp tensor for this iteration
auto b_warp_tensor = make_static_distributed_tensor<DataType>(
make_static_tile_distribution(b_warp_dstr_encode));
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
@@ -105,17 +103,25 @@ struct TestBYSlicingKernel
const auto& warp_buffer = b_warp_tensor.get_thread_buffer();
// Print from each warp sequentially
for(int w = 0; w < 4; ++w) {
for(int w = 0; w < 4; ++w)
{
__syncthreads();
if(warp_id == w && lane_id == 0) {
printf("Warp %d (M-warp %d, N-warp %d), NIter=%d, KIter=%d:\n",
w, w/NWarp, w%NWarp, static_cast<int>(nIter), static_cast<int>(kIter));
if(warp_id == w && lane_id == 0)
{
printf("Warp %d (M-warp %d, N-warp %d), NIter=%d, KIter=%d:\n",
w,
w / NWarp,
w % NWarp,
static_cast<int>(nIter),
static_cast<int>(kIter));
printf(" Buffer size: %d\n", static_cast<int>(warp_buffer.size()));
printf(" Values: ");
for(int i = 0; i < warp_buffer.size() && i < 16; ++i) {
for(int i = 0; i < warp_buffer.size() && i < 16; ++i)
{
printf("%.0f ", static_cast<float>(warp_buffer[i]));
}
if(warp_buffer.size() > 16) printf("...");
if(warp_buffer.size() > 16)
printf("...");
printf("\n");
}
}
@@ -124,7 +130,8 @@ struct TestBYSlicingKernel
__syncthreads();
if(tid == 0) {
if(tid == 0)
{
printf("\n=== Expected Pattern ===\n");
printf("Each warp should get 4 elements per iteration (16 K × 16 N / 64 threads)\n");
printf("Warp 0, NIter=0: Should have values from B[0:16, 0:16]\n");
@@ -142,8 +149,8 @@ int main()
std::cout << "Test B Y-Slicing with get_y_sliced_thread_data\n";
std::cout << "==================================================\n\n";
constexpr index_t K = 16; // Match distribution coverage
constexpr index_t N = 64;
constexpr index_t K = 16; // Match distribution coverage
constexpr index_t N = 64;
constexpr index_t ldb = N;
using DataType = half_t;
@@ -152,8 +159,10 @@ int main()
// Initialize B[k,n] = k*1000 + n (easy to identify position)
auto counter = 0;
for(index_t k = 0; k < K; ++k) {
for(index_t n = 0; n < N; ++n) {
for(index_t k = 0; k < K; ++k)
{
for(index_t n = 0; n < N; ++n)
{
h_b[k * ldb + n] = static_cast<DataType>(counter++);
}
}
@@ -164,18 +173,19 @@ int main()
std::cout << " B[0,16] = " << static_cast<float>(h_b[16]) << "\n";
std::cout << " B[0,32] = " << static_cast<float>(h_b[32]) << "\n";
std::cout << " B[0,48] = " << static_cast<float>(h_b[48]) << "\n";
std::cout << " B[15,0] = " << static_cast<float>(h_b[15*ldb]) << "\n\n";
std::cout << " B[15,0] = " << static_cast<float>(h_b[15 * ldb]) << "\n\n";
DeviceMem d_b(K * N * sizeof(DataType));
d_b.ToDevice(h_b.data(), K * N * sizeof(DataType));
stream_config stream;
launch_kernel(stream,
make_kernel<256>(
TestBYSlicingKernel<DataType>{},
dim3(1), dim3(256), 0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
ldb));
make_kernel<256>(TestBYSlicingKernel<DataType>{},
dim3(1),
dim3(256),
0,
static_cast<const DataType*>(d_b.GetDeviceBuffer()),
ldb));
hip_check_error(hipDeviceSynchronize());

View File

@@ -30,23 +30,23 @@
using namespace ck_tile;
// Tile Sweeping HGEMM kernel with Y-dimension repetition
template<typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
template <typename ADataType, typename BDataType, typename CDataType, typename AccDataType>
struct TileSweepingYRepetitionHgemmKernel
{
static constexpr index_t kWaveSize = 64; // AMD wave size
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
static constexpr index_t kWaveSize = 64; // AMD wave size
static constexpr index_t kWarpM = 16; // MFMA M dimension per warp
static constexpr index_t kWarpN = 16; // MFMA N dimension per warp
static constexpr index_t kWarpK = 16; // MFMA K dimension per instruction
// Warp configuration: 2×2 warps per block (SAME as Tutorial 06)
static constexpr index_t MWarp = 2; // 2 warps in M dimension
static constexpr index_t NWarp = 2; // 2 warps in N dimension
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
static constexpr index_t MWarp = 2; // 2 warps in M dimension
static constexpr index_t NWarp = 2; // 2 warps in N dimension
static constexpr index_t kBlockSize = MWarp * NWarp * kWaveSize; // 256 threads
// NEW: Tile iterations per warp (Y-dimension repetition)
static constexpr index_t MIterPerWarp = 2; // Each warp sweeps 2 tiles in M
static constexpr index_t NIterPerWarp = 2; // Each warp sweeps 2 tiles in N
static constexpr index_t KIterPerWarp = 1; // K handled in main loop
static constexpr index_t MIterPerWarp = 2; // Each warp sweeps 2 tiles in M
static constexpr index_t NIterPerWarp = 2; // Each warp sweeps 2 tiles in N
static constexpr index_t KIterPerWarp = 1; // K handled in main loop
// Use ck_tile's WarpGemm for MFMA
using WarpGemm = WarpGemmMfmaF16F16F32M16N16K16;
@@ -58,28 +58,28 @@ struct TileSweepingYRepetitionHgemmKernel
index_t M,
index_t N,
index_t K,
index_t lda, // Leading dimension of A (column-major)
index_t ldb, // Leading dimension of B (row-major)
index_t ldc, // Leading dimension of C (column-major)
index_t ldd, // Leading dimension of D (column-major)
index_t lda, // Leading dimension of A (column-major)
index_t ldb, // Leading dimension of B (row-major)
index_t ldc, // Leading dimension of C (column-major)
index_t ldd, // Leading dimension of D (column-major)
AccDataType alpha,
AccDataType beta) const
{
// Calculate which warp this thread belongs to within the block
[[maybe_unused]] const index_t warp_id = get_warp_id();
[[maybe_unused]] const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
[[maybe_unused]] const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
[[maybe_unused]] const index_t iMWarp = warp_id / NWarp; // M-warp index (0 or 1)
[[maybe_unused]] const index_t iNWarp = warp_id % NWarp; // N-warp index (0 or 1)
// Calculate base offset for this block
// Each block now computes (MWarp × MIterPerWarp × kWarpM) × (NWarp × NIterPerWarp × kWarpN)
const index_t kMPerBlock = MWarp * MIterPerWarp * kWarpM; // 2×2×16 = 64
const index_t kNPerBlock = NWarp * NIterPerWarp * kWarpN; // 2×2×16 = 64
const index_t kMPerBlock = MWarp * MIterPerWarp * kWarpM; // 2×2×16 = 64
const index_t kNPerBlock = NWarp * NIterPerWarp * kWarpN; // 2×2×16 = 64
// Calculate block position in 2D grid
const index_t num_blocks_n = N / kNPerBlock;
const index_t block_m = get_block_id() / num_blocks_n;
const index_t block_n = get_block_id() % num_blocks_n;
const index_t block_m = get_block_id() / num_blocks_n;
const index_t block_n = get_block_id() % num_blocks_n;
const index_t m_block_base = block_m * kMPerBlock;
const index_t n_block_base = block_n * kNPerBlock;
@@ -89,111 +89,89 @@ struct TileSweepingYRepetitionHgemmKernel
// Create tensor views for matrices
const auto a_tensor = make_naive_tensor_view<address_space_enum::global>(
a,
make_tuple(M, K),
make_tuple(1, lda),
number<1>{},
number<1>{}
);
a, make_tuple(M, K), make_tuple(1, lda), number<1>{}, number<1>{});
const auto b_tensor = make_naive_tensor_view<address_space_enum::global>(
b,
make_tuple(K, N),
make_tuple(ldb, 1),
number<4>{},
number<1>{}
);
b, make_tuple(K, N), make_tuple(ldb, 1), number<4>{}, number<1>{});
const auto c_tensor = make_naive_tensor_view<address_space_enum::global>(
c,
make_tuple(M, N),
make_tuple(1, ldc),
number<1>{},
number<1>{}
);
c, make_tuple(M, N), make_tuple(1, ldc), number<1>{}, number<1>{});
auto d_tensor = make_naive_tensor_view<address_space_enum::global>(
d,
make_tuple(M, N),
make_tuple(1, ldd),
number<1>{},
number<1>{}
);
d, make_tuple(M, N), make_tuple(1, ldd), number<1>{}, number<1>{});
// ============================================================================
// TILE DISTRIBUTIONS with Y-DIMENSION REPETITION (following 02_gemm pattern)
// ============================================================================
// A Distribution: Block-level with Y-repetition
// Warp-level distribution (same as Tutorial 06)
constexpr auto a_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<16>, sequence<4, 4>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
constexpr auto a_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<4, 4>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
// constexpr auto a_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<NWarp>,
// tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<1, 0>>,
// tuple<sequence<MIterPerWarp, MWarp>,
// sequence<KIterPerWarp>>, tuple<sequence<1, 0>>,
// tuple<sequence<1, 0>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// Block-level outer distribution with Y-repetition
constexpr auto a_block_outer_dstr_encode = tile_distribution_encoding<
sequence<NWarp>, // Replicate across N-warps
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // H0: 2 iters × 2 warps in M
tuple<sequence<0, 1>>, // Ps_to_Hs
tuple<sequence<0, 1>>, // Ps_in_Hs
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and K
sequence<0, 0>>{}; // Ys_in_Hs
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
constexpr auto a_block_outer_dstr_encode =
tile_distribution_encoding<sequence<NWarp>, // Replicate across N-warps
tuple<sequence<MIterPerWarp, MWarp>,
sequence<KIterPerWarp>>, // H0: 2 iters × 2 warps in M
tuple<sequence<0, 1>>, // Ps_to_Hs
tuple<sequence<0, 1>>, // Ps_in_Hs
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and K
sequence<0, 0>>{}; // Ys_in_Hs
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encode, a_warp_dstr_encode);
// B Distribution: Block-level with Y-repetition
constexpr auto b_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
constexpr auto b_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
// constexpr auto b_block_outer_dstr_encode =
// tile_distribution_encoding<sequence<MWarp>,
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<NIterPerWarp, NWarp>,
// sequence<KIterPerWarp>>, tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encode = tile_distribution_encoding<
sequence<MWarp>, // Replicate across M-warps
tuple<sequence<KIterPerWarp>, // H0: 2 iters × 2 warps in N
sequence<NIterPerWarp, NWarp>>, // H1: 1 K-chunk
tuple<sequence<2, 0>>, // Ps_to_Hs
tuple<sequence<1, 0>>, // Ps_in_Hs
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH N and K
sequence<0, 0>>{}; // Ys_in_Hs
constexpr auto b_block_outer_dstr_encode =
tile_distribution_encoding<sequence<MWarp>, // Replicate across M-warps
tuple<sequence<KIterPerWarp>, // H0: 2 iters × 2 warps in N
sequence<NIterPerWarp, NWarp>>, // H1: 1 K-chunk
tuple<sequence<2, 0>>, // Ps_to_Hs
tuple<sequence<1, 0>>, // Ps_in_Hs
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH N and K
sequence<0, 0>>{}; // Ys_in_Hs
constexpr auto b_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encode, b_warp_dstr_encode);
// // C Distribution: Block-level with Y-repetition for output
constexpr auto c_warp_dstr_encode = tile_distribution_encoding<
sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
constexpr auto c_warp_dstr_encode =
tile_distribution_encoding<sequence<>,
tuple<sequence<4, 4>, sequence<16>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<1>,
sequence<1>>{};
// constexpr auto c_block_outer_dstr_encode = tile_distribution_encoding<
// sequence<>,
@@ -203,18 +181,17 @@ struct TileSweepingYRepetitionHgemmKernel
// sequence<1, 2>,
// sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encode = tile_distribution_encoding<
sequence<>, // No replication for output
tuple<sequence<MIterPerWarp, MWarp>, // H0: M iterations
sequence<NIterPerWarp, NWarp>>, // H1: N iterations
tuple<sequence<2, 1>>, // Ps_to_Hs
tuple<sequence<1, 1>>, // Ps_in_Hs
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and N
sequence<0, 0>>{}; // Ys_in_Hs
constexpr auto c_block_outer_dstr_encode =
tile_distribution_encoding<sequence<>, // No replication for output
tuple<sequence<MIterPerWarp, MWarp>, // H0: M iterations
sequence<NIterPerWarp, NWarp>>, // H1: N iterations
tuple<sequence<2, 1>>, // Ps_to_Hs
tuple<sequence<1, 1>>, // Ps_in_Hs
sequence<1, 2>, // Ys_to_Hs: Y maps to BOTH M and N
sequence<0, 0>>{}; // Ys_in_Hs
constexpr auto c_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encode, c_warp_dstr_encode);
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encode, c_warp_dstr_encode);
// Create distributions
constexpr auto a_block_distribution = make_static_tile_distribution(a_block_dstr_encode);
@@ -226,72 +203,71 @@ struct TileSweepingYRepetitionHgemmKernel
using BWarpDstr = decltype(make_static_tile_distribution(b_warp_dstr_encode));
using CWarpDstr = decltype(make_static_tile_distribution(c_warp_dstr_encode));
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// // Create block-level windows
auto a_block_window = make_tile_window(
a_tensor,
make_tuple(number<kMPerBlock>{}, number<kWarpK>{}),
{m_block_base, 0},
a_block_distribution
);
// // Create block-level windows
auto a_block_window = make_tile_window(a_tensor,
make_tuple(number<kMPerBlock>{}, number<kWarpK>{}),
{m_block_base, 0},
a_block_distribution);
auto b_block_window = make_tile_window(
b_tensor,
make_tuple(number<kWarpK>{}, number<kNPerBlock>{}),
{0, n_block_base},
b_block_distribution
);
auto b_block_window = make_tile_window(b_tensor,
make_tuple(number<kWarpK>{}, number<kNPerBlock>{}),
{0, n_block_base},
b_block_distribution);
// // Create block-level accumulator tile
// // Create block-level accumulator tile
auto c_block_tile = make_static_distributed_tensor<AccDataType>(c_block_distribution);
set_tile(c_block_tile, AccDataType{0});
// // Main K-loop
// // Main K-loop
const index_t num_k_loops = K / kWarpK;
for(index_t k_iter = 0; k_iter < num_k_loops; ++k_iter)
{
// Load entire block tiles (all iterations at once)
const auto a_block_tile = load_tile(a_block_window);
const auto b_block_tile = load_tile(b_block_window);
// Nested loops over tile iterations using Y-slicing (like 02_gemm)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// Extract A warp tensor for this M-iteration using Y-slicing
auto a_warp_tensor = make_static_distributed_tensor<ADataType>(
make_static_tile_distribution(a_warp_dstr_encode));
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Extract B warp tensor for this N-iteration using Y-slicing
auto b_warp_tensor = make_static_distributed_tensor<BDataType>(
make_static_tile_distribution(b_warp_dstr_encode));
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// Extract C warp tensor for this M,N iteration
auto c_warp_tensor = make_static_distributed_tensor<AccDataType>(
make_static_tile_distribution(c_warp_dstr_encode));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// Warp GEMM: C[mIter, nIter] += A[mIter, kIter] * B[nIter, kIter]
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// Write C warp tensor back to block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
@@ -302,7 +278,8 @@ struct TileSweepingYRepetitionHgemmKernel
});
// Move windows to next K chunk
if(k_iter < num_k_loops - 1) {
if(k_iter < num_k_loops - 1)
{
move_tile_window(a_block_window, {0, kWarpK});
move_tile_window(b_block_window, {kWarpK, 0});
}
@@ -314,62 +291,67 @@ struct TileSweepingYRepetitionHgemmKernel
// Add beta * C if needed
if(std::abs(beta) > 1e-6f)
{
auto c_block_window = make_tile_window(
c_tensor,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{m_block_base, n_block_base},
c_block_distribution
);
auto c_block_window =
make_tile_window(c_tensor,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{m_block_base, n_block_base},
c_block_distribution);
const auto c_input_block_tile = load_tile(c_block_window);
tile_elementwise_inout(
[beta](const auto& c_val, auto& acc_val) {
acc_val += beta * c_val;
},
c_input_block_tile, c_block_tile);
[beta](const auto& c_val, auto& acc_val) { acc_val += beta * c_val; },
c_input_block_tile,
c_block_tile);
}
// Store final result to D
auto d_block_window = make_tile_window(
d_tensor,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{m_block_base, n_block_base},
c_block_distribution
);
auto d_block_window =
make_tile_window(d_tensor,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{m_block_base, n_block_base},
c_block_distribution);
store_tile(d_block_window, c_block_tile);
}
};
// CPU reference for verification
template<typename InType, typename AccType>
template <typename InType, typename AccType>
void reference_gemm_mixed(const std::vector<InType>& a,
const std::vector<InType>& b,
const std::vector<AccType>& c,
std::vector<AccType>& d,
index_t M, index_t N, index_t K,
index_t lda, index_t ldb, index_t ldc, index_t ldd,
AccType alpha, AccType beta)
index_t M,
index_t N,
index_t K,
index_t lda,
index_t ldb,
index_t ldc,
index_t ldd,
AccType alpha,
AccType beta)
{
for(index_t n = 0; n < N; ++n) {
for(index_t m = 0; m < M; ++m) {
for(index_t n = 0; n < N; ++n)
{
for(index_t m = 0; m < M; ++m)
{
AccType sum = 0;
for(index_t k = 0; k < K; ++k) {
sum += static_cast<AccType>(a[m + k * lda]) *
static_cast<AccType>(b[k * ldb + n]);
for(index_t k = 0; k < K; ++k)
{
sum += static_cast<AccType>(a[m + k * lda]) * static_cast<AccType>(b[k * ldb + n]);
}
d[m + n * ldd] = alpha * sum + beta * c[m + n * ldc];
}
}
}
template<typename T>
template <typename T>
void fill_random(std::vector<T>& data, T min_val = -1, T max_val = 1)
{
for(auto& val : data) {
val = static_cast<T>(min_val + (max_val - min_val) *
static_cast<float>(rand()) / RAND_MAX);
for(auto& val : data)
{
val = static_cast<T>(min_val + (max_val - min_val) * static_cast<float>(rand()) / RAND_MAX);
}
}
@@ -390,7 +372,7 @@ int main()
// constexpr index_t M = 128;
// constexpr index_t N = 128;
// constexpr index_t K = 64;
constexpr index_t M = 128;
constexpr index_t N = 128;
constexpr index_t K = 64;
@@ -404,13 +386,13 @@ int main()
using AccumType = float;
constexpr AccumType alpha = 2.0f;
constexpr AccumType beta = 1.5f;
constexpr AccumType beta = 1.5f;
std::cout << "Problem configuration:\n";
std::cout << " M×N×K: " << M << "×" << N << "×" << K << "\n";
std::cout << " Block output: 64×64 (2 warps × 2 iters × 16)\n";
std::cout << " Warp output: 32×32 (2 iters × 16 in each dim)\n";
std::cout << " Total blocks: " << (M/64) << "×" << (N/64) << "\n\n";
std::cout << " Total blocks: " << (M / 64) << "×" << (N / 64) << "\n\n";
// Host memory
std::vector<InputType> h_a(M * K);
@@ -427,7 +409,7 @@ int main()
// CPU reference
auto cpu_start = std::chrono::high_resolution_clock::now();
reference_gemm_mixed(h_a, h_b, h_c, h_d_ref, M, N, K, lda, ldb, ldc, ldd, alpha, beta);
auto cpu_end = std::chrono::high_resolution_clock::now();
auto cpu_end = std::chrono::high_resolution_clock::now();
double cpu_time_ms = std::chrono::duration<double, std::milli>(cpu_end - cpu_start).count();
// Device memory
@@ -443,7 +425,7 @@ int main()
// Launch kernel
constexpr index_t block_size = 256;
const index_t grid_size = (M / 64) * (N / 64); // 64×64 per block
const index_t grid_size = (M / 64) * (N / 64); // 64×64 per block
std::cout << "Launching kernel:\n";
std::cout << " Grid: " << grid_size << " blocks\n";
@@ -454,59 +436,80 @@ int main()
stream_config stream;
// Warmup
for(int i = 0; i < 5; ++i) {
launch_kernel(stream,
make_kernel<block_size>(
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
for(int i = 0; i < 5; ++i)
{
launch_kernel(
stream,
make_kernel<block_size>(
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M,
N,
K,
lda,
ldb,
ldc,
ldd,
alpha,
beta));
}
hip_check_error(hipDeviceSynchronize());
// Timed run
auto gpu_start = std::chrono::high_resolution_clock::now();
launch_kernel(stream,
make_kernel<block_size>(
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M, N, K, lda, ldb, ldc, ldd, alpha, beta));
launch_kernel(
stream,
make_kernel<block_size>(
TileSweepingYRepetitionHgemmKernel<InputType, InputType, AccumType, AccumType>{},
dim3(grid_size),
dim3(block_size),
0,
static_cast<const InputType*>(d_a.GetDeviceBuffer()),
static_cast<const InputType*>(d_b.GetDeviceBuffer()),
static_cast<const AccumType*>(d_c.GetDeviceBuffer()),
static_cast<AccumType*>(d_d.GetDeviceBuffer()),
M,
N,
K,
lda,
ldb,
ldc,
ldd,
alpha,
beta));
hip_check_error(hipDeviceSynchronize());
auto gpu_end = std::chrono::high_resolution_clock::now();
auto gpu_end = std::chrono::high_resolution_clock::now();
double gpu_time_ms = std::chrono::duration<double, std::milli>(gpu_end - gpu_start).count();
// Get result
d_d.FromDevice(h_d.data(), M * N * sizeof(AccumType));
// Verify
bool passed = true;
float max_error = 0;
bool passed = true;
float max_error = 0;
index_t error_count = 0;
for(index_t i = 0; i < M * N; ++i) {
for(index_t i = 0; i < M * N; ++i)
{
float error = std::abs(h_d[i] - h_d_ref[i]);
max_error = std::max(max_error, error);
if(error > 1e-2f) {
if(error_count < 5) {
max_error = std::max(max_error, error);
if(error > 1e-2f)
{
if(error_count < 5)
{
index_t m = i % M;
index_t n = i / M;
std::cout << "Error at [" << m << "," << n << "]: "
<< h_d[i] << " vs " << h_d_ref[i]
<< " (diff=" << error << ")\n";
std::cout << "Error at [" << m << "," << n << "]: " << h_d[i] << " vs "
<< h_d_ref[i] << " (diff=" << error << ")\n";
}
error_count++;
}
@@ -514,14 +517,15 @@ int main()
passed = (error_count == 0);
double gflops = 2.0 * M * N * K / 1e9;
double gflops = 2.0 * M * N * K / 1e9;
double gpu_tflops = gflops / (gpu_time_ms / 1000);
double cpu_gflops = gflops / (cpu_time_ms / 1000);
std::cout << "Results:\n";
std::cout << " Correctness: " << (passed ? "✓ PASSED" : "✗ FAILED") << "\n";
std::cout << " Max error: " << max_error << "\n";
if(!passed) std::cout << " Error count: " << error_count << "/" << M*N << "\n";
if(!passed)
std::cout << " Error count: " << error_count << "/" << M * N << "\n";
std::cout << "\n";
std::cout << "Performance:\n";