Merge commit '9fcc1ee9fd9730efd865f530afde505f2556954d' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-18 17:12:50 +00:00
parent d436787ed0
commit 68b20e1d4f
113 changed files with 610 additions and 531 deletions

View File

@@ -60,13 +60,30 @@ enum struct memory_operation_enum : std::uint16_t
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}
CK_TILE_HOST bool is_wave32()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return false;
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return false;
}
return props.major > 9;
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }

View File

@@ -274,12 +274,6 @@
#define CK_TILE_WA_ISSUE_2028 0
#endif
#ifndef CK_TILE_WAVE32_ENABLED
#if defined(__gfx11__) || defined(__gfx12__)
#define CK_TILE_WAVE32_ENABLED
#endif
#endif
// Y pointed to R, we don't see a valuable use case.
// Will enforce encoding to check Y not pointed to R if set to zero
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R

View File

@@ -15,9 +15,9 @@
namespace ck_tile {
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
template <int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
#endif
__global__ void kentry(Args... args)
{
@@ -35,15 +35,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
//
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
//
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU, typename KernelImpl, typename... Args>
CK_TILE_HOST auto
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
const auto kernel = kentry<MinBlockPerCu, KernelImpl, Args...>;
return [=](const stream_config& s) {
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -53,6 +53,7 @@ struct AddRmsnorm2dRdquantFwd
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};

View File

@@ -34,6 +34,8 @@ struct BatchedTransposeKernel
using Type = typename Problem::DataType;
static constexpr index_t kBlockSize = Problem::kBlockSize;
struct BatchedTransposeKargs
{
const void* p_input;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -20,11 +20,10 @@ struct BatchedTransposeLdsProblem
static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{});
static constexpr index_t kColWarps_ = NumWarps::at(number<1>{});
static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_;
static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{});
static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{});
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kBlockSize = get_warp_size() * kRowWarps_ * kColWarps_;
// warps per block
static constexpr index_t kLeadNumWarps = kColWarps_;
static constexpr index_t kSecondNumWarps = kRowWarps_;

View File

@@ -20,6 +20,8 @@ struct ElementWiseKernel
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(Dims lens,
Dims input_strides,

View File

@@ -17,7 +17,6 @@ template <typename ADataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kBlockSize_,
index_t kM_,
index_t kN_,
index_t MWave_,
@@ -40,7 +39,7 @@ struct CShuffleEpilogueProblem
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;

View File

@@ -91,13 +91,13 @@ struct FlatmmKernel
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
@@ -127,7 +127,7 @@ struct FlatmmKernel
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)

View File

@@ -237,15 +237,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) / 2;
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) / 4;
return TileShape::WarpTile::at(I2) * scale / 4;
}
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -24,9 +24,10 @@ namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaBatchPrefillWithPagedKVCacheKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ struct FmhaFwdAppendKVKernel
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;

View File

@@ -25,9 +25,10 @@ namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;

View File

