mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK] CK Tile grouped convolution direct load
This commit is contained in:
@@ -58,6 +58,8 @@ consteval BlockGemmSpec SetBlockGemm()
|
||||
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
|
||||
@@ -93,6 +95,8 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
|
||||
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE.";
|
||||
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
|
||||
@@ -98,6 +98,20 @@ struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V6>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_ASYNC>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_ASYNC_V1>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegAsyncV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
{
|
||||
@@ -111,6 +125,8 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
|
||||
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
|
||||
case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6;
|
||||
case PipelineVersion::ASYNC_V1: return ck_tile_pipeline::BASIC_ASYNC_V1;
|
||||
case PipelineVersion::ASYNC_V4: return ck_tile_pipeline::COMPUTE_ASYNC;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
|
||||
@@ -158,6 +158,8 @@ enum class PipelineVersion
|
||||
V4,
|
||||
V5,
|
||||
V6,
|
||||
ASYNC_V1,
|
||||
ASYNC_V4,
|
||||
WEIGHT_ONLY
|
||||
};
|
||||
|
||||
@@ -330,6 +332,8 @@ inline std::string_view to_string(PipelineVersion ver)
|
||||
case V4: return "V4";
|
||||
case V5: return "V5";
|
||||
case V6: return "V6";
|
||||
case ASYNC_V1: return "ASYNC_V1";
|
||||
case ASYNC_V4: return "ASYNC_V4";
|
||||
case WEIGHT_ONLY: return "WEIGHT_ONLY";
|
||||
default: return "Unknown";
|
||||
}
|
||||
|
||||
@@ -234,4 +234,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -225,4 +225,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -234,4 +234,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -225,4 +225,30 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 64, 64, Filter1x1Stride
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 16, 128, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 16, 256, 64, Filter1x1Stride1Pad0, 16, 16, 1, 4, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 32, 256, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Default, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Default, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Filter1x1Stride1Pad0, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 128, 64, Filter1x1Stride1Pad0, 32, 32, 1, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 32, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
|
||||
@@ -60,7 +60,6 @@ class ConvInstanceTemplateParams:
|
||||
|
||||
def get_block_gemm_desc(self):
|
||||
double_smem_buffer = "true" if self.double_smem_buffer else "false"
|
||||
pipeline_version = self.pipeline_version[-1:]
|
||||
scheduler = (
|
||||
"INTRAWAVE" if self.scheduler.find("Intrawave") != -1 else "INTERWAVE"
|
||||
)
|
||||
@@ -69,7 +68,7 @@ class ConvInstanceTemplateParams:
|
||||
.warp_tile = {{.m = {self.warp_tile[0]}, .n = {self.warp_tile[1]}, .k = {self.warp_tile[2]}}},
|
||||
.double_smem_buffer = {double_smem_buffer},
|
||||
.num_wave_groups = {self.num_wave_groups},
|
||||
.pipeline_version = ckb::PipelineVersion::V{pipeline_version},
|
||||
.pipeline_version = ckb::PipelineVersion::{self.pipeline_version},
|
||||
.scheduler = ckb::PipelineScheduler::{scheduler}}}"""
|
||||
|
||||
def get_block_transfer(self):
|
||||
@@ -180,6 +179,16 @@ def parse_fwd_instances(instances, problem_name):
|
||||
pipeline_version = (
|
||||
"v1" if instance.find("BlkGemmPipelineVersion") == -1 else args[15]
|
||||
)
|
||||
# Replace pipeline if Direct Load
|
||||
if instance.find("DirectLoad") != -1:
|
||||
if instance.find("BlkGemmPipelineVersion: v1") != -1:
|
||||
pipeline_version = "ASYNC_V1"
|
||||
elif instance.find("BlkGemmPipelineVersion: v4") != -1:
|
||||
pipeline_version = "ASYNC_V4"
|
||||
else:
|
||||
raise RuntimeError("not supported pipeline for direct load")
|
||||
else:
|
||||
pipeline_version = f"""V{pipeline_version[-1:]}"""
|
||||
|
||||
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
|
||||
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
|
||||
|
||||
@@ -51,6 +51,7 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_async.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
|
||||
@@ -134,6 +134,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = true;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
@@ -529,13 +531,16 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
@@ -556,7 +561,67 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
@@ -567,10 +632,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
element_wise::PassThrough{},
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
@@ -126,6 +126,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
|
||||
@@ -144,6 +144,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
|
||||
@@ -73,6 +73,8 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
|
||||
@@ -129,6 +129,8 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
|
||||
@@ -179,6 +179,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
|
||||
@@ -0,0 +1,594 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = true;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "BASIC_ASYNC_V1";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegAsyncV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
kLdsAlignmentInBytes) *
|
||||
kLdsAlignmentInBytes;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
kLdsAlignmentInBytes) *
|
||||
kLdsAlignmentInBytes;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// // Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -71,6 +71,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
|
||||
@@ -71,6 +71,8 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Pr
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
|
||||
@@ -731,6 +731,13 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
|
||||
@@ -508,6 +508,13 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
|
||||
@@ -654,6 +654,14 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
|
||||
Reference in New Issue
Block a user