mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Do not use warpSize as compile time constant as it is removed (#2320)
* Do not use warpSize as compile time constant as it is removed
* Update tile_image_to_column_shape.hpp
update warpSize usage.
* clean-up all use of warpSize, make sure code builds
* fix
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
[ROCm/composable_kernel commit: 4c57157d50]
This commit is contained in:
committed by
GitHub
parent
af88547b60
commit
a4517b0a9d
@@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
@@ -95,7 +95,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -104,11 +104,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -121,11 +121,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -136,7 +136,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
@@ -146,8 +146,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -156,9 +156,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
|
||||
@@ -448,19 +448,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock &&
|
||||
warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock &&
|
||||
WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (warpSize * KVector + kPad);
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -516,18 +516,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
warpSize /
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
@@ -538,9 +538,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<warpSize * KVector + kPad>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
@@ -569,18 +569,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
@@ -594,8 +594,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<BufferSize>{},
|
||||
number<NumWarps*(warpSize * KVector + kPad)>{},
|
||||
number<warpSize * KVector + kPad>{},
|
||||
number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
@@ -746,13 +746,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
|
||||
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
|
||||
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
|
||||
|
||||
static constexpr index_t BlockSize = warpSize * NumWarps;
|
||||
static constexpr index_t BlockSize = WarpSize * NumWarps;
|
||||
|
||||
// some assert
|
||||
static_assert(Block_M0 == Block_M1);
|
||||
|
||||
@@ -381,7 +381,7 @@ struct MoeSortingKernel
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -618,7 +618,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
// TODO: only support expert-tile like 8, 16, 32
|
||||
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
|
||||
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
|
||||
{
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset = cumsum[eid] +
|
||||
@@ -686,7 +686,7 @@ struct MoeSortingKernel
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
|
||||
const index_t lid = __lane_id();
|
||||
constexpr index_t block_size = 256; // blockDim.x;
|
||||
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
|
||||
@@ -791,7 +791,7 @@ struct MoeSortingKernel
|
||||
// NOTE: under this block can never use __syncthreads!
|
||||
int i_e_ = 0;
|
||||
int local_cumsum_ = 0;
|
||||
for(; i_e_ < num_experts; i_e_ += warpSize)
|
||||
for(; i_e_ < num_experts; i_e_ += WarpSize)
|
||||
{
|
||||
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
@@ -836,7 +836,7 @@ struct MoeSortingKernel
|
||||
// cumsum padded in case local cumsum is zero, but
|
||||
// pre_sumsum has value, which will result int
|
||||
// zero local cumsum(but we want at least padded)
|
||||
wave_cumsum<int, warpSize>(local_cumsum_);
|
||||
wave_cumsum<int, WarpSize>(local_cumsum_);
|
||||
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
|
||||
@@ -844,7 +844,7 @@ struct MoeSortingKernel
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
local_masking += pre_cumsum_masking;
|
||||
wave_cumsum<int, warpSize>(local_masking);
|
||||
wave_cumsum<int, WarpSize>(local_masking);
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumdup(i_e_ + lid + 1) = local_masking;
|
||||
}
|
||||
@@ -854,7 +854,7 @@ struct MoeSortingKernel
|
||||
// than 0(which is not we want)
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
if((lid + i_e_ - warpSize) == (num_experts - 1))
|
||||
if((lid + i_e_ - WarpSize) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
}
|
||||
@@ -1091,7 +1091,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
return chunk * sizeof(index_t);
|
||||
};
|
||||
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -1456,7 +1456,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1498,8 +1498,8 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1512,7 +1512,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1601,7 +1601,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1685,8 +1685,8 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1700,7 +1700,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1777,7 +1777,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -1802,8 +1802,8 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -1848,22 +1848,22 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -1978,7 +1978,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1995,8 +1995,8 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int e_start = p_expert_cumsum[eid];
|
||||
int e_end = p_expert_cumsum[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2026,17 +2026,17 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2081,7 +2081,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int);
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
@@ -2186,15 +2186,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -2239,22 +2239,22 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2324,13 +2324,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int e_start = p_expert_cumsum_smem[eid];
|
||||
int e_end = p_expert_cumsum_smem[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2390,17 +2390,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2441,17 +2441,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
cumsum_store += i_show[j];
|
||||
});
|
||||
int cumsum = cumsum_store;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2496,17 +2496,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == WarpSize - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
|
||||
@@ -303,7 +303,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -312,11 +312,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -329,11 +329,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -344,7 +344,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
@@ -354,8 +354,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -364,9 +364,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -398,7 +398,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -407,11 +407,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK >= NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -424,11 +424,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KPack>{}, // lds load vector
|
||||
@@ -439,7 +439,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
|
||||
make_merge_transform(make_tuple(
|
||||
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
|
||||
number<wavesPerK>{}, number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -449,8 +449,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -459,9 +459,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KPack>{}, // lds load vector
|
||||
|
||||
@@ -26,7 +26,7 @@ struct TileImageToColumnShape
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
|
||||
|
||||
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
|
||||
static constexpr index_t kBlockSize = get_warp_size() * kMWarpPerBlock * kKWarpPerBlock;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
return num_warps * 4 * thread_buf_size * sizeof(float);
|
||||
}
|
||||
|
||||
@@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
@@ -210,7 +210,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
Reference in New Issue
Block a user