Merge commit '4a49dac7c6fff9ffe4d275bed761a79e51188f3c' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-28 13:19:19 +00:00
parent b36f983558
commit c8234bd387
6 changed files with 30 additions and 16 deletions

View File

@@ -78,7 +78,6 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -98,8 +97,8 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -121,7 +120,7 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -77,10 +77,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
// we intentionally do not use pipeline for this example and let the kernel be composite of
// Problem and Policy
constexpr ck_tile::index_t kBlockSize = Shape::BlockSize;
auto blockSize = Kernel::BlockSize();
// Print configuration information
std::cout << "block size (number of threads per block) " << kBlockSize << std::endl;
std::cout << "block size (number of threads per block) " << blockSize << std::endl;
std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl;
std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{})
<< " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl;
@@ -103,7 +103,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
launch_kernel(ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
ck_tile::make_kernel<1>(Kernel{},
kGridSize,
kBlockSize,
blockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),

View File

@@ -27,8 +27,9 @@ struct TileCopyShape
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
// Wave tile dimensions
static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{});
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{});
static constexpr index_t Wave_Tile_M = ThreadTile_M * ThreadTile_N * WaveSize / Wave_Tile_N;
// Block tile dimensions
static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{});
@@ -45,7 +46,6 @@ struct TileCopyShape
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
// Hardware configuration
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize;
// Configuration validation
@@ -60,8 +60,10 @@ struct TileCopyShape
"Invalid wave configuration for N dimension");
// Ensure wave tile dimensions align with wave size
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize,
"(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize");
#endif
};
/**
@@ -200,6 +202,19 @@ struct ElementWiseTileCopyKernel
using XDataType = typename Problem::XDataType;
using Policy = ck_tile::remove_cvref_t<Policy_>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static auto BlockSize()
{
if(ck_tile::is_wave32())
{
return kBlockSize / 2;
}
else
{
return kBlockSize;
}
}
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -130,7 +130,7 @@ struct FusedMoeGemmKernel
// static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
static constexpr index_t kBlockSize = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType;
@@ -231,7 +231,7 @@ struct FusedMoeGemmKernel
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }

View File

@@ -1,5 +1,5 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -487,7 +487,7 @@ struct GroupedConvolutionBackwardDataKernel
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
@@ -530,7 +530,7 @@ struct GroupedConvolutionBackwardDataKernel
return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr GroupedConvBwdDataKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdDataHostArgs& hostArgs)