[CK] CK Tile grouped convolution direct load (#4406)

## Motivation

CK Tile grouped convolution forward direct load support.

## Technical Details

Basic pipeline for direct load and new instances for forward for v1 and
v4 pipelines.

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

CI pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
AICK-130
This commit is contained in:
Bartłomiej Kocot
2026-02-09 22:08:57 +01:00
committed by GitHub
parent fdb1a08e6f
commit 23b32f1ff8
29 changed files with 739 additions and 56 deletions

View File

@@ -58,6 +58,8 @@ consteval BlockGemmSpec SetBlockGemm()
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
@@ -93,6 +95,8 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
case PipelineVersion::V4: return ck_pipeline::v4;
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE.";
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
default: throw "Unknown GridwiseGemmPipelineVersion";
@@ -139,6 +143,8 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
case PipelineVersion::V3: return ck_pipeline::v3;
case PipelineVersion::V4: return ck_pipeline::v4;
case PipelineVersion::V5: return ck_pipeline::v5;
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";

View File

@@ -98,6 +98,20 @@ struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V6>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_ASYNC>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_ASYNC_V1>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegAsyncV1<PipelineProblem>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
{
@@ -111,6 +125,8 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6;
case PipelineVersion::ASYNC_V1: return ck_tile_pipeline::BASIC_ASYNC_V1;
case PipelineVersion::ASYNC_V4: return ck_tile_pipeline::COMPUTE_ASYNC;
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";

View File

@@ -54,14 +54,15 @@ void init_tensor_buffer_uniform_int(void* buf,
// we might be asked to generate int values on fp data types that don't have the required
// precision
if(static_cast<ck_type>(max_value - 1) == static_cast<ck_type>(min_value))
if(static_cast<ck_type>(max_value - 1) <= static_cast<ck_type>(min_value) &&
static_cast<ck_type>(max_value - 1) >= static_cast<ck_type>(min_value))
{
throw std::runtime_error("Error while filling device tensor with random integer data: "
"insufficient precision in specified range");
}
size_t packed_size = ck::packed_size_v<ck_type>;
fill_tensor_uniform_rand_int_values<<<256, 256>>>(
static_cast<ck_type>(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type));
static_cast<ck_type*>(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type));
}
/// @brief Initialize tensor data with a uniform float distribution

View File

@@ -158,6 +158,8 @@ enum class PipelineVersion
V4,
V5,
V6,
ASYNC_V1,
ASYNC_V4,
WEIGHT_ONLY
};
@@ -330,6 +332,8 @@ inline std::string_view to_string(PipelineVersion ver)
case V4: return "V4";
case V5: return "V5";
case V6: return "V6";
case ASYNC_V1: return "ASYNC_V1";
case ASYNC_V4: return "ASYNC_V4";
case WEIGHT_ONLY: return "WEIGHT_ONLY";
default: return "Unknown";
}

View File

