mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
feat(tile_window): print content of tile window for easier debugging (#2827)
* feat(tile_window): add function to print content of tile windowof static length, given a 2D range * chore: make documentation less verbose
This commit is contained in:
@@ -887,6 +887,58 @@ struct tile_window_with_static_lengths
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Print tile window elements for debugging.
|
||||
*
|
||||
* @tparam DataType Element data type (e.g., fp16_t, float, bf8_t)
|
||||
* @param start_i Starting row (inclusive)
|
||||
* @param end_i Ending row (exclusive)
|
||||
* @param start_j Starting column (inclusive)
|
||||
* @param end_j Ending column (exclusive)
|
||||
* @param label Optional output label
|
||||
*
|
||||
* @note Tested on fp16. Custom types may need adjustments.
|
||||
* @example tile_window.template print_tile_window_range<fp16_t>(0, 4, 0, 8, "A");
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_DEVICE void print_tile_window_range(index_t start_i,
|
||||
index_t end_i,
|
||||
index_t start_j,
|
||||
index_t end_j,
|
||||
const char* label = "") const
|
||||
{
|
||||
const auto& tensor_view = this->get_bottom_tensor_view();
|
||||
const auto window_origin = this->get_window_origin();
|
||||
|
||||
printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
|
||||
label,
|
||||
start_i,
|
||||
end_i - 1,
|
||||
start_j,
|
||||
end_j - 1,
|
||||
window_origin[0],
|
||||
window_origin[1]);
|
||||
|
||||
for(index_t i = start_i; i < end_i; i++)
|
||||
{
|
||||
for(index_t j = start_j; j < end_j; j++)
|
||||
{
|
||||
// Create coordinate for this element relative to window origin
|
||||
auto coord =
|
||||
make_tensor_coordinate(tensor_view.get_tensor_descriptor(),
|
||||
make_tuple(window_origin[0] + i, window_origin[1] + j));
|
||||
|
||||
// Get the element using thread buffer type directly
|
||||
using ThreadBuf = thread_buffer<DataType, 2>;
|
||||
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
|
||||
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
|
||||
printf(" %s[%d,%d] = %f", label, i, j, static_cast<float>(value));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
|
||||
Reference in New Issue
Block a user