diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index a3248e2a5e..0bb5408772 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -34,6 +34,7 @@ args: if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) -s_k seqlen_k, -1 means equal to s (default:-1) -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 74cb3657e6..91fc07d831 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[]) "-1", "num of head, for k/v, -1 means equal to h\n" "if not equal to h, then this is GQA/MQA case") - .insert("s", - "3328", - "seqlen_q. if group-mode, means the average value of seqlen_q\n" - "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert( + "s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed. same as xformer kv_padding") .insert("d", "128", "head dim for q, k") .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale_s", @@ -103,6 +110,7 @@ auto create_args(int argc, char* argv[]) "11939", "random seed used for initializing input tensors. 0 for " "non-deterministic seed") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -177,10 +185,20 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k < 0) - seqlen_k = seqlen_q; + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, + batch, + arg_parser.get_str("s"), + arg_parser.get_str("s_k"), + arg_parser.get_str("s_kpad")); + +#if 0 + // clang-format off + std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl; + // clang-format on +#endif + ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); if(hdim_v < 0) @@ -229,7 +247,8 @@ bool run(const ck_tile::ArgParser& arg_parser) bool lse = arg_parser.get_bool("lse"); bias_info bias = bias_info::decode(arg_parser.get_str("bias")); - mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + mask_info mask = mask_info::decode( + arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore std::string init_method = arg_parser.get_str("init"); std::optional seed = arg_parser.get_uint32("seed"); @@ -242,11 +261,16 @@ bool run(const ck_tile::ArgParser& arg_parser) int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); - ck_tile::stream_config stream_config{ - nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; - const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); - const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); using TypeConfig = FmhaFwdTypeConfig; @@ -302,9 +326,11 @@ bool run(const ck_tile::ArgParser& arg_parser) // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = - (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + (mode == mode_enum::batch ? seqlen_ks[0] + : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() + : seqstart_k_with_padding_host.back())); ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -407,6 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); @@ -414,7 +441,9 @@ bool run(const ck_tile::ArgParser& arg_parser) v_buf.ToDevice(v_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); + seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off @@ -430,7 +459,9 @@ bool run(const ck_tile::ArgParser& arg_parser) const std::string prec = arg_parser.get_str("prec"); std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch - << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << std::flush; @@ -460,7 +491,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return ck_tile::identity{}; }(); - auto fmha_args = [&]() { + auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & @@ -506,7 +537,7 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), - nullptr, + k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), shape_seqlen_q, shape_seqlen_k, batch, @@ -576,7 +607,10 @@ bool run(const ck_tile::ArgParser& arg_parser) // adjust matrix index according to the mode const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; @@ -661,7 +695,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else { return ck_tile::Alibi{ - 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL}; + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; } }(); @@ -671,7 +705,8 @@ bool run(const ck_tile::ArgParser& arg_parser) for(auto i_h = 0; i_h < nhead; i_h++) { SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); - alibi_host.slope = current_slope; + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; for(auto i_r = 0; i_r < real_seqlen_q; i_r++) { for(auto i_c = 0; i_c < real_seqlen_k; i_c++) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 51fecd07b5..f0180d6db7 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -78,6 +78,11 @@ BOOL_MAP = { "f" : "false" } +TILE_PARTITIONER_MAP = { + "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", + "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +} + DIRECTIONS = ["fwd"] GEN_DIR = "" # in Cmake, have to generate files in same folder @@ -107,7 +112,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_dvpad}, {F_bias}, {F_lse}, - {F_squant}, + {F_squant}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; @@ -136,7 +141,7 @@ using fmha_epilogue_{F_idx} = {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel, + ck_tile::FmhaFwdKernel<{F_tile_partitioner}, fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; @@ -154,7 +159,7 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ @@ -389,6 +394,12 @@ class FmhaFwdKernel: F_pipeline : FmhaFwdPipeline mask_impl : str + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + @property def template(self) -> str: kernel_body = str() @@ -413,7 +424,7 @@ class FmhaFwdKernel: F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -421,12 +432,13 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh index 2c4bb562a3..21f679e11a 100755 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -28,6 +28,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index e10ae617dc..737efd8256 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -4,12 +4,14 @@ #pragma once #include +#include #include #include #include #include #include #include +#include #include "ck_tile/core/container/span.hpp" @@ -37,12 +39,14 @@ std::vector to_seqstarts(ck_tile::span seqlens) std::vector generate_seqlens(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, // if not negative, clamp max std::optional seed = std::nullopt) { assert(0 < count); - std::vector seqlens(count, seqlens_sum); + std::vector seqlens( + count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg); if(mode == mode_enum::group && 1 < count) { @@ -55,7 +59,7 @@ std::vector generate_seqlens(mode_enum mode, std::uniform_int_distribution step_dist(1, count - 1); auto next_step = std::bind(step_dist, std::ref(random_engine)); - for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) { const size_type to_decrease = next_idx(); // make sure each elements of seqlens is always greater than 0 @@ -66,6 +70,11 @@ std::vector generate_seqlens(mode_enum mode, const size_type to_increase = (to_decrease + next_step()) % count; + if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max) + { + continue; + } + --seqlens[to_decrease]; ++seqlens[to_increase]; } @@ -76,10 +85,91 @@ std::vector generate_seqlens(mode_enum mode, std::vector generate_seqstarts(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, std::optional seed = std::nullopt) { - return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); + return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed)); +} + +/* + * decode the seqlen string from cmdline + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +std::tuple, + std::vector, + std::vector> +decode_seqlen(mode_enum mode, + ck_tile::index_t batch, + std::string q_val, + std::string k_val, + std::string k_pad_val, + std::optional seed = std::nullopt) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + if(mode == mode_enum::batch) + { + ck_tile::index_t q = _S2I_(q_val); + ck_tile::index_t k = _S2I_(k_val); + auto s_q = std::vector(batch, q); + auto s_k = std::vector(batch, k < 0 ? q : k); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + return std::make_tuple(s_q, s_k, s_kpad); + } + else + { + ck_tile::index_t idx = 0; + std::string::size_type pos_q = 0; + std::string::size_type pos_k = 0; + std::string::size_type pos_kp = 0; + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + while(true) + { + auto found_q = q_val.find(',', pos_q); + auto found_k = k_val.find(',', pos_k); + auto found_kp = k_pad_val.find(',', pos_kp); + + ck_tile::index_t q = _S2I_( + q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q)); + ck_tile::index_t k = _S2I_( + k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k)); + ck_tile::index_t kp = _S2I_(k_pad_val.substr( + pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + idx++; + if(found_q == std::string::npos || idx >= batch) + { + break; + } + pos_q = found_q + 1; + pos_k = found_k == std::string::npos ? pos_k : found_k + 1; + pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1; + } + if(idx < batch) + { + auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed); + auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + } + return std::make_tuple(s_q, s_k, s_kpad); + } +#undef _S2I_ } int env_get_int(const char* var_name, int default_int) @@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int) char* v = getenv(var_name); int r = default_int; if(v) - r = atoi(v); + r = std::atoi(v); return r; } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 53f42a7421..ac2f0cab9c 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz return __builtin_bit_cast(int32x4_t, res); } +namespace impl { +// below type indicate the data type used for buffer load inline asm +// clang-format off +template struct buffer_load_trait; + +template struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; }; +template struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; }; +template struct buffer_load_trait<4 , T> { using payload_t = float; }; +template struct buffer_load_trait<2 , T> { using payload_t = float; }; +template struct buffer_load_trait<1 , T> { using payload_t = float; }; + +#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA +template<> struct buffer_load_trait<16, thread_buffer> { using payload_t = bf16x8_t; }; +template<> struct buffer_load_trait<8 , thread_buffer> { using payload_t = bf16x4_t; }; +template<> struct buffer_load_trait<4 , thread_buffer> { using payload_t = bf16x2_t; }; +#endif +// clang-format on +} // namespace impl + // TODO: glc/slc/... template struct buffer_load; @@ -48,7 +67,7 @@ struct buffer_load<16> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -68,7 +87,7 @@ struct buffer_load<8> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -88,7 +107,7 @@ struct buffer_load<4> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -108,7 +127,7 @@ struct buffer_load<2> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -128,7 +147,7 @@ struct buffer_load<1> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -152,7 +171,7 @@ struct buffer_load_if<16> { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" @@ -177,7 +196,7 @@ struct buffer_load_if<8> { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x2_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" @@ -201,7 +220,7 @@ struct buffer_load_if<4> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" @@ -225,7 +244,7 @@ struct buffer_load_if<2> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" @@ -249,7 +268,7 @@ struct buffer_load_if<1> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 601aad19bd..10045d8f7d 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -171,3 +171,7 @@ #ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #define CK_TILE_FMHA_FWD_FAST_EXP2 0 #endif + +#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA +#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 +#endif diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 0c4a778226..98a3bb7d7f 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -20,3 +20,4 @@ #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 91463a06a9..7c8549f74f 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -27,7 +27,14 @@ struct DeviceMem DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void Realloc(std::size_t mem_size) { @@ -36,7 +43,14 @@ struct DeviceMem HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); } mMemSize = mem_size; - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void* GetDeviceBuffer() const { return mpDeviceBuf; } std::size_t GetBufferSize() const { return mMemSize; } @@ -47,15 +61,18 @@ struct DeviceMem HIP_CHECK_ERROR( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } - else - { - throw std::runtime_error("ToDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("ToDevice with an empty pointer"); + // } } void ToDevice(const void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR( - hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR( + hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + } } void FromDevice(void* p) const { @@ -63,14 +80,17 @@ struct DeviceMem { HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } - else - { - throw std::runtime_error("FromDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("FromDevice with an empty pointer"); + // } } void FromDevice(void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + } } void SetZero() const { @@ -82,13 +102,16 @@ struct DeviceMem template void SetValue(T x) const { - if(mMemSize % sizeof(T) != 0) + if(mpDeviceBuf) { - throw std::runtime_error("wrong! not entire DeviceMem will be set"); - } + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } - // TODO: call a gpu kernel to set the value (?) - set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + // TODO: call a gpu kernel to set the value (?) + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } } ~DeviceMem() { diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 7053888abd..e9c5a0c254 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/timer.hpp" #include #include @@ -14,153 +15,92 @@ template -CK_TILE_HOST float launch_and_time_kernel(const stream_config& s, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) -{ -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { - // warm up - for(int i = 0; i < s.cold_niters_; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - const int nrepeat = s.nrepeat_; - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); - - for(int i = 0; i < nrepeat; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); - - float total_time = 0; - - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); - - return total_time / nrepeat; - } - else - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; - } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; -#endif -} - -template -CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s, - PreProcessFunc preprocess, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) -{ -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { -#if CK_TILE_DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up 1 time\n"); -#endif - // warm up - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - const int nrepeat = 10; -#if CK_TILE_DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); - - for(int i = 0; i < nrepeat; ++i) - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); - - float total_time = 0; - - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); - - return total_time / nrepeat; - } - else - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - return 0; - } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - return 0; -#endif + Kernel{}(args...); } +// +// return a anonymous functor(lambda) to be called later +// the KernelImpl should be a class without non-static data member, or let's say +// can be instantiate with "KernelImpl{}" +// +// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl +// template -CK_TILE_HOST float launch_kernel(const stream_config& s, - KernelImpl kernel_impl, - dim3 grid_dim, - dim3 block_dim, - std::size_t dynamic_smem_byte, - Args... args) +CK_TILE_HOST auto +make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { const auto kernel = kentry; - return launch_and_time_kernel( - s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...); + return [=](const stream_config& s) { + kernel<<>>(args...); + }; } + +// clang-format off +/* + * launch_kernel() + * + * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config) + * the callables should have signature as "operator()(const stream_config& s){ ... }" to call + * + * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }" + * as signature, for the callable (pay attention to the capture list) + * + * e.g. + * ck_tile::launch_kernel(s, + * [=](const stream_config& s){ hipMemset(ptr, 0, size) }, + * [=](const stream_config& s){ some_kernel<<>>(arg); } + * ); + * + * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}") + * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you, + * then pass it to ck_tile::launch_kernel() + * + * e.g. + * ck_tile::launch_kernel(s, + * ck_tile::make_kernel(kernel_0{}, grids0, blocks0, 0, kargs0), + * ck_tile::make_kernel(kernel_1{}, grids1, blocks1, 0, kargs1), + * ...); + **/ +// clang-format on +template +CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) +{ + // clang-format off + if(!s.time_kernel_) { + (callables(s),...); hip_check_error(hipGetLastError()); + return 0; + } + if(s.is_gpu_timer_) { + gpu_timer timer {}; + + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); + + return timer.duration() / s.nrepeat_; + } + else { + cpu_timer timer {}; + + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); + + return timer.duration() / s.nrepeat_; + } + // clang-format on +} + } // namespace ck_tile diff --git a/include/ck_tile/host/stream_config.hpp b/include/ck_tile/host/stream_config.hpp index d29c6f0fa1..47cf0fd5e4 100644 --- a/include/ck_tile/host/stream_config.hpp +++ b/include/ck_tile/host/stream_config.hpp @@ -6,6 +6,22 @@ #include namespace ck_tile { +/* + * construct this structure with behavior as: + * + * // create stream config with default stream(NULL), and not timing the kernel + * stream_config s = stream_config{}; + * + * // create stream config with _some_stream_id_, and not timing the kernel + * stream_config s = stream_config{_some_stream_id_}; + * + * // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default + * stream_config s = stream_config{_some_stream_id_, true}; + * + * // create stream config with _some_stream_id_, and benchmark using cpu timer + * stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false}; + **/ + struct stream_config { hipStream_t stream_id_ = nullptr; @@ -13,5 +29,6 @@ struct stream_config int log_level_ = 0; int cold_niters_ = 3; int nrepeat_ = 10; + bool is_gpu_timer_ = true; // keep compatible }; } // namespace ck_tile diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp new file mode 100644 index 0000000000..e2baeaef7c --- /dev/null +++ b/include/ck_tile/host/timer.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include +#include + +namespace ck_tile { + +struct gpu_timer +{ + CK_TILE_HOST gpu_timer() + { + HIP_CHECK_ERROR(hipEventCreate(&start_evt)); + HIP_CHECK_ERROR(hipEventCreate(&stop_evt)); + } + + CK_TILE_HOST ~gpu_timer() noexcept(false) + { + HIP_CHECK_ERROR(hipEventDestroy(start_evt)); + HIP_CHECK_ERROR(hipEventDestroy(stop_evt)); + } + + CK_TILE_HOST void start(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start_evt, s)); + } + + CK_TILE_HOST void stop(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipEventRecord(stop_evt, s)); + HIP_CHECK_ERROR(hipEventSynchronize(stop_evt)); + } + // return in ms + CK_TILE_HOST float duration() const + { + float ms = 0; + HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt)); + return ms; + } + + private: + hipEvent_t start_evt, stop_evt; +}; + +struct cpu_timer +{ + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void start(const hipStream_t&) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + start_tick = std::chrono::high_resolution_clock::now(); + } + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void stop(const hipStream_t&) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + stop_tick = std::chrono::high_resolution_clock::now(); + } + // return in ms + CK_TILE_HOST float duration() const + { + double sec = + std::chrono::duration_cast>(stop_tick - start_tick) + .count(); + return static_cast(sec * 1e3); + } + + private: + std::chrono::time_point start_tick; + std::chrono::time_point stop_tick; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp index 9c6c353908..c2fdaf3a1a 100644 --- a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -23,13 +23,13 @@ VERTICAL: [0] 1 2 3 4 5 [0] 1 2 3 4 5 -TOP_LEFT: +TOP_LEFT(but negative): [0] 1 2 3 4 5 1 [0] 1 2 3 4 2 1 [0] 1 2 3 3 2 1 [0] 1 2 -FROM_BOTTOM_RIGHT: +FROM_BOTTOM_RIGHT(but negative): 2 1 [0] 1 2 3 3 2 1 [0] 1 2 4 3 2 1 [0] 1 @@ -54,7 +54,7 @@ struct Alibi index_t x_total_, AlibiMode mode_ = AlibiMode::VERTICAL) { - slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; + slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_; shift_left_up = [&]() { if(RowMajor) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 10ce7395ad..9992d56ea9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -76,7 +76,7 @@ struct FmhaFwdKernel return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + @@ -702,7 +702,7 @@ struct FmhaFwdKernel else { return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; } } else diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp index 52f458c72e..e40b006685 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + static constexpr const char* name = "shb"; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) { // TODO: this may need tuning return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * @@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner } }; +template +using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner; + +template +struct FmhaFwdTilePartitioner_HBS +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + static constexpr const char* name = "hbs"; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1)); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + } // namespace ck_tile diff --git a/test/position_embedding/position_embedding.cpp b/test/position_embedding/position_embedding.cpp index e295ec454a..4e13225dd3 100644 --- a/test/position_embedding/position_embedding.cpp +++ b/test/position_embedding/position_embedding.cpp @@ -131,74 +131,74 @@ int main() 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, - 1, 0, 1, 2, 3, 4, - 2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0, - 4, 3, 2, 1, - 5, 4, 3, 2}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2, - 4, 3, 2, 1, 0, 1, - 5, 4, 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, - 1, 2, 3, 4, - 0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, - 1, 0, 1, 2, 3, 4, - 2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0, - 4, 3, 2, 1, - 5, 4, 3, 2}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2, - 4, 3, 2, 1, 0, 1, - 5, 4, 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, - 1, 2, 3, 4, - 0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); rtn &= test_alibi_slope_generation(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); rtn &= test_alibi_slope_generation(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692,