save tuning

This commit is contained in:
zanzhang
2025-04-23 10:46:05 +08:00
parent bcf5bb41be
commit 9b94c7bfc7
14 changed files with 330 additions and 116 deletions

View File

@@ -21,7 +21,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --receipt 200
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
@@ -45,7 +45,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 200
)
add_custom_command(

View File

@@ -3,11 +3,11 @@
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
# "fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
# "fp8" : "FmhaFwdFp8",
# "fp8fp16": "FmhaFwdFp8Fp16",
# "fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {

View File

@@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
);
}}
@@ -545,12 +545,6 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= dpad == dvpad
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= dpad == dvpad
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k)
@@ -688,11 +682,6 @@ def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaB
cond &= mode == "group"
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
return gen
@@ -845,11 +834,6 @@ def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[Fmh
cond &= mode == "group"
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
return gen

View File

@@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
"""
@@ -288,7 +288,7 @@ class FmhaFwdApiPool:
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
@@ -413,18 +413,21 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '128' : FmhaFwdTileSize(64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
'128' : FmhaFwdTileSize(64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
@@ -441,31 +444,32 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
if False:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
# if bias == "bias":
# # TODO: rocm 6.2 compiler problem if using qr_async for bias case
# pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# else:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# if receipt == 1 and bias != "bias":
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
@@ -493,10 +497,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -537,13 +537,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_bias == 'no'
cond &= pipeline.F_lse == 'f'
cond &= pipeline.F_dropout == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())

View File

@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)

View File

@@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
);
}}
@@ -439,13 +439,8 @@ class FmhaFwdSplitKVCombinePipeline:
pn = pad_name()
n = f'{self.tag}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdSplitKVApiPool:
@@ -481,7 +476,7 @@ class FmhaFwdSplitKVApiPool:
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
@@ -738,13 +733,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
@@ -803,11 +791,6 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
cond &= mode == "group"
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
return gen

View File

@@ -724,6 +724,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
}
else if(init_method == "cnst") // fill with constant
{
ck_tile::FillConstant<QDataType>{1.f}(q_host);
ck_tile::FillIncConstant<KDataType>{1.f}(k_host);
ck_tile::FillConstant<KDataType>{1.f}(knew_host);
// ck_tile::FillConstant<VDataType>{1.f}(v_host);
// ck_tile::FillConstant<VDataType>{1.f}(vnew_host);
// ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
// ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
// ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillIncConstant<VDataType>{1.f}(v_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
}
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
@@ -1375,6 +1390,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{},
ck_tile::scales(scale_s));
if(bias.type == bias_enum::elementwise_bias)
{
// elementwise bias
@@ -1516,6 +1532,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
q_host_ref.savetxt("query.bin");
k_host_ref.savetxt("key.bin");
v_host_ref.savetxt("value.bin");
s_host_ref.savetxt("s.bin");
p_host_ref.savetxt("p.bin");
o_host_result.savetxt("output-fmha.bin");
o_host_ref.savetxt("output-ref.bin");
pass &= cur_pass;
if(!cur_pass)
{

View File

@@ -109,8 +109,8 @@ if __name__ == "__main__":
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration"
)
args = parser.parse_args()

View File

@@ -346,12 +346,12 @@ struct FillStepRange
template <typename T>
struct FillConstant
{
T value_{0};
float value_{1.f};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::fill(first, last, value_);
std::fill(first, last, ck_tile::type_convert<T>(value_));
}
template <typename ForwardRange>
@@ -364,6 +364,30 @@ struct FillConstant
}
};
template <typename T>
struct FillIncConstant
{
float value_{1.f};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
(void)last;
for (int i = 0; i < 128; ++i) {
std::fill(first + 128 * i, first + 128 * (i + 1), ck_tile::type_convert<T>(value_ * (i + 1)));
}
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillIncConstant&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
//----------------------------------------------------------------------------------------------
/// @brief Transforms given input to fit 2:4 structured sparsity pattern so
/// every subgroup of 4 elements contain at most 2 non-zero elements

View File

@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
const index_t split_start = x_per_split * i_split;
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end));

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include <cwchar>
namespace ck_tile {
@@ -301,9 +302,57 @@ struct BlockFmhaPipelineQRKSVS
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
// using WG = remove_cvref_t<decltype(config.template at<0>())>;
//
//
// constexpr int32_t MWarp = BlockFmhaShape::Gemm0BlockWarps::at(ck_tile::number<0>{});
// constexpr int32_t NWarp = BlockFmhaShape::Gemm0BlockWarps::at(ck_tile::number<1>{});
//
// constexpr int32_t kNPerBlock = BlockFmhaShape::kN1;
// constexpr int32_t kKPerBlock = BlockFmhaShape::kK1;
//
// constexpr int32_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
// constexpr int32_t KIterPerWarp = kKPerBlock / WG::kK;
//
// constexpr auto k_tile_outer_encode =
// ck_tile::tile_distribution_encoding<
// ck_tile::sequence<MWarp>,
// ck_tile::tuple<ck_tile::sequence<NIterPerWarp, NWarp>, ck_tile::sequence<KIterPerWarp>>,
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
// ck_tile::sequence<1, 2>,
// ck_tile::sequence<0, 0>>{};
//
// constexpr auto k_dram_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
// k_tile_outer_encode, typename WG::BWarpDstrEncoding{});
//
// auto k_lds_dis = ck_tile::make_static_tile_distribution(k_dram_block_dstr_encode);
//
// auto b_warp_window_tmp = make_tile_window(
// k_lds_window.get_bottom_tensor_view(),
// make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
// k_lds_window.get_window_origin(),
// k_lds_dis);
//
// auto b_tile = load_tile(b_warp_window_tmp);
// constexpr auto v_spans = decltype(b_tile)::get_distributed_spans();
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// const auto tile_idx = get_x_indices_from_distributed_indices(
// b_tile.get_tile_distribution(), i_j_idx);
// printf("threadIdx.x[%d], k[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx]));
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
// });
// });
// }
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(s_acc,
@@ -384,6 +433,7 @@ struct BlockFmhaPipelineQRKSVS
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -392,13 +442,19 @@ struct BlockFmhaPipelineQRKSVS
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// printf("threadIdx.x[%d], loop[%d], q_tile[%d], k[%d], %d \n", threadIdx.x, i_total_loops, q_origin.at(number<0>{}), k_origin.at(number<0>{}), static_cast<int32_t>(need_perpixel_check));
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
auto ret = mask.IsOutOfBound(row, col);
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// printf("threadIdx.x[%d], q_tile[%d, %d], k[%d, %d], %d \n", threadIdx.x, row, q_origin.at(number<0>{}), col, k_origin.at(number<0>{}), static_cast<int32_t>(ret));
return ret;
});
}
}
@@ -409,7 +465,7 @@ struct BlockFmhaPipelineQRKSVS
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
block_tile_reduce_sync(m_local, f_max, bool_constant<true>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
@@ -441,6 +497,8 @@ struct BlockFmhaPipelineQRKSVS
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
// const auto tile_idx = get_x_indices_from_distributed_indices(
// p_compute.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
@@ -452,16 +510,31 @@ struct BlockFmhaPipelineQRKSVS
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// printf("threadIdx.x[%d], s_acc[%d, %d] = %f p = %f m = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast<float>(s[i_j_idx]), static_cast<float>(p_compute[i_j_idx]), static_cast<float>(get_validated_m(m[i_idx])));
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
// {
// constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
// sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
// sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
// const auto tile_idx = get_x_indices_from_distributed_indices(
// s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// printf("threadIdx.x[%d], s_acc[%d, %d] = %f p = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast<float>(s_acc[i_j_idx]), static_cast<float>(p_compute[i_j_idx]));
// });
// });
// }
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<true>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
@@ -503,10 +576,37 @@ struct BlockFmhaPipelineQRKSVS
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
// {
// constexpr auto v_spans = decltype(v_prefetch)::get_distributed_spans();
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
// const auto tile_idx = get_x_indices_from_distributed_indices(
// v_prefetch.get_tile_distribution(), make_tuple(idx0, idx1));
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// printf("threadIdx.x[%d], v_prefetch[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(v_prefetch[i_j_idx]));
// });
// });
//
// }
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
// {
// constexpr auto v_spans = decltype(v_shuffle_tmp)::get_distributed_spans();
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
// const auto tile_idx = get_x_indices_from_distributed_indices(
// v_shuffle_tmp.get_tile_distribution(), make_tuple(idx0, idx1));
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(v_shuffle_tmp[i_j_idx]));
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
// });
// });
//
// }
}
else
{
@@ -515,16 +615,103 @@ struct BlockFmhaPipelineQRKSVS
}
move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
// using WG = remove_cvref_t<decltype(config.template at<0>())>;
//
//
// constexpr int32_t MWarp = BlockFmhaShape::Gemm1BlockWarps::at(ck_tile::number<0>{});
// constexpr int32_t NWarp = BlockFmhaShape::Gemm1BlockWarps::at(ck_tile::number<1>{});
//
// constexpr int32_t kNPerBlock = BlockFmhaShape::kN1;
// constexpr int32_t kKPerBlock = BlockFmhaShape::kK1;
//
// constexpr int32_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
// constexpr int32_t KIterPerWarp = kKPerBlock / WG::kK;
//
// constexpr auto k_tile_outer_encode =
// ck_tile::tile_distribution_encoding<
// ck_tile::sequence<MWarp>,
// ck_tile::tuple<ck_tile::sequence<NIterPerWarp, NWarp>, ck_tile::sequence<KIterPerWarp>>,
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
// ck_tile::sequence<1, 2>,
// ck_tile::sequence<0, 0>>{};
//
// constexpr auto k_dram_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
// k_tile_outer_encode, typename WG::BWarpDstrEncoding{});
//
// auto v_lds_dis = ck_tile::make_static_tile_distribution(k_dram_block_dstr_encode);
//
// auto b_warp_window_tmp = make_tile_window(
// v_lds_window.get_bottom_tensor_view(),
// make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
// v_lds_window.get_window_origin(),
// v_lds_dis);
//
// auto b_tile = load_tile(b_warp_window_tmp);
// constexpr auto v_spans = decltype(b_tile)::get_distributed_spans();
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// const auto tile_idx = get_x_indices_from_distributed_indices(
// b_tile.get_tile_distribution(), i_j_idx);
// printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx]));
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
// });
// });
// }
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
auto v_lds_dis = ck_tile::make_static_tile_distribution(typename WG::BWarpDstrEncoding{});
auto b_warp_window_tmp = make_tile_window(
v_lds_window.get_bottom_tensor_view(),
make_tuple(number<128>{}, number<64>{}),
v_lds_window.get_window_origin(),
v_lds_dis);
auto b_tile = load_tile(b_warp_window_tmp);
printf("threadIdx.x[%d], v_shuffle_tmp[%d]\n", threadIdx.x, b_tile.get_thread_buffer_size());
for (int i = 0; i < b_tile.get_thread_buffer_size(); ++i) {
printf(" threadIdx.x[%d], v_shuffle_tmp[%f] \n", threadIdx.x, bf16_to_float(b_tile.get_thread_buffer()[i]));
}
printf("\n");
// sweep_tile(b_tile, [&](auto idx) {
// printf("threadIdx.x[%d], v_shuffle_tmp[%d] = %f\n", threadIdx.x, idx.value, bf16_to_float(b_tile[idx]));
// });
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// const auto tile_idx = get_x_indices_from_distributed_indices(
// b_tile.get_tile_distribution(), i_j_idx);
// printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx]));
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
// });
// };
}
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm_1(o_acc,
// gemm_1(o_acc,
gemm_0(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
@@ -551,11 +738,26 @@ struct BlockFmhaPipelineQRKSVS
// tail
{
block_sync_lds();
gemm_1(o_acc,
// gemm_1(o_acc,
gemm_0(o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
// {
// sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
// constexpr auto i_idx = make_tuple(idx0);
// sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
// const auto tile_idx = get_x_indices_from_distributed_indices(
// o_acc.get_tile_distribution(), make_tuple(idx0, idx1));
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// printf("threadIdx.x[%d], o_acc[%d, %d] = %f p = %f m = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast<float>(o_acc[i_j_idx]), static_cast<float>(p_compute[i_j_idx]), static_cast<float>(l[i_idx]));
// });
// });
// }
} while(++i_total_loops < num_total_loop);
// store lse

View File

@@ -394,19 +394,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using VDataType = remove_cvref_t<typename Problem::VDataType>;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize; //256
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; //16
constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType))); // 8
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); // 2
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
? kMaxVecLoad //8
: (total_pixels / kMinVecLoad);
return kVecLoad;
return kVecLoad; //8
}
else
{
@@ -854,17 +854,17 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; //16
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K3 = total_pixels / N1; // 16 / 8 = 2
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // 8
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave // 4
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);

View File

@@ -51,11 +51,11 @@ struct BlockGemmARegBSmemCRegV2
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); //128 / 16
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 32 / 32 = 1
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; // 128 / 8 = 16
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; // 32
const index_t iNWarp = get_warp_id() % NWarp;

View File

@@ -17,7 +17,7 @@ fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -v -save-temps -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \