mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Merge commit '9fcc1ee9fd9730efd865f530afde505f2556954d' into develop
This commit is contained in:
@@ -60,13 +60,30 @@ enum struct memory_operation_enum : std::uint16_t
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST bool is_wave32()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return props.major > 9;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
|
||||
|
||||
@@ -274,12 +274,6 @@
|
||||
#define CK_TILE_WA_ISSUE_2028 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WAVE32_ENABLED
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
#define CK_TILE_WAVE32_ENABLED
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Y pointed to R, we don't see a valuable use case.
|
||||
// Will enforce encoding to check Y not pointed to R if set to zero
|
||||
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
template <int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#endif
|
||||
__global__ void kentry(Args... args)
|
||||
{
|
||||
@@ -35,15 +35,11 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
//
|
||||
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
|
||||
//
|
||||
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
|
||||
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU, typename KernelImpl, typename... Args>
|
||||
CK_TILE_HOST auto
|
||||
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
|
||||
|
||||
const auto kernel = kentry<MinBlockPerCu, KernelImpl, Args...>;
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -53,6 +53,7 @@ struct AddRmsnorm2dRdquantFwd
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
@@ -34,6 +34,8 @@ struct BatchedTransposeKernel
|
||||
|
||||
using Type = typename Problem::DataType;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
struct BatchedTransposeKargs
|
||||
{
|
||||
const void* p_input;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -20,11 +20,10 @@ struct BatchedTransposeLdsProblem
|
||||
|
||||
static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{});
|
||||
static constexpr index_t kColWarps_ = NumWarps::at(number<1>{});
|
||||
static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_;
|
||||
static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kBlockSize = get_warp_size() * kRowWarps_ * kColWarps_;
|
||||
// warps per block
|
||||
static constexpr index_t kLeadNumWarps = kColWarps_;
|
||||
static constexpr index_t kSecondNumWarps = kRowWarps_;
|
||||
|
||||
@@ -20,6 +20,8 @@ struct ElementWiseKernel
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
|
||||
|
||||
template <typename... XDataType, typename Dims>
|
||||
CK_TILE_DEVICE void operator()(Dims lens,
|
||||
Dims input_strides,
|
||||
|
||||
@@ -17,7 +17,6 @@ template <typename ADataType_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t MWave_,
|
||||
@@ -40,7 +39,7 @@ struct CShuffleEpilogueProblem
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
|
||||
@@ -91,13 +91,13 @@ struct FlatmmKernel
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -127,7 +127,7 @@ struct FlatmmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
|
||||
|
||||
@@ -237,15 +237,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
return TileShape::WarpTile::at(I2) * scale / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
return TileShape::WarpTile::at(I2) * scale / 4;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -24,9 +24,10 @@ namespace ck_tile {
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ struct FmhaFwdAppendKVKernel
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -25,9 +25,10 @@ namespace ck_tile {
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdKernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
@@ -30,6 +30,7 @@ struct FmhaFwdPagedKVKernel
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,6 +14,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
|
||||
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -26,6 +26,7 @@ struct FmhaFwdSplitKVKernel
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -213,7 +213,7 @@ struct MoeSortingKernel
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
struct Kargs
|
||||
@@ -487,8 +487,8 @@ struct MoeSortingKernel
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += (gridDim.x - 1) * BLOCK_SIZE)
|
||||
for(long_index_t i = (blockIdx.x - 1) * kBlockSize + threadIdx.x; i < total_elems;
|
||||
i += (gridDim.x - 1) * kBlockSize)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
@@ -1419,7 +1419,7 @@ template <typename Problem_>
|
||||
struct MoeSortingClearWorkspaceKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr index_t BLOCK_SIZE = Problem::BlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::BlockSize;
|
||||
static constexpr index_t OCCUPANCY = Problem::Occu;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
@@ -1461,7 +1461,7 @@ struct MoeSortingClearWorkspaceKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
@@ -1499,8 +1499,8 @@ struct MoeSortingClearWorkspaceKernel
|
||||
vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elems;
|
||||
i += gridDim.x * kBlockSize)
|
||||
{
|
||||
p_expert_mesh[i] = zero_;
|
||||
}
|
||||
@@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -1604,7 +1604,7 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
@@ -1647,8 +1647,8 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
|
||||
i += gridDim.x * kBlockSize)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
@@ -1678,7 +1678,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -1709,12 +1709,12 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
return kBlockSize / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1756,7 +1756,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
|
||||
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -1768,7 +1768,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
index_t cnt = 0; // per-wave cnt
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
int position = i * kBlockSize + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
@@ -1792,7 +1792,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1811,7 +1811,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -1878,12 +1878,12 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h)
|
||||
{
|
||||
index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, GridSize(h));
|
||||
@@ -1892,7 +1892,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
return kBlockSize / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1921,7 +1921,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, kargs.wg_count);
|
||||
@@ -1940,8 +1940,8 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
i += BLOCK_SIZE * gridDim.x)
|
||||
for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
|
||||
i += kBlockSize * gridDim.x)
|
||||
{
|
||||
auto x = p_topk_ids[i];
|
||||
static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
|
||||
@@ -1996,7 +1996,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (kargs.mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -2008,7 +2008,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
index_t cnt = 0; // per-wave cnt
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
int position = i * kBlockSize + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (kargs.mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
@@ -2033,7 +2033,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -2055,7 +2055,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -2123,17 +2123,17 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
// return 2 * kBlockSize * sizeof(IndexType);
|
||||
return (4 + 2 * kBlockSize / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -2142,7 +2142,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
impl::moe_buf_set_zero_kernel_2d<kBlockSize>(kargs.p_moe_buf,
|
||||
kargs.tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
@@ -2150,7 +2150,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
gridDim.x - 1);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
impl::moe_buf_set_zero_kernel<kBlockSize>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - 1);
|
||||
@@ -2167,7 +2167,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
@@ -2176,7 +2176,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
|
||||
for(index_t i = 0; i < loops; i++)
|
||||
{
|
||||
index_t position = i * BLOCK_SIZE + threadIdx.x;
|
||||
index_t position = i * kBlockSize + threadIdx.x;
|
||||
IndexType a_ = 0; // token count for a expert
|
||||
IndexType b_ = 0; // mask for a expert
|
||||
if(position < kargs.num_experts)
|
||||
@@ -2221,15 +2221,15 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2240,7 +2240,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
cumsum_a += prev_cumsum_a;
|
||||
cumsum_b += prev_cumsum_b;
|
||||
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[2] = cumsum_a; // store the last cumsum
|
||||
s[3] = cumsum_b;
|
||||
@@ -2297,7 +2297,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -2341,12 +2341,12 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
return (4 + kBlockSize / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -2391,11 +2391,11 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
}
|
||||
|
||||
// cumsum one by one
|
||||
int loops = (kargs.mesh_stride + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (kargs.mesh_stride + kBlockSize - 1) / kBlockSize;
|
||||
int prev_cumsum = 0;
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE + threadIdx.x;
|
||||
int i_token = i * kBlockSize + threadIdx.x;
|
||||
IndexType x = 0;
|
||||
if(i_token < tokens)
|
||||
{
|
||||
@@ -2414,13 +2414,13 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2441,7 +2441,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
@@ -2457,9 +2457,9 @@ namespace impl {
|
||||
// we use dynamic LDS size here
|
||||
CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
constexpr index_t kBlockSize = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
|
||||
return (4 + 2 * kBlockSize / get_warp_size() + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
@@ -2473,7 +2473,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using MeshType = typename Problem::MeshType;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
@@ -2563,18 +2563,18 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
|
||||
|
||||
// only use this at host !
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
|
||||
{
|
||||
const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts);
|
||||
const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType);
|
||||
const auto smem_sf = kBlockSize * 4 * sizeof(IndexType);
|
||||
return max(smem_23, smem_sf);
|
||||
}
|
||||
|
||||
@@ -2595,7 +2595,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
impl::moe_buf_set_zero_kernel_2d<kBlockSize>(kargs.p_moe_buf,
|
||||
tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
@@ -2603,7 +2603,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
gridDim.x - kargs.num_experts);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
impl::moe_buf_set_zero_kernel<kBlockSize>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - kargs.num_experts);
|
||||
@@ -2618,13 +2618,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
@@ -2633,7 +2633,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
|
||||
for(index_t i = 0; i < loops; i++)
|
||||
{
|
||||
index_t position = i * BLOCK_SIZE + threadIdx.x;
|
||||
index_t position = i * kBlockSize + threadIdx.x;
|
||||
IndexType a_ = 0; // token count for a expert
|
||||
IndexType b_ = 0; // mask for a expert
|
||||
if(position < kargs.num_experts)
|
||||
@@ -2678,15 +2678,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2697,7 +2697,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
cumsum_a += prev_cumsum_a;
|
||||
cumsum_b += prev_cumsum_b;
|
||||
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[2] = cumsum_a; // store the last cumsum
|
||||
s[3] = cumsum_b;
|
||||
@@ -2758,7 +2758,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
@@ -2795,13 +2795,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
using d_t = ext_vector_t<index_t, index_pack>;
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
|
||||
|
||||
int prev_cumsum = 0;
|
||||
|
||||
for(int i = 0; i < loops; i++)
|
||||
{
|
||||
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
|
||||
int i_token_pack = i * kBlockSize + threadIdx.x;
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
@@ -2819,7 +2819,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
|
||||
static_for<0, index_pack, 1>{}([&](auto j_) {
|
||||
constexpr auto j = j_.value;
|
||||
x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * BLOCK_SIZE];
|
||||
x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * kBlockSize];
|
||||
});
|
||||
}
|
||||
#else
|
||||
@@ -2830,7 +2830,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
#pragma unroll
|
||||
for(int j = 0; j < index_pack / 2; j++)
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE;
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize;
|
||||
index_t x = x_d[j];
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
@@ -2845,13 +2845,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2896,13 +2896,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2912,10 +2912,10 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int position = cumsum - cumsum_store;
|
||||
static_for<0, index_pack, 1>{}([&](auto j_) {
|
||||
constexpr auto j = j_.value;
|
||||
// int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j *
|
||||
// BLOCK_SIZE;
|
||||
// int i_token = i * kBlockSize * index_pack + threadIdx.x + j *
|
||||
// kBlockSize;
|
||||
int i_token =
|
||||
i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j;
|
||||
i * kBlockSize * index_pack + threadIdx.x * index_pack + j;
|
||||
|
||||
if(i_show[j])
|
||||
{
|
||||
@@ -2932,7 +2932,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
});
|
||||
|
||||
#if 0
|
||||
int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2;
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2;
|
||||
index_t x = x_d[j];
|
||||
index_t x0 = static_cast<index_t>(x & 0xffff);
|
||||
index_t x1 = static_cast<index_t>(x >> 16);
|
||||
@@ -2951,13 +2951,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == BLOCK_SIZE - 1)
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
@@ -2996,7 +2996,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
|
||||
@@ -64,6 +64,7 @@ struct BatchedGemmKernel
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
@@ -121,9 +122,16 @@ struct BatchedGemmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
CK_TILE_HOST static auto BlockSize() -> dim3
|
||||
{
|
||||
return dim3(UniversalGemmKernel::KernelBlockSize);
|
||||
if(ck_tile::is_wave32())
|
||||
{
|
||||
return dim3(UniversalGemmKernel::kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(UniversalGemmKernel::kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr BatchedGemmKernelArgs
|
||||
|
||||
@@ -113,6 +113,7 @@ struct GemmKernel
|
||||
|
||||
static constexpr index_t NumATensor = 1;
|
||||
static constexpr index_t NumBTensor = 1;
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
|
||||
CK_TILE_HOST static auto GetName() -> const std::string
|
||||
{
|
||||
|
||||
@@ -86,6 +86,7 @@ struct GemmKernelMultiD
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
|
||||
@@ -128,7 +128,7 @@ struct GroupedGemmKernel
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -155,7 +155,7 @@ struct GroupedGemmKernel
|
||||
return group_count * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); }
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
@@ -166,10 +166,10 @@ struct GroupedGemmKernel
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
|
||||
const auto kernel = kentry<KernelBlockSize, 1, Kernel, ConstantPointer, index_t>;
|
||||
const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>;
|
||||
int occupancy;
|
||||
HIP_CHECK_ERROR(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
@@ -196,7 +196,7 @@ struct UniversalGemmKernel
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
// Get the persistent kernel if the pipeline has it available
|
||||
struct has_persistent_kernel
|
||||
@@ -275,15 +275,26 @@ struct UniversalGemmKernel
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
|
||||
const auto kernel = kentry<1, Kernel, KernelArgs>;
|
||||
int occupancy;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
|
||||
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
if(ck_tile::is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs)
|
||||
@@ -371,7 +382,9 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
bool AsTesnorIsValid = {true};
|
||||
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
|
||||
: GemmPipeline::template GetVectorSizeA<false>();
|
||||
bool AsTesnorIsValid = {true};
|
||||
static_for<0, NumATensor, 1>{}([&](auto index) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -387,7 +400,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
|
||||
if(kargs.K % vectorSizeA != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -407,7 +420,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
|
||||
if(kargs.M % vectorSizeA != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -418,7 +431,9 @@ struct UniversalGemmKernel
|
||||
}
|
||||
});
|
||||
|
||||
bool BsTesnorIsValid = {true};
|
||||
bool BsTesnorIsValid = {true};
|
||||
const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
|
||||
: GemmPipeline::template GetVectorSizeB<false>();
|
||||
static_for<0, NumBTensor, 1>{}([&](auto index) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<index.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -432,7 +447,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
|
||||
if(kargs.N % vectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -454,7 +469,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
|
||||
if(kargs.K % vectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
|
||||
@@ -127,8 +127,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
|
||||
@@ -124,8 +124,16 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
|
||||
@@ -61,8 +61,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
|
||||
@@ -176,8 +176,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
|
||||
@@ -36,8 +36,16 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
|
||||
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
|
||||
@@ -305,11 +305,15 @@ struct UniversalGemmBasePolicy
|
||||
* @tparam XPerTile The contiguous Tile dimension size.
|
||||
* @return Maximum DRAM vector load size.
|
||||
*/
|
||||
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
|
||||
template <typename Problem,
|
||||
typename DataType,
|
||||
index_t MNPerBlock,
|
||||
index_t XPerTile,
|
||||
bool IsWave32Host>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
@@ -349,7 +353,7 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
@@ -359,15 +363,23 @@ struct UniversalGemmBasePolicy
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
ADataType,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
ADataType,
|
||||
MPerBlock,
|
||||
MPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
@@ -377,11 +389,19 @@ struct UniversalGemmBasePolicy
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
BDataType,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
BDataType,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -59,13 +59,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem>();
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem>();
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
|
||||
@@ -76,13 +76,15 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem>();
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem>();
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
|
||||
@@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct AQuantGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
|
||||
@@ -131,7 +131,7 @@ struct AQuantGemmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
|
||||
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)
|
||||
|
||||
@@ -354,7 +354,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
@@ -393,7 +393,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
|
||||
|
||||
@@ -361,7 +361,7 @@ struct GroupedConvolutionForwardKernel
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
@@ -398,7 +398,7 @@ struct GroupedConvolutionForwardKernel
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -31,6 +31,7 @@ struct ImageToColumn
|
||||
|
||||
static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
|
||||
@@ -14,11 +14,10 @@ struct TileImageToColumnShape
|
||||
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t kKPerThread = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kKPerWarp = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
|
||||
static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread;
|
||||
static constexpr index_t kKThreadPerWarp = get_warp_size() / kMThreadPerWarp;
|
||||
static constexpr index_t kKPerWarp = kKPerThread * kKThreadPerWarp;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kKPerBlock = BlockTile::at(number<1>{});
|
||||
|
||||
@@ -76,9 +76,9 @@ struct Layernorm2dFwd
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ struct Reduce
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
private:
|
||||
// Helper function to calculate optimal vector size for input tensor
|
||||
template <typename InputShape, typename ReduceDims>
|
||||
|
||||
@@ -70,6 +70,7 @@ struct Rmsnorm2dFwd
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
@@ -48,6 +48,7 @@ struct MoeSmoothquant
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
@@ -45,6 +45,7 @@ struct Smoothquant
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -34,6 +34,8 @@ struct TopkSoftmaxKernel
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockSize;
|
||||
|
||||
struct TopkSoftmaxKargs
|
||||
{
|
||||
const void* p_input;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -117,7 +117,7 @@ struct naive_attention_fwd_kernel
|
||||
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
|
||||
|
||||
static constexpr int v_per_token_quant_group_size = 64;
|
||||
|
||||
static constexpr int kBlockSize = 256;
|
||||
// TODO: hardcode
|
||||
using SoftmaxType = float; // always using float to do softmax compute
|
||||
using QuantComputeType = float; // used for quant/dequant scale compute
|
||||
@@ -254,7 +254,7 @@ struct naive_attention_fwd_kernel
|
||||
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
|
||||
};
|
||||
|
||||
__device__ __host__ static constexpr int get_block_size() { return 256; }
|
||||
__device__ __host__ static constexpr int get_block_size() { return kBlockSize; }
|
||||
|
||||
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
|
||||
// compute all hdim from q, compute WG_SIZE hdim from v
|
||||
|
||||
Reference in New Issue
Block a user