feat(tile_window): print content of tile window for easier debugging (#2827)

* feat(tile_window): add function to print content of tile windowof static length, given a 2D range

* chore: make documentation less verbose
This commit is contained in:
Aviral Goel
2025-09-16 18:47:21 -04:00
committed by GitHub
parent 48e08c6429
commit 2723dbd332

View File

@@ -887,6 +887,58 @@ struct tile_window_with_static_lengths
this->window_lengths_ = window_lengths;
this->bottom_tensor_view_ = bottom_tensor_view;
}
/**
* @brief Print tile window elements for debugging.
*
* @tparam DataType Element data type (e.g., fp16_t, float, bf8_t)
* @param start_i Starting row (inclusive)
* @param end_i Ending row (exclusive)
* @param start_j Starting column (inclusive)
* @param end_j Ending column (exclusive)
* @param label Optional output label
*
* @note Tested on fp16. Custom types may need adjustments.
* @example tile_window.template print_tile_window_range<fp16_t>(0, 4, 0, 8, "A");
*/
template <typename DataType>
CK_TILE_DEVICE void print_tile_window_range(index_t start_i,
index_t end_i,
index_t start_j,
index_t end_j,
const char* label = "") const
{
const auto& tensor_view = this->get_bottom_tensor_view();
const auto window_origin = this->get_window_origin();
printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
label,
start_i,
end_i - 1,
start_j,
end_j - 1,
window_origin[0],
window_origin[1]);
for(index_t i = start_i; i < end_i; i++)
{
for(index_t j = start_j; j < end_j; j++)
{
// Create coordinate for this element relative to window origin
auto coord =
make_tensor_coordinate(tensor_view.get_tensor_descriptor(),
make_tuple(window_origin[0] + i, window_origin[1] + j));
// Get the element using thread buffer type directly
using ThreadBuf = thread_buffer<DataType, 2>;
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
printf(" %s[%d,%d] = %f", label, i, j, static_cast<float>(value));
}
printf("\n");
}
printf("\n");
}
};
template <typename TensorView_, typename WindowLengths_>