mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[rocm-libraries] ROCm/rocm-libraries#4406 (commit 61f9f90)
[CK] CK Tile grouped convolution direct load ## Motivation CK Tile grouped convolution forward direct load support. ## Technical Details Basic pipeline for direct load and new instances for forward for v1 and v4 pipelines. ## Test Plan test_grouped_convnd_fwd_tile ## Test Result CI pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-130
This commit is contained in:
committed by
assistant-librarian[bot]
parent
0cafa68b6f
commit
27e0a34e0f
@@ -58,6 +58,8 @@ consteval BlockGemmSpec SetBlockGemm()
|
||||
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
|
||||
@@ -93,6 +95,8 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
|
||||
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE.";
|
||||
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
@@ -139,6 +143,8 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
case PipelineVersion::V3: return ck_pipeline::v3;
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: return ck_pipeline::v5;
|
||||
case PipelineVersion::ASYNC_V1: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::ASYNC_V4: throw "PipelineVersion::ASYNC can be used only for CK TILE.";
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
|
||||
@@ -98,6 +98,20 @@ struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V6>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_ASYNC>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_ASYNC_V1>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegAsyncV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
{
|
||||
@@ -111,6 +125,8 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
|
||||
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
|
||||
case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6;
|
||||
case PipelineVersion::ASYNC_V1: return ck_tile_pipeline::BASIC_ASYNC_V1;
|
||||
case PipelineVersion::ASYNC_V4: return ck_tile_pipeline::COMPUTE_ASYNC;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
|
||||
@@ -54,14 +54,15 @@ void init_tensor_buffer_uniform_int(void* buf,
|
||||
|
||||
// we might be asked to generate int values on fp data types that don't have the required
|
||||
// precision
|
||||
if(static_cast<ck_type>(max_value - 1) == static_cast<ck_type>(min_value))
|
||||
if(static_cast<ck_type>(max_value - 1) <= static_cast<ck_type>(min_value) &&
|
||||
static_cast<ck_type>(max_value - 1) >= static_cast<ck_type>(min_value))
|
||||
{
|
||||
throw std::runtime_error("Error while filling device tensor with random integer data: "
|
||||
"insufficient precision in specified range");
|
||||
}
|
||||
size_t packed_size = ck::packed_size_v<ck_type>;
|
||||
fill_tensor_uniform_rand_int_values<<<256, 256>>>(
|
||||
static_cast<ck_type>(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type));
|
||||
static_cast<ck_type*>(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type));
|
||||
}
|
||||
|
||||
/// @brief Initialize tensor data with a uniform float distribution
|
||||
|
||||
@@ -158,6 +158,8 @@ enum class PipelineVersion
|
||||
V4,
|
||||
V5,
|
||||
V6,
|
||||
ASYNC_V1,
|
||||
ASYNC_V4,
|
||||
WEIGHT_ONLY
|
||||
};
|
||||
|
||||
@@ -330,6 +332,8 @@ inline std::string_view to_string(PipelineVersion ver)
|
||||
case V4: return "V4";
|
||||
case V5: return "V5";
|
||||
case V6: return "V6";
|
||||
case ASYNC_V1: return "ASYNC_V1";
|
||||
case ASYNC_V4: return "ASYNC_V4";
|
||||
case WEIGHT_ONLY: return "WEIGHT_ONLY";
|
||||
default: return "Unknown";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user