[CK_TILE] optimize moe sorting kernel, boost large context case up to 20x (#2153)

* combine 2-3 as single stage

* support zeroing

* improve long tokens

* update specialization

* b16 ws

* 8bit topk optimize

* update 15 example
This commit is contained in:
carlushuang
2025-05-06 17:32:07 +08:00
committed by GitHub
parent b8fa27bfef
commit 4e9b76f88c
15 changed files with 1216 additions and 115 deletions

View File

@@ -154,4 +154,13 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres
#pragma clang diagnostic pop
}
CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
{
#if defined(__gfx950__)
return 163840;
#else
return 65536;
#endif
}
} // namespace ck_tile

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
struct workgroup_barrier
{
CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0)
{
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
}
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)
{
while(ld(offset) != value) {}
}
__syncthreads();
}
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)
{
while(ld(offset) < value) {}
}
__syncthreads();
}
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0)
{
if(threadIdx.x == 0)
{
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
}
__syncthreads();
}
// enter critical zoon, assume buffer is zero when launch kernel
CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); }
// exit critical zoon, assume buffer is zero when launch kernel
CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); }
CK_TILE_DEVICE void inc(uint32_t offset = 0)
{
__syncthreads();
if(threadIdx.x == 0)
{
atomicAdd(base_ptr + offset, 1);
}
}
uint32_t* base_ptr;
};
} // namespace ck_tile