mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] FMHA Support hdim_v to as a Multiple of 32 (#2114)
* 160+192
* Add splitkv d160
* cleanup
* fix
* Add change log
* Fix CHANGELOG
* Use static_cast
* Update ignored instance
---------
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
[ROCm/composable_kernel commit: b8212864cf]
This commit is contained in:
@@ -18,6 +18,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for FP16 2:4 structured sparsity to universal GEMM.
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv)
|
||||
* Added benchmarking support for tile engine GEMM.
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
|
||||
@@ -71,6 +71,7 @@ args:
|
||||
-drop_seed seed for random number generator (default:1)
|
||||
-drop_offset offset for random number generator (default:0)
|
||||
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
|
||||
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
```
|
||||
|
||||
@@ -282,18 +282,19 @@ class FmhaFwdApiPool:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][(hdim, hdim_v)]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
@@ -306,7 +307,7 @@ class FmhaFwdApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -435,18 +436,20 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -454,7 +457,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -463,7 +466,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
@@ -507,15 +510,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for pipeline in get_pipelines(dtype, hdim, hdim_v):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
if (hdim, hdim_v) == (192, 128) or hdim == 160:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
|
||||
@@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
# 160: 160,
|
||||
256: 256
|
||||
}
|
||||
|
||||
@@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -178,50 +178,30 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
}
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
||||
{
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
||||
max_splits = std::min({max_splits, num_SMs});
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
||||
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
||||
// (i.e. it's 11 splits anyway).
|
||||
// So we check if the number of blocks per split is the same as the previous num_splits.
|
||||
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
||||
return num_splits == 1 ||
|
||||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
||||
};
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
efficiency.push_back(0.f);
|
||||
}
|
||||
else
|
||||
{
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
||||
{
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
@@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks,
|
||||
int override_num_splits_if_necessary(
|
||||
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
||||
{
|
||||
(void)hdim_v;
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
@@ -250,15 +231,13 @@ int override_num_splits_if_necessary(
|
||||
|
||||
// tile size should match the generate.py
|
||||
const int kM0 = 64;
|
||||
const int kN1 = hdim_v;
|
||||
|
||||
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
||||
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
|
||||
: static_cast<index_t>(idx_y_start[ii]);
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 2, 1>, // N0 K2 N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
@@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 1, 2>, // N0 K2 <-> N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
Reference in New Issue
Block a user