@@ -30,6 +30,7 @@ struct FmhaFwdPagedKVKernel
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -14,6 +14,7 @@ struct FmhaFwdSplitKVCombineKernel
static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -26,6 +26,7 @@ struct FmhaFwdSplitKVKernel
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -213,7 +213,7 @@ struct MoeSortingKernel
using Hargs = MoeSortingHostArgs;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t kBlockSize = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
struct Kargs
@@ -487,8 +487,8 @@ struct MoeSortingKernel
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
auto zero_ = vector_type{0};
for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems;
i += (gridDim.x - 1) * BLOCK_SIZE)
for(long_index_t i = (blockIdx.x - 1) * kBlockSize + threadIdx.x; i < total_elems;
i += (gridDim.x - 1) * kBlockSize)
{
p_buf[i] = zero_;
}
@@ -1419,7 +1419,7 @@ template <typename Problem_>
struct MoeSortingClearWorkspaceKernel
{
using Problem = remove_cvref_t<Problem_>;
static constexpr index_t BLOCK_SIZE = Problem::BlockSize;
static constexpr index_t kBlockSize = Problem::BlockSize;
static constexpr index_t OCCUPANCY = Problem::Occu;
using Hargs = MoeSortingHostArgs;
@@ -1461,7 +1461,7 @@ struct MoeSortingClearWorkspaceKernel
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
@@ -1499,8 +1499,8 @@ struct MoeSortingClearWorkspaceKernel
vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
auto zero_ = vector_type{0};
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems;
i += gridDim.x * BLOCK_SIZE)
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elems;
i += gridDim.x * kBlockSize)
{
p_expert_mesh[i] = zero_;
}
@@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P0
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t kBlockSize = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
typedef MoeSortingHostArgs MoeSortingKargs;
@@ -1604,7 +1604,7 @@ struct MoeSortingMultiPhaseKernel_P0
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
@@ -1647,8 +1647,8 @@ struct MoeSortingMultiPhaseKernel_P0
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
i += gridDim.x * BLOCK_SIZE)
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
i += gridDim.x * kBlockSize)
{
auto x = p_topk_ids[i];
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
@@ -1678,7 +1678,7 @@ struct MoeSortingMultiPhaseKernel_P1
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t kBlockSize = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
typedef MoeSortingHostArgs MoeSortingKargs;
@@ -1709,12 +1709,12 @@ struct MoeSortingMultiPhaseKernel_P1
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
return kBlockSize / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1756,7 +1756,7 @@ struct MoeSortingMultiPhaseKernel_P1
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
if constexpr(Problem::LocalExpertMasking)
{
@@ -1768,7 +1768,7 @@ struct MoeSortingMultiPhaseKernel_P1
index_t cnt = 0; // per-wave cnt
for(int i = 0; i < loops; i++)
{
int position = i * BLOCK_SIZE + threadIdx.x;
int position = i * kBlockSize + threadIdx.x;
r_t v{0};
if(position < (mesh_stride / index_pack))
v = p_expert_mesh[position];
@@ -1792,7 +1792,7 @@ struct MoeSortingMultiPhaseKernel_P1
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
{
c += s[i];
}
@@ -1811,7 +1811,7 @@ struct MoeSortingMultiPhaseKernel_P01
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t kBlockSize = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
typedef MoeSortingHostArgs MoeSortingKargs;
@@ -1878,12 +1878,12 @@ struct MoeSortingMultiPhaseKernel_P01
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h)
{
index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile;
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
// no more than grid_size
return min(elem_cnt, GridSize(h));
@@ -1892,7 +1892,7 @@ struct MoeSortingMultiPhaseKernel_P01
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
return kBlockSize / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1921,7 +1921,7 @@ struct MoeSortingMultiPhaseKernel_P01
if constexpr(Problem::LocalToken)
{
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
// no more than grid_size
return min(elem_cnt, kargs.wg_count);
@@ -1940,8 +1940,8 @@ struct MoeSortingMultiPhaseKernel_P01
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
#pragma unroll Problem::SubTokenTile
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
i += BLOCK_SIZE * gridDim.x)
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
i += kBlockSize * gridDim.x)
{
auto x = p_topk_ids[i];
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
@@ -1996,7 +1996,7 @@ struct MoeSortingMultiPhaseKernel_P01
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
int loops = (kargs.mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
if constexpr(Problem::LocalExpertMasking)
{
@@ -2008,7 +2008,7 @@ struct MoeSortingMultiPhaseKernel_P01
index_t cnt = 0; // per-wave cnt
for(int i = 0; i < loops; i++)
{
int position = i * BLOCK_SIZE + threadIdx.x;
int position = i * kBlockSize + threadIdx.x;
r_t v{0};
if(position < (kargs.mesh_stride / index_pack))
v = p_expert_mesh[position];
@@ -2033,7 +2033,7 @@ struct MoeSortingMultiPhaseKernel_P01
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
{
c += s[i];
}
@@ -2055,7 +2055,7 @@ struct MoeSortingMultiPhaseKernel_P2
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t kBlockSize = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
typedef MoeSortingHostArgs MoeSortingKargs;
@@ -2123,17 +2123,17 @@ struct MoeSortingMultiPhaseKernel_P2
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
#else
// use 1 block to cumsum
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
#endif
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
// return 2 * BLOCK_SIZE * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
// return 2 * kBlockSize * sizeof(IndexType);
return (4 + 2 * kBlockSize / get_warp_size()) * sizeof(IndexType);
}
// reduce single pixel within a wave
@@ -2142,7 +2142,7 @@ struct MoeSortingMultiPhaseKernel_P2
if(blockIdx.x > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
impl::moe_buf_set_zero_kernel_2d<kBlockSize>(kargs.p_moe_buf,
kargs.tokens,
kargs.moe_buf_interm_dim,
kargs.moe_buf_elem_bytes,
@@ -2150,7 +2150,7 @@ struct MoeSortingMultiPhaseKernel_P2
gridDim.x - 1);
return;
#else
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
impl::moe_buf_set_zero_kernel<kBlockSize>(
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
kargs.moe_buf_bytes,
blockIdx.x - 1);
@@ -2167,7 +2167,7 @@ struct MoeSortingMultiPhaseKernel_P2
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
index_t wave_id = threadIdx.x / get_warp_size();
index_t lane_id = threadIdx.x % get_warp_size();
@@ -2176,7 +2176,7 @@ struct MoeSortingMultiPhaseKernel_P2
for(index_t i = 0; i < loops; i++)
{
index_t position = i * BLOCK_SIZE + threadIdx.x;
index_t position = i * kBlockSize + threadIdx.x;
IndexType a_ = 0; // token count for a expert
IndexType b_ = 0; // mask for a expert
if(position < kargs.num_experts)
@@ -2221,15 +2221,15 @@ struct MoeSortingMultiPhaseKernel_P2
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2240,7 +2240,7 @@ struct MoeSortingMultiPhaseKernel_P2
cumsum_a += prev_cumsum_a;
cumsum_b += prev_cumsum_b;
if(threadIdx.x == BLOCK_SIZE - 1)
if(threadIdx.x == kBlockSize - 1)
{
s[2] = cumsum_a; // store the last cumsum
s[3] = cumsum_b;
@@ -2297,7 +2297,7 @@ struct MoeSortingMultiPhaseKernel_P3
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t kBlockSize = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
typedef MoeSortingHostArgs MoeSortingKargs;
@@ -2341,12 +2341,12 @@ struct MoeSortingMultiPhaseKernel_P3
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
return (4 + kBlockSize / get_warp_size()) * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -2391,11 +2391,11 @@ struct MoeSortingMultiPhaseKernel_P3
}
// cumsum one by one
int loops = (kargs.mesh_stride + BLOCK_SIZE - 1) / BLOCK_SIZE;
int loops = (kargs.mesh_stride + kBlockSize - 1) / kBlockSize;
int prev_cumsum = 0;
for(int i = 0; i < loops; i++)
{
int i_token = i * BLOCK_SIZE + threadIdx.x;
int i_token = i * kBlockSize + threadIdx.x;
IndexType x = 0;
if(i_token < tokens)
{
@@ -2414,13 +2414,13 @@ struct MoeSortingMultiPhaseKernel_P3
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
});
cumsum += prev_cumsum; // add previous round cumsum
if(threadIdx.x == BLOCK_SIZE - 1)
if(threadIdx.x == kBlockSize - 1)
{
s[0] = cumsum;
}
@@ -2441,7 +2441,7 @@ struct MoeSortingMultiPhaseKernel_P3
}
}
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
@@ -2457,9 +2457,9 @@ namespace impl {
// we use dynamic LDS size here
CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
{
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
constexpr index_t kBlockSize = 256; // hardcoded 256
const index_t expert_cumsum_elem = num_experts_ + 1;
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
return (4 + 2 * kBlockSize / get_warp_size() + expert_cumsum_elem) * sizeof(int);
}
} // namespace impl
@@ -2473,7 +2473,7 @@ struct MoeSortingMultiPhaseKernel_P23
using WeightType = typename Problem::WeightType;
using MeshType = typename Problem::MeshType;
static constexpr index_t BLOCK_SIZE = 256;
static constexpr index_t kBlockSize = 256;
static constexpr index_t OCCUPANCY = 2; // hard coded
typedef MoeSortingHostArgs MoeSortingKargs;
@@ -2563,18 +2563,18 @@ struct MoeSortingMultiPhaseKernel_P23
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
#else
// use 1 block to cumsum
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
#endif
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
// only use this at host !
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{
const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts);
const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType);
const auto smem_sf = kBlockSize * 4 * sizeof(IndexType);
return max(smem_23, smem_sf);
}
@@ -2595,7 +2595,7 @@ struct MoeSortingMultiPhaseKernel_P23
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
{
#if MOE_SORTING_FMOE_2D_BUF
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
impl::moe_buf_set_zero_kernel_2d<kBlockSize>(kargs.p_moe_buf,
tokens,
kargs.moe_buf_interm_dim,
kargs.moe_buf_elem_bytes,
@@ -2603,7 +2603,7 @@ struct MoeSortingMultiPhaseKernel_P23
gridDim.x - kargs.num_experts);
return;
#else
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
impl::moe_buf_set_zero_kernel<kBlockSize>(
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
kargs.moe_buf_bytes,
blockIdx.x - kargs.num_experts);
@@ -2618,13 +2618,13 @@ struct MoeSortingMultiPhaseKernel_P23
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
IndexType* p_total_tokens_post_pad =
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
IndexType* p_sorted_expert_ids =
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
index_t wave_id = threadIdx.x / get_warp_size();
index_t lane_id = threadIdx.x % get_warp_size();
@@ -2633,7 +2633,7 @@ struct MoeSortingMultiPhaseKernel_P23
for(index_t i = 0; i < loops; i++)
{
index_t position = i * BLOCK_SIZE + threadIdx.x;
index_t position = i * kBlockSize + threadIdx.x;
IndexType a_ = 0; // token count for a expert
IndexType b_ = 0; // mask for a expert
if(position < kargs.num_experts)
@@ -2678,15 +2678,15 @@ struct MoeSortingMultiPhaseKernel_P23
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2697,7 +2697,7 @@ struct MoeSortingMultiPhaseKernel_P23
cumsum_a += prev_cumsum_a;
cumsum_b += prev_cumsum_b;
if(threadIdx.x == BLOCK_SIZE - 1)
if(threadIdx.x == kBlockSize - 1)
{
s[2] = cumsum_a; // store the last cumsum
s[3] = cumsum_b;
@@ -2758,7 +2758,7 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType* s = reinterpret_cast<IndexType*>(smem);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
@@ -2795,13 +2795,13 @@ struct MoeSortingMultiPhaseKernel_P23
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
using d_t = ext_vector_t<index_t, index_pack>;
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
int prev_cumsum = 0;
for(int i = 0; i < loops; i++)
{
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
int i_token_pack = i * kBlockSize + threadIdx.x;
r_t x_v = 0;
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
{
@@ -2819,7 +2819,7 @@ struct MoeSortingMultiPhaseKernel_P23
static_for<0, index_pack, 1>{}([&](auto j_) {
constexpr auto j = j_.value;
x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * BLOCK_SIZE];
x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * kBlockSize];
});
}
#else
@@ -2830,7 +2830,7 @@ struct MoeSortingMultiPhaseKernel_P23
#pragma unroll
for(int j = 0; j < index_pack / 2; j++)
{
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE;
int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize;
index_t x = x_d[j];
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
@@ -2845,13 +2845,13 @@ struct MoeSortingMultiPhaseKernel_P23
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
});
cumsum += prev_cumsum; // add previous round cumsum
if(threadIdx.x == BLOCK_SIZE - 1)
if(threadIdx.x == kBlockSize - 1)
{
s[0] = cumsum;
}
@@ -2896,13 +2896,13 @@ struct MoeSortingMultiPhaseKernel_P23
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
});
cumsum += prev_cumsum; // add previous round cumsum
if(threadIdx.x == BLOCK_SIZE - 1)
if(threadIdx.x == kBlockSize - 1)
{
s[0] = cumsum;
}
@@ -2912,10 +2912,10 @@ struct MoeSortingMultiPhaseKernel_P23
int position = cumsum - cumsum_store;
static_for<0, index_pack, 1>{}([&](auto j_) {
constexpr auto j = j_.value;
// int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j *
// BLOCK_SIZE;
// int i_token = i * kBlockSize * index_pack + threadIdx.x + j *
// kBlockSize;
int i_token =
i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j;
i * kBlockSize * index_pack + threadIdx.x * index_pack + j;
if(i_show[j])
{
@@ -2932,7 +2932,7 @@ struct MoeSortingMultiPhaseKernel_P23
});
#if 0
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2;
int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2;
index_t x = x_d[j];
index_t x0 = static_cast<index_t>(x & 0xffff);
index_t x1 = static_cast<index_t>(x >> 16);
@@ -2951,13 +2951,13 @@ struct MoeSortingMultiPhaseKernel_P23
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
});
cumsum += prev_cumsum; // add previous round cumsum
if(threadIdx.x == BLOCK_SIZE - 1)
if(threadIdx.x == kBlockSize - 1)
{
s[0] = cumsum;
}
@@ -2996,7 +2996,7 @@ struct MoeSortingMultiPhaseKernel_P23
}
}
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);

