Merge commit '9fcc1ee9fd9730efd865f530afde505f2556954d' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-18 17:12:50 +00:00
parent d436787ed0
commit 68b20e1d4f
113 changed files with 610 additions and 531 deletions

View File

@@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -111,7 +111,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -124,8 +123,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -144,7 +143,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -137,11 +137,11 @@ class TestCkTileBatchedTranspose // N C H W layout_in==
Config::BlockTile::at(1)};
auto kargs = Kernel::MakeKargs(host_args);
auto sc = ck_tile::stream_config{};
const dim3 grid_size = Kernel::GridSize(host_args);
constexpr dim3 block_size = Kernel::BlockSize();
ck_tile::launch_kernel(
sc, ck_tile::make_kernel<block_size.x, 1>(Kernel{}, grid_size, block_size, 0, kargs));
auto sc = ck_tile::stream_config{};
const dim3 grid_size = Kernel::GridSize(host_args);
const dim3 block_size = Kernel::BlockSize();
ck_tile::launch_kernel(sc,
ck_tile::make_kernel<1>(Kernel{}, grid_size, block_size, 0, kargs));
y_dev.FromDevice(y_host.data());
ck_tile::reference_batched_transpose<DataType>(x_host, y_ref, layout_in, layout_out);

View File

@@ -118,19 +118,17 @@ class TestCkTileElementwise : public ::testing::Test
"The kernel configuration is not supported for the given input size.");
}
ck_tile::launch_kernel(
s,
ck_tile::make_kernel<TestElementWiseShape::kBlockSize, // MaxThreadPerBlock
kBlockPerCu> // MinBlockPerCu
(ew_kernel,
grid,
block,
0, // actual shared memory
lens,
strides, // input strides
strides, // output strides
d_x_ptrs_tuple,
p_y_device));
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu> // MinBlockPerCu
(ew_kernel,
grid,
block,
0, // actual shared memory
lens,
strides, // input strides
strides, // output strides
d_x_ptrs_tuple,
p_y_device));
d_y_mem.FromDevice(h_y.data());

View File

@@ -77,7 +77,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -93,8 +92,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -114,7 +113,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -91,7 +91,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -114,7 +113,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -165,15 +164,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};

View File

@@ -10,6 +10,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/core/numeric/math.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
@@ -184,7 +185,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipeline::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -207,7 +207,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -222,7 +222,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {

View File

@@ -99,7 +99,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -114,8 +113,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
@@ -139,7 +138,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -178,7 +178,6 @@ class TestCkTileGemmMultiD : public ::testing::Test
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -192,8 +191,8 @@ class TestCkTileGemmMultiD : public ::testing::Test
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -212,7 +211,7 @@ class TestCkTileGemmMultiD : public ::testing::Test
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -183,7 +183,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipeline::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -206,7 +205,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -221,7 +220,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {

View File

@@ -136,7 +136,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
@@ -150,8 +149,8 @@ class TestCkTileGroupedGemm : public ::testing::Test
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
@@ -169,7 +168,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
grids,
blocks,
@@ -227,12 +226,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
@@ -242,8 +235,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
@@ -268,7 +259,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
@@ -279,8 +269,8 @@ class TestCkTileGroupedGemm : public ::testing::Test
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
@@ -291,7 +281,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,

View File

@@ -97,13 +97,13 @@ class TestCkTileImageToColumn : public ::testing::Test
kargs.N * kargs.output_spatial_lengths[0] * kargs.output_spatial_lengths[1],
kargs.filter_spatial_lengths[0] * kargs.filter_spatial_lengths[1] * kargs.C,
kargs.G);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;
ck_tile::launch_kernel(
ck_tile::stream_config{},
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
// reference
ck_tile::reference_im2col<DataType, DataType, NDimSpatial>(in, out_host, conv_params);

View File

@@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a)
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -76,17 +76,17 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam<std::tuple<int, int
constexpr ck_tile::index_t kBlockSize = 128;
constexpr ck_tile::index_t kBlockPerCu = 1;
auto ms = launch_kernel(ck_tile::stream_config{nullptr, true},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n,
warp_id));
auto ms = launch_kernel(
ck_tile::stream_config{nullptr, true},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n,
warp_id));
auto bytes = 2 * m * n * sizeof(DataType);
std::cout << "elapsed: " << ms << " (ms)" << std::endl;

View File

@@ -64,7 +64,8 @@ struct TileCopy
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
static constexpr bool AsyncCopy = Problem::AsyncCopy;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
static constexpr bool AsyncCopy = Problem::AsyncCopy;
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution()

View File

@@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_size = kernel::GetSmemSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
@@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
} \
}
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)

View File

@@ -1,5 +1,5 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
ck_tile::kentry<1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
static constexpr ck_tile::index_t kBlockSize = BLOCK_SIZE;
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;

View File

@@ -54,11 +54,11 @@ float permute(permute_args a, const ck_tile::stream_config& s)
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}

View File

@@ -82,18 +82,18 @@ class TestCkTileReduce : public ::testing::Test
throw std::runtime_error("Wrong! Arguments not supported!\n");
}
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
static_cast<YDataType*>(d_y_mem.GetDeviceBuffer()),
input_shape_tuple,
input_strides_tuple,
kept_dims,
reduce_dims));
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
static_cast<YDataType*>(d_y_mem.GetDeviceBuffer()),
input_shape_tuple,
input_strides_tuple,
kept_dims,
reduce_dims));
// Get results back
d_y_mem.FromDevice(h_y.data());

View File

@@ -246,7 +246,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -57,5 +57,5 @@ float smoothquant_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -13,11 +13,11 @@
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
constexpr dim3 blocks = kernel::BlockSize(); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;