mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)
[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The two-stage kernels were not supported in the initial PR. This PR adds the the missing bwd weight two-stage kernels to the CK Profiler. ## Technical Details Extended the CK Tile conv builder factory to build also the elementwise ops required for the two-stage kernels. Extended the CK Builder for CK Tile instance to accept the two-stage flag as part of the algorithm configuration. ## Test Plan Added units tests for CK Builder that verify the two-stage kernel construction. ## Test Result If CI passes, the added unit tests are passing. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fc2f95620d
commit
e2f5ab8000
@@ -155,6 +155,7 @@ concept TileOptimizationsDescriptor = requires(T t) {
|
||||
{ t.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ t.split_image } -> std::convertible_to<bool>;
|
||||
{ t.explicit_gemm } -> std::convertible_to<bool>;
|
||||
{ t.two_stage } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
|
||||
@@ -295,6 +296,7 @@ concept SpecifiesTileOptimizations = requires {
|
||||
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ T::optimizations.split_image } -> std::convertible_to<bool>;
|
||||
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
|
||||
{ T::optimizations.two_stage } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
@@ -68,6 +70,10 @@ struct ConvTileFactory
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using ConvOutDataType = std::conditional_t<OPTIMIZATIONS.two_stage,
|
||||
typename Types::AccDataType,
|
||||
typename Types::EDataType>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
@@ -103,7 +109,7 @@ struct ConvTileFactory
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::AccDataType,
|
||||
typename Types::EDataType,
|
||||
ConvOutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
@@ -126,4 +132,33 @@ struct ConvTileFactory
|
||||
ConvEpilogue>::Instance;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION = LATEST_API_VERSION>
|
||||
struct ElementwiseOpTileFactory
|
||||
{
|
||||
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
|
||||
|
||||
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
|
||||
using XDataType = Types::AccDataType;
|
||||
using WorkspaceDataType = Types::AccDataType;
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using YDataType = Types::EDataType;
|
||||
using BlockTile = ck_tile::sequence<BLOCK.per_block.m * BLOCK.per_block.n>;
|
||||
using BlockWarps = ck_tile::sequence<BLOCK_GEMM.warps.m * BLOCK_GEMM.warps.n>;
|
||||
using WarpTile = ck_tile::sequence<BLOCK_GEMM.warp_tile.m * BLOCK_GEMM.warp_tile.n>;
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
|
||||
|
||||
// Conversion from X -> Y.
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
WorkspaceDataType,
|
||||
YDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
|
||||
using Instance = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
@@ -34,6 +34,7 @@ struct TileOptimizations
|
||||
int num_groups_to_merge = 1;
|
||||
bool split_image = false;
|
||||
bool explicit_gemm = false;
|
||||
bool two_stage = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
@@ -181,7 +182,8 @@ consteval TileOptimizations SetTileOptimizations()
|
||||
|
||||
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
|
||||
.split_image = OPT.split_image,
|
||||
.explicit_gemm = OPT.explicit_gemm};
|
||||
.explicit_gemm = OPT.explicit_gemm,
|
||||
.two_stage = OPT.two_stage};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -91,6 +91,100 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
}
|
||||
}
|
||||
|
||||
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
auto& elementwise_op,
|
||||
const Args<SIGNATURE>& args,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output,
|
||||
const ck_tile::stream_config s_conf)
|
||||
{
|
||||
using Conv = std::remove_reference_t<decltype(conv)>;
|
||||
using ElementwiseOp = std::remove_reference_t<decltype(elementwise_op)>;
|
||||
using WorkspaceDataType = typename ElementwiseOp::ComputeDataType;
|
||||
using CDataType = typename ElementwiseOp::YDataType;
|
||||
using BlockShape = typename ElementwiseOp::Problem::BlockShape;
|
||||
|
||||
const auto param = args.to_ck_tile_conv_param();
|
||||
|
||||
ck_tile::GroupedConvHostArgs<InDataType*, WeiDataType*, OutDataType*, ck_tile::PassThrough>
|
||||
host_args(param, input, weight, {}, output, args.k_batch);
|
||||
|
||||
// Set-up for elementwise op kernel.
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
std::accumulate(host_args.filter_spatial_lengths_.begin(),
|
||||
host_args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(host_args.G_ * host_args.K_ * host_args.C_ *
|
||||
spatial_lengths_accum * sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
|
||||
ck_tile::GroupedConvBwdWeightHostArgs(host_args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
auto kargs = Conv::MakeKernelArgs(ws_args);
|
||||
const dim3 grids = Conv::GridSize(kargs);
|
||||
const dim3 blocks = Conv::BlockSize();
|
||||
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
return RunResult::not_supported("unsupported ck_tile arguments");
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {
|
||||
static_cast<ck_tile::index_t>(host_args.G_ * host_args.K_),
|
||||
static_cast<ck_tile::index_t>(host_args.C_ * spatial_lengths_accum)};
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseOp::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockShape::kBlockM;
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
||||
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseOp::IsSupportedArgument(input_size))
|
||||
{
|
||||
return RunResult::not_supported("unsupported ck_tile arguments for elementwise op");
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
if(args.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
shape[0] * shape[1] * sizeof(WorkspaceDataType),
|
||||
s_conf.stream_id_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
|
||||
s_conf,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<minimum_occupancy>(elementwise_op,
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr))));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a convolution is invoked like CK Tile.
|
||||
@@ -149,4 +243,28 @@ template <auto SIGNATURE>
|
||||
s_conf);
|
||||
}
|
||||
|
||||
/// @brief `run()` specialization for two-stage backwards weight convolution and CK Tile.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
auto& elementwise_op,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config s_conf = {})
|
||||
{
|
||||
return detail::run(conv,
|
||||
elementwise_op,
|
||||
args,
|
||||
static_cast<const void*>(inputs.input),
|
||||
static_cast<void*>(outputs.weight),
|
||||
static_cast<const void*>(inputs.output),
|
||||
s_conf);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -25,8 +25,10 @@ TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
.with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1,
|
||||
.split_image = false,
|
||||
.explicit_gemm = false,
|
||||
.two_stage = false});
|
||||
|
||||
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
namespace ckf = ck_tile::builder::factory;
|
||||
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using ck_tile::test::MatchesReference;
|
||||
@@ -31,12 +32,49 @@ constexpr auto ALGORITHM =
|
||||
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(cku::TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(ckt::TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
|
||||
.split_image = false,
|
||||
.explicit_gemm = false,
|
||||
.two_stage = false});
|
||||
|
||||
constexpr auto TWO_STAGE_ALGORITHM =
|
||||
cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(ckb::TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(cku::TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
|
||||
.split_image = false,
|
||||
.explicit_gemm = false,
|
||||
.two_stage = true});
|
||||
|
||||
constexpr ckt::Args<SIGNATURE> Args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = 2,
|
||||
.groups = 4,
|
||||
.input_channels = 32,
|
||||
.output_channels = 48,
|
||||
.image = {.width = 32, .height = 56},
|
||||
.filter = {.width = 3, .height = 3},
|
||||
},
|
||||
.filter_strides = {.width = 1, .height = 1},
|
||||
.filter_dilation = {.width = 1, .height = 1},
|
||||
.input_left_pad = {.width = 0, .height = 0},
|
||||
.input_right_pad = {.width = 0, .height = 0},
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
using TwoStageBuilder = ckb::ConvBuilder<SIGNATURE, TWO_STAGE_ALGORITHM>;
|
||||
using TwoStageInstance = TwoStageBuilder::Instance;
|
||||
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<SIGNATURE, TWO_STAGE_ALGORITHM>;
|
||||
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
|
||||
|
||||
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
|
||||
TEST(BwdWeight_2D_FP16_NHWGC, Create)
|
||||
@@ -61,38 +99,47 @@ TEST(BwdWeight_2D_FP16_NHWGC, Create)
|
||||
});
|
||||
}
|
||||
|
||||
TEST(ElementWiseOp, CreateBwdWeightTwoStageElementwiseOp)
|
||||
{
|
||||
cku::run_ck_tile_test<ElementwiseOpBuilder>({"elementwise_kernel",
|
||||
"4096_256_4_4_64_4_256",
|
||||
"UnaryConvert",
|
||||
"kPad_1",
|
||||
"ElementWiseDefaultPolicy"});
|
||||
}
|
||||
|
||||
TEST(BwdWeight_2D_FP16_NHWGC, Execution)
|
||||
{
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = 2,
|
||||
.groups = 4,
|
||||
.input_channels = 32,
|
||||
.output_channels = 48,
|
||||
.image = {.width = 32, .height = 56},
|
||||
.filter = {.width = 3, .height = 3},
|
||||
},
|
||||
.filter_strides = {.width = 1, .height = 1},
|
||||
.filter_dilation = {.width = 1, .height = 1},
|
||||
.input_left_pad = {.width = 0, .height = 0},
|
||||
.input_right_pad = {.width = 0, .height = 0},
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
auto inputs = ckt::alloc_inputs(Args);
|
||||
auto outputs = ckt::alloc_outputs(Args);
|
||||
auto reference = ckt::alloc_outputs(Args);
|
||||
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
ckt::init_inputs(Args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
EXPECT_THAT(ckt::run(conv, Args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get()));
|
||||
}
|
||||
|
||||
TEST(BwdWeight_TwoStage_2D_FP16_NHWGC, Execution)
|
||||
{
|
||||
auto inputs = ckt::alloc_inputs(Args);
|
||||
auto outputs = ckt::alloc_outputs(Args);
|
||||
auto reference = ckt::alloc_outputs(Args);
|
||||
|
||||
ckt::init_inputs(Args, inputs.get());
|
||||
|
||||
auto conv = TwoStageInstance{};
|
||||
auto elementwise_op = ElementwiseOpInstance{};
|
||||
|
||||
EXPECT_THAT(ckt::run(conv, elementwise_op, Args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -24,8 +24,10 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
.with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1,
|
||||
.split_image = false,
|
||||
.explicit_gemm = false,
|
||||
.two_stage = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
|
||||
@@ -31,8 +31,10 @@ constexpr auto ALGORITHM =
|
||||
.with_tile_thread_block(cku::FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(cku::FwdTileTransfer_4x4x4)
|
||||
.with_tile_optimizations(ckt::TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
|
||||
.split_image = false,
|
||||
.explicit_gemm = false,
|
||||
.two_stage = false});
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
@@ -377,6 +377,8 @@ struct TileOptimizations
|
||||
bool split_image;
|
||||
// Explicit gemm for 1x1, stride=0, pad=0 cases
|
||||
bool explicit_gemm;
|
||||
// Two-stage kernels
|
||||
bool two_stage;
|
||||
};
|
||||
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ class ConvInstanceTemplateParams:
|
||||
warp_tile,
|
||||
double_smem_buffer,
|
||||
num_wave_groups,
|
||||
is_two_stage_instance,
|
||||
pipeline_version,
|
||||
scheduler,
|
||||
scalar_per_vector,
|
||||
@@ -27,6 +28,7 @@ class ConvInstanceTemplateParams:
|
||||
self.warp_tile = warp_tile
|
||||
self.double_smem_buffer = double_smem_buffer
|
||||
self.num_wave_groups = num_wave_groups
|
||||
self.is_two_stage_instance = is_two_stage_instance
|
||||
self.pipeline_version = pipeline_version
|
||||
self.scheduler = scheduler
|
||||
self.scalar_per_vector = scalar_per_vector
|
||||
@@ -39,7 +41,8 @@ class ConvInstanceTemplateParams:
|
||||
explicit_gemm = "true" if self.explicit_gemm else "false"
|
||||
split_image = "true" if self.split_image else "false"
|
||||
num_groups_to_merge = str(self.num_groups_to_merge)
|
||||
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}}}"
|
||||
two_stage_instance = "true" if self.is_two_stage_instance else "false"
|
||||
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}, .two_stage = {two_stage_instance}}}"
|
||||
|
||||
def get_specialization(self):
|
||||
namespace = "ckb::TileConvSpecialization::"
|
||||
@@ -270,6 +273,8 @@ def parse_fwd_instances(instances, problem_name):
|
||||
print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.")
|
||||
continue
|
||||
|
||||
is_two_stage = False
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
[m_per_block, n_per_block, k_per_block],
|
||||
@@ -277,6 +282,7 @@ def parse_fwd_instances(instances, problem_name):
|
||||
[m_per_xdl, n_per_xdl, k_per_xdl],
|
||||
double_smem_buffer,
|
||||
num_wave_groups,
|
||||
is_two_stage,
|
||||
pipeline_version,
|
||||
scheduler,
|
||||
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
|
||||
@@ -343,7 +349,7 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
num_groups_to_merge = 1
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
blk_gemm_pipeline_schduler = args[6]
|
||||
block_gemm_pipeline_scheduler = args[6]
|
||||
blk_gemm_pipeline_version = args[7]
|
||||
else:
|
||||
spec = args[11]
|
||||
@@ -372,20 +378,29 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
num_groups_to_merge = int(args[44])
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
blk_gemm_pipeline_schduler = args[39]
|
||||
block_gemm_pipeline_scheduler = args[39]
|
||||
blk_gemm_pipeline_version = args[40]
|
||||
elif is_two_stage_instance:
|
||||
print(f"Skipping instance {instance_id} with device op {device_op_name} since it's not supported yet.")
|
||||
continue
|
||||
if len(args) != 46:
|
||||
raise RuntimeError(f"Wrong number of parameters in the TwoStage instance string: {instance}\n" +
|
||||
f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters.")
|
||||
|
||||
num_groups_to_merge = args[41]
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
block_gemm_pipeline_scheduler = args[39]
|
||||
blk_gemm_pipeline_version = args[40]
|
||||
|
||||
else:
|
||||
# Regular V1 XDL CShuffle instance
|
||||
if len(args) != 43:
|
||||
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}")
|
||||
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n" +
|
||||
f"Expected 43 parameters for V1 instance. Found {len(args)} parameters.")
|
||||
|
||||
num_groups_to_merge = 1
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
blk_gemm_pipeline_schduler = "Intrawave"
|
||||
block_gemm_pipeline_scheduler = "Intrawave"
|
||||
blk_gemm_pipeline_version = "v1"
|
||||
|
||||
# Common part to all solvers.
|
||||
@@ -393,15 +408,15 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
# Sanity check for Block GEMM pipeline parameters
|
||||
# Scheduler must be either Intrawave or Interwave.
|
||||
# Version must be from v1 to v5
|
||||
if blk_gemm_pipeline_schduler not in ["Intrawave", "Interwave"]:
|
||||
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {blk_gemm_pipeline_schduler} in instance: {instance}")
|
||||
if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]:
|
||||
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}")
|
||||
if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]:
|
||||
raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}")
|
||||
|
||||
split_image = instance.find("Large") != -1
|
||||
double_smem_buffer = blk_gemm_pipeline_version == "v4"
|
||||
num_wave_groups = 1
|
||||
scheduler = blk_gemm_pipeline_schduler
|
||||
scheduler = block_gemm_pipeline_scheduler
|
||||
pipeline_version = blk_gemm_pipeline_version.upper()
|
||||
|
||||
# OLd CK pipeline version V5 maps to V6 for CK Tile
|
||||
@@ -428,6 +443,7 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
[m_per_xdl, n_per_xdl, k_per_xdl],
|
||||
double_smem_buffer,
|
||||
num_wave_groups,
|
||||
is_two_stage_instance,
|
||||
pipeline_version,
|
||||
scheduler,
|
||||
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
namespace ckf = ck_tile::builder::factory;
|
||||
|
||||
namespace ck_tile::builder::profiling {
|
||||
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using ConvInstance = Builder::Instance;
|
||||
|
||||
auto conv = ConvInstance{};
|
||||
|
||||
auto result = [&]<auto Sig, auto Alg>() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<Sig> && Alg.optimizations.two_stage)
|
||||
{
|
||||
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<Sig, Alg>;
|
||||
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
|
||||
auto elementwise_op = ElementwiseOpInstance{};
|
||||
return ckt::run(conv, elementwise_op, args, inputs, outputs, s_conf);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
}
|
||||
}.template operator()<SIGNATURE, ALGORITHM>();
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
@@ -108,6 +109,19 @@ struct ElementWiseKernel
|
||||
ignore = input_sizes;
|
||||
return true;
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "elementwise_kernel",
|
||||
Problem::GetName(),
|
||||
"policy",
|
||||
Policy::GetName()
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -24,6 +24,11 @@ struct ElementWiseDefaultPolicy
|
||||
sequence<0, 3>>{} // Yield
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
return "ElementWiseDefaultPolicy";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -21,6 +22,19 @@ struct ElementWisePipelineProblem
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
|
||||
static constexpr bool kPad = kPad_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_',
|
||||
BlockShape::GetName(),
|
||||
"op",
|
||||
ElementWiseOperation::name,
|
||||
"kPad",
|
||||
kPad
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -25,6 +26,15 @@ struct ElementWiseShape
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "shape",
|
||||
kBlockM, kWarpM, kVectorM, kWarpPerBlockM, kThreadPerWarpM, kRepeatM, kBlockSize
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user