mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE] Add has hot loop check for pipeline v1 (#4407)
## Motivation Add has hot loop check for pipeline v1 (v1 basic and v1 basic async). Enable more tests which have been fixed by this change. ## Technical Details Hot loop has been executed without num loop check. ## Test Plan test_grouped_convnd_fwd_tile ## Test Result Passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-651 AICK-663
This commit is contained in:
@@ -53,7 +53,7 @@ void init_tensor_buffer_uniform_int(void* buf,
|
||||
using ck_type = factory::internal::DataTypeToCK<DT>::type;
|
||||
|
||||
// we might be asked to generate int values on fp data types that don't have the required
|
||||
// precision
|
||||
// precision. Check using >= and <= because == is not allowed for floats.
|
||||
if(static_cast<ck_type>(max_value - 1) <= static_cast<ck_type>(min_value) &&
|
||||
static_cast<ck_type>(max_value - 1) >= static_cast<ck_type>(min_value))
|
||||
{
|
||||
|
||||
@@ -16,4 +16,5 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_instance_library(device_grouped_conv_fwd_tile_instances ${GROUPED_CONV_FWD_TILE})
|
||||
target_include_directories(device_grouped_conv_fwd_tile_instances PRIVATE
|
||||
"${PROJECT_SOURCE_DIR}/experimental/builder/test/utils")
|
||||
target_compile_options(device_grouped_conv_fwd_tile_instances PRIVATE -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=0)
|
||||
endif()
|
||||
|
||||
@@ -58,9 +58,10 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pa
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -58,9 +58,10 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pa
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -58,8 +58,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -58,9 +58,10 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pa
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -58,9 +58,10 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pa
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 2, 2, 2, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -58,8 +58,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 1, 2, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 16, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<256, 256, 128, 16, Default, 32, 32, 4, 2, 4, 4, 4, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -7,7 +7,8 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -7,7 +7,8 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -7,7 +7,8 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -7,7 +7,8 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -7,7 +7,8 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
@@ -39,11 +40,11 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
|
||||
@@ -7,7 +7,8 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 32, 64, 32, Filter1x1Stride1Pad
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Filter1x1Stride1Pad0, 16, 16, 2, 2, 1, 2, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
# LargeTensor is temporary disable due to failures
|
||||
# DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<64, 64, 64, 32, Default, 32, 32, 2, 2, 1, 1, 1, 1, 1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 8>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 16>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<64, 64, 16, 16, Default, 16, 16, 4, 1, 4, 1, 1, 1, 1, 32>
|
||||
|
||||
@@ -85,6 +85,13 @@ __device__ inline auto amd_wave_read_first_lane(const Object& obj)
|
||||
return out;
|
||||
}
|
||||
|
||||
// Overload for host to return the same value
|
||||
template <typename T>
|
||||
__host__ inline T amd_wave_read_first_lane(T v)
|
||||
{
|
||||
return v;
|
||||
}
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
struct __attribute__((packed)) buffer_resource
|
||||
|
||||
@@ -81,6 +81,13 @@ __device__ inline auto amd_wave_read_first_lane(const Object& obj)
|
||||
return out;
|
||||
}
|
||||
|
||||
// Overload for host to return the same value
|
||||
template <typename T>
|
||||
__host__ inline T amd_wave_read_first_lane(T v)
|
||||
{
|
||||
return v;
|
||||
}
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
struct __attribute__((packed)) buffer_resource
|
||||
|
||||
@@ -44,15 +44,20 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
@@ -60,12 +65,12 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
@@ -430,7 +435,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
if(HasHotLoop)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// we have had 3 global prefetches so far, indexed (0, 1, 2).
|
||||
index_t i_global_read = amd_wave_read_first_lane(3);
|
||||
|
||||
@@ -46,6 +46,12 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
|
||||
constexpr auto scenarios = []() {
|
||||
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
|
||||
return std::array<std::pair<bool, ck_tile::TailNumber>, 5>{
|
||||
@@ -62,7 +68,8 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
std::make_pair(false, TailNumber::Even),
|
||||
};
|
||||
}();
|
||||
if(has_hot_loop == scenarios[I].first && tail_number == scenarios[I].second)
|
||||
if(has_hot_loop_first_lane == scenarios[I].first &&
|
||||
tail_number_first_lane == scenarios[I].second)
|
||||
return run_func(bool_constant<scenarios[I].first>{}, constant<scenarios[I].second>{});
|
||||
else if constexpr(I + 1 < scenarios.size())
|
||||
return TailHandler<I + 1>(run_func, has_hot_loop, tail_number);
|
||||
|
||||
@@ -47,15 +47,20 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
@@ -63,12 +68,12 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
|
||||
@@ -43,15 +43,20 @@ struct BaseGemmPipelineAgBgCrCompV6
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
if(tail_number_first_lane == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
else if(tail_number_first_lane == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
@@ -59,12 +64,12 @@ struct BaseGemmPipelineAgBgCrCompV6
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
if(tail_number_first_lane == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
else if(tail_number_first_lane == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
@@ -567,7 +572,7 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
|
||||
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
|
||||
|
||||
if(HasHotLoop)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
|
||||
@@ -93,9 +93,14 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Wrap the hot_loop dispatch first.
|
||||
auto tail_dispatch = [&](auto tail_num_constant) {
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, tail_num_constant);
|
||||
}
|
||||
@@ -106,7 +111,7 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
};
|
||||
|
||||
#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \
|
||||
else if(tail_number == TailNumber::TAIL_NUMBER) \
|
||||
else if(tail_number_first_lane == TailNumber::TAIL_NUMBER) \
|
||||
{ \
|
||||
if constexpr(PrefetchStages > PREFETCH_VALUE) \
|
||||
{ \
|
||||
@@ -114,11 +119,11 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
} \
|
||||
}
|
||||
// Handle all the valid cases.
|
||||
if(tail_number == TailNumber::One)
|
||||
if(tail_number_first_lane == TailNumber::One)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::One>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Full)
|
||||
else if(tail_number_first_lane == TailNumber::Full)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
@@ -117,7 +118,8 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
@@ -268,25 +270,28 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_block_tile, b_block_tile);
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_block_tile, b_block_tile);
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
block_sync_lds_direct_load();
|
||||
|
||||
iCounter--;
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -311,12 +316,18 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -349,12 +360,17 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
@@ -27,9 +30,21 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
return run_func(ck_tile::bool_constant<true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_func(ck_tile::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -39,6 +54,7 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
@@ -137,7 +153,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
@@ -216,6 +233,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
@@ -238,10 +263,10 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -273,54 +298,57 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
iCounter--;
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -340,7 +368,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
@@ -476,50 +505,53 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
iCounter--;
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -543,13 +575,18 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -582,12 +619,17 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1135,6 +1135,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
@@ -906,6 +906,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
@@ -1149,6 +1149,7 @@ struct GroupedConvolutionForwardKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "../../experimental/builder/test/utils/conv_algorithm_type_utils.hpp"
|
||||
#include "grouped_convolution_signatures.hpp"
|
||||
#include "common.hpp"
|
||||
#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp"
|
||||
|
||||
#include "ck_tile/builder/testing/filter_extent.hpp"
|
||||
@@ -96,6 +97,29 @@ auto parse_conv_args(int arg_idx, char* const argv[])
|
||||
return args;
|
||||
}
|
||||
|
||||
template <auto SIGNATURE>
|
||||
void run_cpu_validation(const ckt::Args<SIGNATURE>& args,
|
||||
const ckt::Outputs<SIGNATURE>& outputs,
|
||||
const ckt::Outputs<SIGNATURE>& reference)
|
||||
{
|
||||
using DataType =
|
||||
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP32,
|
||||
float,
|
||||
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP16,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bfloat16_t>>;
|
||||
const auto conv_param = args.to_ck_tile_conv_param();
|
||||
|
||||
const std::size_t output_bytes_num = conv_param.template GetOutputByte<DataType>();
|
||||
std::vector<DataType> out(output_bytes_num / sizeof(DataType));
|
||||
std::vector<DataType> ref(output_bytes_num / sizeof(DataType));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&ref.data()[0], reference.output, output_bytes_num, hipMemcpyDeviceToHost));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost));
|
||||
ck_tile::check_err(out, ref, "Error: Incorrect results!");
|
||||
}
|
||||
|
||||
/// @brief `run_grouped_conv_forward_tile_algs()` run all grouped conv fwd instances.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
@@ -114,39 +138,19 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
float avg_time;
|
||||
bool valid = true;
|
||||
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
[[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
|
||||
#if ENABLE_BUILDER_VALIDATE == 0
|
||||
using DataType =
|
||||
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP32,
|
||||
float,
|
||||
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP16,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bfloat16_t>>;
|
||||
const auto conv_param = args.to_ck_tile_conv_param();
|
||||
|
||||
const std::size_t output_bytes_num = conv_param.template GetOutputByte<DataType>();
|
||||
std::vector<DataType> out(output_bytes_num / sizeof(DataType));
|
||||
std::vector<DataType> ref(output_bytes_num / sizeof(DataType));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&ref.data()[0], reference.get().output, output_bytes_num, hipMemcpyDeviceToHost));
|
||||
|
||||
const ck_tile::index_t GemmK = std::accumulate(conv_param.filter_spatial_lengths_.cbegin(),
|
||||
conv_param.filter_spatial_lengths_.cend(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>()) *
|
||||
conv_param.C_;
|
||||
float max_accumulated_value = *std::max_element(ref.begin(), ref.end());
|
||||
const auto rtol = ck_tile::get_relative_threshold<DataType, DataType, float>(GemmK);
|
||||
const auto atol =
|
||||
ck_tile::get_absolute_threshold<DataType, DataType, float>(max_accumulated_value, GemmK);
|
||||
#endif
|
||||
|
||||
[[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) {
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
auto run_alg = [&](auto&& run_alg_func) {
|
||||
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
|
||||
if(is_supported)
|
||||
{
|
||||
@@ -155,20 +159,27 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name
|
||||
<< std::endl;
|
||||
|
||||
#if ENABLE_BUILDER_VALIDATE
|
||||
const auto errors = ckt::validate(args, outputs, reference.get()).get_errors();
|
||||
for(const auto& error : errors)
|
||||
ckt::ValidationReport report;
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
args,
|
||||
[&](std::string_view name, const auto& desc, void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(name,
|
||||
desc,
|
||||
outputs.*ptr,
|
||||
reference.get().*ptr,
|
||||
ck::profiler::get_rtol<DataType>(),
|
||||
ck::profiler::get_atol<DataType>());
|
||||
});
|
||||
|
||||
for(const auto& error : report.get_errors())
|
||||
{
|
||||
valid = false;
|
||||
std::cout << "Number of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
// Check with cpu verification to get a values
|
||||
run_cpu_validation<SIGNATURE>(args, outputs, reference.get());
|
||||
}
|
||||
#else
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost));
|
||||
valid = ck_tile::check_err(out, ref, "Error: Incorrect results!");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -11,11 +11,6 @@
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "profiler/grouped_convolution_forward_tile_algs.hpp"
|
||||
|
||||
// TODO: Remove limitation of conv fwd gpu reference which does not support right pad
|
||||
#define CK_CONV_FWD_REF_SKIP_RIGHT_PAD_CASES 1
|
||||
// TODO: Remove this limitation after gpu reference fix
|
||||
#define ENABLE_BHALF_GROUPED_CONV_FWD_TESTS 0
|
||||
|
||||
static ck::index_t args_mask = 0xffff;
|
||||
static ck::index_t instance_index = -1;
|
||||
|
||||
@@ -103,17 +98,6 @@ class TestGroupedConvndFwdTile : public ::testing::Test
|
||||
const std::vector<std::size_t>& input_left_pads,
|
||||
const std::vector<std::size_t>& input_right_pads)
|
||||
{
|
||||
#if CK_CONV_FWD_REF_SKIP_RIGHT_PAD_CASES
|
||||
bool without_right_pad = true;
|
||||
for(const std::size_t& right_pad : input_right_pads)
|
||||
{
|
||||
without_right_pad &= right_pad == 0;
|
||||
}
|
||||
if(!without_right_pad)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
@@ -155,12 +139,13 @@ using KernelTypes2d = ::testing::Types<SignatureDetails<2,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>,
|
||||
SignatureDetails<2,
|
||||
ckb::DataType::BF16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>>;
|
||||
#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS
|
||||
SignatureDetails < 2, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC, ckb::TensorLayout::NHWGK >>
|
||||
;
|
||||
#endif
|
||||
|
||||
using KernelTypes3d = ::testing::Types<SignatureDetails<3,
|
||||
ckb::DataType::FP32,
|
||||
@@ -173,12 +158,13 @@ using KernelTypes3d = ::testing::Types<SignatureDetails<3,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>,
|
||||
SignatureDetails<3,
|
||||
ckb::DataType::BF16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>>;
|
||||
#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS
|
||||
SignatureDetails < 3, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC, ckb::TensorLayout::NDHWGK >>
|
||||
;
|
||||
#endif
|
||||
|
||||
template <typename SignatureDetailsType>
|
||||
class TestGroupedConvndFwdTile2d : public TestGroupedConvndFwdTile<SignatureDetailsType>
|
||||
|
||||
Reference in New Issue
Block a user