View File

@@ -64,6 +64,7 @@ struct BatchedGemmKernel
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
@@ -121,9 +122,16 @@ struct BatchedGemmKernel
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
CK_TILE_HOST static auto BlockSize() -> dim3
{
return dim3(UniversalGemmKernel::KernelBlockSize);
if(ck_tile::is_wave32())
{
return dim3(UniversalGemmKernel::kBlockSize / 2);
}
else
{
return dim3(UniversalGemmKernel::kBlockSize);
}
}
CK_TILE_HOST static constexpr BatchedGemmKernelArgs

View File

@@ -113,6 +113,7 @@ struct GemmKernel
static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
CK_TILE_HOST static auto GetName() -> const std::string
{

View File

@@ -86,6 +86,7 @@ struct GemmKernelMultiD
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;

View File

@@ -128,7 +128,7 @@ struct GroupedGemmKernel
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -155,7 +155,7 @@ struct GroupedGemmKernel
return group_count * sizeof(GemmTransKernelArg);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); }
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
@@ -166,10 +166,10 @@ struct GroupedGemmKernel
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
const auto kernel = kentry<KernelBlockSize, 1, Kernel, ConstantPointer, index_t>;
const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>;
int occupancy;
HIP_CHECK_ERROR(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}

View File

@@ -196,7 +196,7 @@ struct UniversalGemmKernel
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
struct has_persistent_kernel
@@ -275,15 +275,26 @@ struct UniversalGemmKernel
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static auto BlockSize()
{
if(ck_tile::is_wave32())
{
return dim3(kBlockSize / 2);
}
else
{
return dim3(kBlockSize);
}
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs)
@@ -371,7 +382,9 @@ struct UniversalGemmKernel
}
}
bool AsTesnorIsValid = {true};
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
: GemmPipeline::template GetVectorSizeA<false>();
bool AsTesnorIsValid = {true};
static_for<0, NumATensor, 1>{}([&](auto index) {
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
@@ -387,7 +400,7 @@ struct UniversalGemmKernel
}
AsTesnorIsValid = false;
}
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
if(kargs.K % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -407,7 +420,7 @@ struct UniversalGemmKernel
}
AsTesnorIsValid = false;
}
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
if(kargs.M % vectorSizeA != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -418,7 +431,9 @@ struct UniversalGemmKernel
}
});
bool BsTesnorIsValid = {true};
bool BsTesnorIsValid = {true};
const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
: GemmPipeline::template GetVectorSizeB<false>();
static_for<0, NumBTensor, 1>{}([&](auto index) {
using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
@@ -432,7 +447,7 @@ struct UniversalGemmKernel
}
BsTesnorIsValid = false;
}
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
if(kargs.N % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -454,7 +469,7 @@ struct UniversalGemmKernel
}
BsTesnorIsValid = false;
}
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
if(kargs.K % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{

View File

@@ -127,8 +127,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t APackedSize =

View File

@@ -124,8 +124,16 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }

View File

@@ -61,8 +61,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;

View File

@@ -176,8 +176,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }

