mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Create invoker through the CK builder.
This commit is contained in:
@@ -2,15 +2,11 @@
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
#include "../utils/types.hpp"
|
||||
|
||||
int main() {
|
||||
|
||||
int main()
|
||||
{
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckb_examples = ck_tile::builder::examples;
|
||||
|
||||
constexpr size_t m_tile = 128;
|
||||
constexpr size_t n_tile = 128;
|
||||
constexpr size_t k_tile = 32;
|
||||
|
||||
constexpr ckb_examples::ConvSignature FwdConvSignature
|
||||
{
|
||||
.spatial_dim = 2,
|
||||
@@ -20,10 +16,20 @@ int main() {
|
||||
};
|
||||
static_assert(ckb::ValidConvSignature<FwdConvSignature>);
|
||||
|
||||
// To get valid configuration parameters, refer to "device_grouped_conv_fwd_xdl_comp_instance.hpp".
|
||||
// This file contains the current instances of the kernel DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
|
||||
// Currently the build has this kernel hard-coded.
|
||||
// In the future, we may need builders per kernel type since they typically have slightly different parameters.
|
||||
|
||||
constexpr ckb::ThreadBlock FwdThreadBlock
|
||||
{
|
||||
.block_size = 256,
|
||||
.submatrix = {.m = m_tile, .n = n_tile, .k = k_tile}
|
||||
.submatrix = {.m = 256, .n = 256, .k = 32} // Tile sizes
|
||||
};
|
||||
|
||||
constexpr ckb::ConvTuningParams FwdTuningParams
|
||||
{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl=32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4
|
||||
};
|
||||
|
||||
constexpr ckb_examples::BlockTransfer FwdBlockTransfer
|
||||
@@ -33,35 +39,36 @@ int main() {
|
||||
.thread_cluster_dims_c = {
|
||||
.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.vector_transfer_a = {
|
||||
.src_vector_dim = 2, .src_scaler_per_vector = 8, .dest_scaler_per_vector_k1 = 8, .add_extra = true},
|
||||
.src_vector_dim = 2, .src_scaler_per_vector = 2, .dest_scaler_per_vector_k1 = 8, .add_extra = false},
|
||||
.vector_transfer_b = {
|
||||
.src_vector_dim = 2, .src_scaler_per_vector = 8, .dest_scaler_per_vector_k1 = 8, .add_extra = true},
|
||||
.src_vector_dim = 2, .src_scaler_per_vector = 8, .dest_scaler_per_vector_k1 = 8, .add_extra = false},
|
||||
.vector_transfer_c = {
|
||||
.m_xdl_per_wave_per_shuffle = 1, .n_xdl_per_wave_per_shuffle = 2, .scaler_per_vector = 8},
|
||||
.m_xdl_per_wave_per_shuffle = 1, .n_xdl_per_wave_per_shuffle = 1, .scaler_per_vector = 8},
|
||||
.a_thread_cluster_access_order = {1, 0, 2},
|
||||
.b_thread_cluster_access_order = {1, 0, 2},
|
||||
.a_source_access_order = {1, 0, 2},
|
||||
.b_source_access_order = {1, 0, 2}
|
||||
};
|
||||
|
||||
|
||||
constexpr ckb_examples::ConvAlgorithm FwdConvAlgorithm
|
||||
{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 1, .n_xdl_per_wave = 4},
|
||||
.tuning_params = FwdTuningParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V1,
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
|
||||
};
|
||||
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
const auto kernel_string = Builder::Instance::TypeString();
|
||||
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
|
||||
// The invoker is the entrypoint to launch the kernel.
|
||||
// Creating the invoker triggers the validation of the builder configuration.
|
||||
//auto invoker = Builder::Instance::MakeInvoker();
|
||||
//(void)invoker;
|
||||
// Creating the invoker triggers the validation of the builder configuration,
|
||||
// that is, the combination of all builder parameters is checked at compile time.
|
||||
auto invoker = Builder::Instance::MakeInvoker();
|
||||
|
||||
// TODO: Prepare actual data and launch the kernel.
|
||||
(void)invoker;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -50,6 +50,8 @@ template <typename T>
|
||||
concept ConvTuningDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> std::convertible_to<int>;
|
||||
{ t.bk1 } -> std::convertible_to<int>;
|
||||
{ t.m_per_xdl } -> std::convertible_to<int>;
|
||||
{ t.n_per_xdl } -> std::convertible_to<int>;
|
||||
{ t.m_xdl_per_wave } -> std::convertible_to<int>;
|
||||
{ t.n_xdl_per_wave } -> std::convertible_to<int>;
|
||||
};
|
||||
@@ -60,6 +62,8 @@ struct ConvTuningParams
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
int ak1 = 0;
|
||||
int bk1 = 0;
|
||||
int m_per_xdl = 0;
|
||||
int n_per_xdl = 0;
|
||||
int m_xdl_per_wave = 0;
|
||||
int n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user