@@ -234,4 +234,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -225,4 +225,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -234,4 +234,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -225,4 +225,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -60,7 +60,6 @@ class ConvInstanceTemplateParams:
def get_block_gemm_desc(self):
double_smem_buffer = "true" if self.double_smem_buffer else "false"
pipeline_version = self.pipeline_version[-1:]
scheduler = (
"INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE"
)
@@ -69,7 +68,7 @@ class ConvInstanceTemplateParams:
.warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}},
.double_smem_buffer = {double_smem_buffer},
.num_wave_groups = {self.num_wave_groups},
.pipeline_version = ckb::PipelineVersion::V{pipeline_version},
.pipeline_version = ckb::PipelineVersion::{self.pipeline_version},
.scheduler = ckb::PipelineScheduler::{scheduler}}}"""
def get_block_transfer(self):
@@ -180,6 +179,16 @@ def parse_fwd_instances(instances, problem_name):
pipeline_version = (
"v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15]
)
# Replace pipeline if Direct Load
if instance.find("DirectLoad") != -1:
if instance.find("BlkGemmPipelineVersion: v1") != -1:
pipeline_version = "ASYNC_V1"
elif instance.find("BlkGemmPipelineVersion: v4") != -1:
pipeline_version = "ASYNC_V4"
else:
raise RuntimeError("not supported pipeline for direct load")
else:
pipeline_version = f"""V{pipeline_version[-1:]}"""
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))

View File

@@ -51,6 +51,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_async_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"

View File

@@ -134,6 +134,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = true;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
@@ -529,13 +531,16 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
@@ -556,7 +561,67 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
}
public:
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
public:
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
@@ -567,10 +632,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
ck_tile::make_tuple(a_dram_block_window_tmp),
element_wise::PassThrough{},
ck_tile::make_tuple(b_dram_block_window_tmp),
element_wise::PassThrough{},
num_loop,
p_smem);
};

View File

@@ -118,6 +118,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{

View File

@@ -144,6 +144,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{

View File

@@ -73,6 +73,8 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{

View File

@@ -129,6 +129,8 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{

View File

@@ -179,6 +179,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{

View File

@@ -0,0 +1,361 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
#include "gemm_pipeline_agmem_bgmem_creg_v1.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr bool Async = true;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool Preshuffle = Problem::Preshuffle;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr index_t kLdsAlignmentInBytes = 16;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
{
// clang-format off
return "BASIC_ASYNC_V1";
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegAsyncV1",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
// TODO support multi-ABD
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
// TODO currently fused elementwise are not supported
ignore = a_element_func;
ignore = b_element_func;
static_assert(std::is_same_v<AElementFunction, element_wise::PassThrough>);
static_assert(std::is_same_v<BElementFunction, element_wise::PassThrough>);
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"Data Type conflict on A and B matrix input data type.");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window(s) for load
auto a_tile_windows =
make_tile_window(a_dram_block_window_tmp[I0{}].get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp[I0{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM window(s) for load
auto b_tile_windows =
make_tile_window(b_dram_block_window_tmp[I0{}].get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp[I0{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// this pipeline has a pair of LDS buffers per logical tile
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// set up LDS tile shapes
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v)
return make_tuple(number<kKPerBlock>{}, number<kMPerBlock>{});
else
return make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v)
return make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{});
else
return make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{});
}();
// LDS tile windows for storing, one per LDS buffer
auto a_copy_lds_window = make_tile_window(a_lds_block, a_lds_shape, {0, 0});
auto b_copy_lds_window = make_tile_window(b_lds_block, b_lds_shape, {0, 0});
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = block_gemm.MakeCBlockTile();
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
// tile distribution for the register tiles
constexpr auto ALdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto BLdsTileDistr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
ALdsTile a_block_tile;
BLdsTile b_block_tile;
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
if constexpr(is_a_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(ALdsTileDistr)::DstrEncode,
typename Problem::ADataType>::TransposedDstrEncode{});
else
return ALdsTileDistr;
}();
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
if constexpr(is_b_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(BLdsTileDistr)::DstrEncode,
typename Problem::BDataType>::TransposedDstrEncode{});
else
return BLdsTileDistr;
}();
// LDS tile windows for reading;
// they share the data pointer with the LDS windows for storing
// but also associate with a distribution to produce a register tile when reading
auto a_lds_ld_window =
make_tile_window(a_lds_block, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
auto b_lds_ld_window =
make_tile_window(b_lds_block, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
static_assert((!(is_tile_window_linear_v<decltype(a_lds_ld_window)>) &&
!(is_tile_window_linear_v<decltype(b_lds_ld_window)>)),
"LDS windows must not be linear");
// Global Prefetch
Base::GlobalPrefetchAsync(a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds_direct_load();
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
block_sync_lds();
Base::GlobalPrefetchAsync(
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
// GEMM i
block_gemm(c_block_tile, a_block_tile, b_block_tile);
block_sync_lds_direct_load();
iCounter--;
}
// tail
{
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
// GEMM num_loop - 1
block_gemm(c_block_tile, a_block_tile, b_block_tile);
}
return c_block_tile;
}
};
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -71,6 +71,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{

View File

@@ -71,6 +71,8 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Pr
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr bool Async = false;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{

View File

@@ -15,7 +15,8 @@ enum struct GemmPipeline
MEMORY,
BASIC_V1,
BASIC_V2,
PRESHUFFLE_V2
PRESHUFFLE_V2,
BASIC_ASYNC_V1
};
} // namespace ck_tile

View File

@@ -731,6 +731,13 @@ struct GroupedConvolutionBackwardDataKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
{
if constexpr(GemmPipeline_::Async)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value)
{
@@ -1128,17 +1135,36 @@ struct GroupedConvolutionBackwardDataKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
#endif
}
else
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
}
}
};

View File

@@ -508,6 +508,13 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if constexpr(GemmPipeline_::Async)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if(kargs.k_batch < 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
@@ -899,7 +906,18 @@ struct GroupedConvolutionBackwardWeightKernel
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
#endif
}
else
{
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
}
}
}
};

View File

@@ -654,6 +654,14 @@ struct GroupedConvolutionForwardKernel
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
{
if constexpr(GemmPipeline_::Async)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
@@ -1141,19 +1149,40 @@ struct GroupedConvolutionForwardKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
#endif
}
else
{
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
}
}
}
};

View File

@@ -15,8 +15,7 @@
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/builder/conv_builder.hpp"
// Temporary disable builder validate since we don't have deduced rtol, atol support
#define ENABLE_BUILDER_VALIDATE 0
#define ENABLE_BUILDER_VALIDATE 1
namespace ck_tile::builder::profiling {
@@ -168,11 +167,8 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
#else
HIP_CHECK_ERROR(
hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost));
valid = ck_tile::check_err(out, ref, "Error: Incorrect results!", rtol, atol);
valid = ck_tile::check_err(out, ref, "Error: Incorrect results!");
#endif
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;
}
else
{

View File

@@ -69,9 +69,9 @@ class TestGroupedConvndFwdTile : public ::testing::Test
auto inputs = alloc_inputs(args);
auto outputs = alloc_outputs(args);
ckt::init_tensor_buffer_uniform_fp(
ckt::init_tensor_buffer_uniform_int(
inputs.get().input, args.make_input_descriptor(), -5, 5);
ckt::init_tensor_buffer_uniform_fp(
ckt::init_tensor_buffer_uniform_int(
inputs.get().weight, args.make_weight_descriptor(), -5, 5);
std::cout << args.make_input_descriptor() << std::endl;