View File

@@ -36,8 +36,16 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Problem::VectorSizeA;
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Problem::VectorSizeB;
}
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }

View File

@@ -305,11 +305,15 @@ struct UniversalGemmBasePolicy
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
template <typename Problem,
typename DataType,
index_t MNPerBlock,
index_t XPerTile,
bool IsWave32Host>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
@@ -349,7 +353,7 @@ struct UniversalGemmBasePolicy
}
}
template <typename Problem>
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
@@ -359,15 +363,23 @@ struct UniversalGemmBasePolicy
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
return GetGlobalVectorLoadSize<Problem,
ADataType,
MPerBlock,
KPerBlock,
IsWave32Host>();
}
else
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
return GetGlobalVectorLoadSize<Problem,
ADataType,
MPerBlock,
MPerBlock,
IsWave32Host>();
}
}
template <typename Problem>
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
@@ -377,11 +389,19 @@ struct UniversalGemmBasePolicy
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
return GetGlobalVectorLoadSize<Problem,
BDataType,
NPerBlock,
NPerBlock,
IsWave32Host>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
return GetGlobalVectorLoadSize<Problem,
BDataType,
NPerBlock,
KPerBlock,
IsWave32Host>();
}
}

View File

@@ -59,13 +59,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return PipelinePolicy::template GetVectorSizeA<Problem>();
return PipelinePolicy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return PipelinePolicy::template GetVectorSizeB<Problem>();
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr bool kPadM = Problem::kPadM;

View File

@@ -76,13 +76,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return PipelinePolicy::template GetVectorSizeA<Problem>();
return PipelinePolicy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return PipelinePolicy::template GetVectorSizeB<Problem>();
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr bool kPadM = Problem::kPadM;

View File

@@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct AQuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
@@ -131,7 +131,7 @@ struct AQuantGemmKernel
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)

View File

@@ -354,7 +354,7 @@ struct GroupedConvolutionBackwardWeightKernel
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
@@ -393,7 +393,7 @@ struct GroupedConvolutionBackwardWeightKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)

View File

@@ -361,7 +361,7 @@ struct GroupedConvolutionForwardKernel
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
@@ -398,7 +398,7 @@ struct GroupedConvolutionForwardKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -31,6 +31,7 @@ struct ImageToColumn
static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
struct Kargs
{

View File

@@ -14,11 +14,10 @@ struct TileImageToColumnShape
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kKPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kKPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread;
static constexpr index_t kKThreadPerWarp = get_warp_size() / kMThreadPerWarp;
static constexpr index_t kKPerWarp = kKPerThread * kKThreadPerWarp;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kKPerBlock = BlockTile::at(number<1>{});

View File

@@ -76,9 +76,9 @@ struct Layernorm2dFwd
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -25,6 +25,8 @@ struct Reduce
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
private:
// Helper function to calculate optimal vector size for input tensor
template <typename InputShape, typename ReduceDims>

View File

@@ -70,6 +70,7 @@ struct Rmsnorm2dFwd
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};

View File

@@ -48,6 +48,7 @@ struct MoeSmoothquant
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};

View File

@@ -45,6 +45,7 @@ struct Smoothquant
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -34,6 +34,8 @@ struct TopkSoftmaxKernel
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
static constexpr index_t kBlockSize = Problem::BlockSize;
struct TopkSoftmaxKargs
{
const void* p_input;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -117,7 +117,7 @@ struct naive_attention_fwd_kernel
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
static constexpr int v_per_token_quant_group_size = 64;
static constexpr int kBlockSize = 256;
// TODO: hardcode
using SoftmaxType = float; // always using float to do softmax compute
using QuantComputeType = float; // used for quant/dequant scale compute
@@ -254,7 +254,7 @@ struct naive_attention_fwd_kernel
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
};
__device__ __host__ static constexpr int get_block_size() { return 256; }
__device__ __host__ static constexpr int get_block_size() { return kBlockSize; }
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
// compute all hdim from q, compute WG_SIZE hdim from v