diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index 39620d0dc3..fcd1404dda 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor num_sorted_tiles_host({1}); #if 0 -# if 1 +#if 1 ck_tile::FillStepRange{-.5f, .5f, 0.01f}(a_host); ck_tile::FillStepRange{-.5f, .5f, 0.01f}(g_host); ck_tile::FillStepRange{.5f, -.5f, -0.01f}(d_host); @@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sd_host); ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sy_host); ck_tile::FillStepRange{-.5f, .5f, 0.01f}(topk_weight_host); -# else +#else ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(g_host); ck_tile::FillUniformDistribution{-.5f, .5f}(d_host); @@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-.5f, .5f}(sd_host); ck_tile::FillUniformDistribution{-.5f, .5f}(sy_host); ck_tile::FillUniformDistribution{-.5f, .5f}(topk_weight_host); -# endif +#endif // permute weight ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); @@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; -# if 0 +#if 0 ck_tile::reference_moe_sorting( topk_ids_host, topk_weight_host, @@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } return 1; -# endif +#endif #endif (void)balance; diff --git a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp index c3e74a1945..20df0140f8 100644 --- a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp +++ b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp @@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) { - using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; r = fused_moegemm_(s, a); } // clang-format on diff --git a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp index b81f14495a..b4f124124b 100644 --- a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp +++ b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp @@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal using TopkWeightDataType = ck_tile::remove_cvref_t; using IndexDataType = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1) + static constexpr ck_tile::index_t BT_ = + BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1) static constexpr ck_tile::index_t BI_ = BlockTIle_::at(ck_tile::number<1>{}); // block intermediate (block_n0, block_k1) - static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0) - static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1) + static constexpr ck_tile::index_t BH_ = + BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0) + static constexpr ck_tile::index_t BD_ = + BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1) using BlockTile_0 = ck_tile::sequence; using WarpPerBlock_0 = ck_tile::remove_cvref_t; diff --git a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp index 93f9c77869..c57fe4b194 100644 --- a/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp +++ b/example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp @@ -8,7 +8,7 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on diff --git a/example/ck_tile/16_fused_moe_general/main.cpp b/example/ck_tile/16_fused_moe_general/main.cpp index 8a88e4357f..6b066a20af 100644 --- a/example/ck_tile/16_fused_moe_general/main.cpp +++ b/example/ck_tile/16_fused_moe_general/main.cpp @@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-.5f, .5f}(sy_host); ck_tile::FillUniformDistribution{0.0f, 1.0f}(topk_weight_host); - // permute weight ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp index ab33459872..c0274235bb 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp @@ -66,448 +66,27 @@ struct FusedMoeGemmPipeline_FlatmmGl } }(); - static constexpr const char* name = "flatmm_uk"; + static constexpr const char* name = "flatmm_gl"; CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - constexpr index_t smem_0 = Policy::template GetUK_1().GetSmemSize(); - constexpr index_t smem_1 = Policy::template GetUK_1().GetSmemSize(); constexpr index_t smem_bridge = BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); - return max(smem_0, max(smem_1, smem_bridge)); + return smem_bridge; } - // this is the thread-offset along row/col - CK_TILE_HOST_DEVICE static auto GetACoord() - { - constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A(); - const auto a_coord = a_dist.calculate_index(); - return a_coord; - } - - // this is the thread-offset along row/col - CK_TILE_HOST_DEVICE static auto GetOCoord() - { - constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution(); - const auto o_coord = o_dist.calculate_index(); - return o_coord; - } - - CK_TILE_DEVICE constexpr auto GetNumRowCoords_A() - { - constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA; - constexpr index_t MLans = BlockShape::BlockSize / KLans; - constexpr index_t MRepeat = BlockShape::Block_M0 / MLans; - - return MRepeat; - } - - // TODO: properlly support scatter/gather - CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset) - { - constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA; - constexpr index_t MLans = BlockShape::BlockSize / KLans; - constexpr index_t MRepeat = BlockShape::Block_M0 / MLans; - - auto base_coord = threadIdx.x / KLans + base_offset; - - array coords; - static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; }); - - return coords; - } - - template - CK_TILE_DEVICE auto GetRowID_A(const ROW_COORDS coords, - const IndexDataType* sorted_token_ids_ptr) - { - constexpr index_t n_size = coords.size(); - - array row_ids; - static_for<0, n_size, 1>{}([&](auto i) { - row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans; - }); - - return row_ids; - } - - // TODO: properlly support scatter/gather - CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset) - { - constexpr index_t WarpGemmLane_M = 16; // TODO: use 16x16 - constexpr index_t WarpGemmRepeat_M = BlockShape::Block_M0 / WarpGemmLane_M; - - auto base_coord = threadIdx.x % WarpGemmLane_M + base_offset; - - array coords; - static_for<0, WarpGemmRepeat_M, 1>{}( - [&](auto i) { coords.at(i) = base_coord + i * WarpGemmLane_M; }); - - return coords; - } - - template - CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords, - const TopkWeightDataType* sorted_weight_ptr) - { - constexpr index_t n_size = coords.size(); - - array w; - static_for<0, n_size, 1>{}([&](auto i) { - w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans; - }); - - return w; - } - - CK_TILE_DEVICE auto GetRowCoords_O() - { - constexpr index_t NLans = BlockShape::Block_N1 / kAlignmentA; - constexpr index_t MLans = BlockShape::BlockSize / NLans; - constexpr index_t MRepeat = BlockShape::Block_M1 / MLans; - - auto base_coord = threadIdx.x / NLans; - - array coords; - static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; }); - - return coords; - } - /* - struct FusedMoeGemmKargs - { - const void* a_ptr; // [m, k], input token - const void* a_scale_ptr; // [m, 1], token scale - const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) - const void* g_scale_ptr; // [e, 1, n], gate(up) scale - const void* d_scale_ptr; // [e, 1, k], down scale - const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input - void* o_ptr; // [m, k], output token - - const void* sorted_token_ids_ptr; - const void* sorted_weight_ptr; - const void* sorted_expert_ids_ptr; - const void* num_sorted_tiles_ptr; - - index_t hidden_size; // k - index_t intermediate_size; // n (TP slice this) - index_t num_tokens; // input number of tokens for current iteration - index_t num_experts; // number of groups - index_t topk; // need this? - - index_t stride_token; // for input/output, stride for each row, should >= hidden_size - }; - */ + template CK_TILE_DEVICE auto operator()(const Karg& kargs, CK_TILE_LDS_ADDR void* smem, index_t sorted_tile_id, index_t intermediate_tile_id) { - constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; - ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; - // w1 (Down, N size) - ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0; + ignore = kargs; + ignore = smem; + ignore = sorted_tile_id; + ignore = intermediate_tile_id; - index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W - index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W - index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1; - index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1; - - const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( - reinterpret_cast(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); - index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; - index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; - - // nr*kr*w - index_t interm_idx_nr = __builtin_amdgcn_readfirstlane( - intermediate_tile_id * - BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W) - - // printf("bid:%d,%d, sorted_tile_id:%d(, intermediate_tile_id:%d, expert_id:%d, - // interm_idx_nr:%d\n", static_cast(blockIdx.x), - // static_cast(blockIdx.y), sorted_tile_id, intermediate_tile_id, expert_id, - // interm_idx_nr); - - auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0); - auto row_ids_a = GetRowID_A( - row_coords_a, reinterpret_cast(kargs.sorted_token_ids_ptr)); - auto a_coords = generate_tuple( - [&](auto i) { - return row_ids_a[i] * kargs.stride_token + - threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; - }, - number{}); - auto a_res = - make_wave_buffer_resource(reinterpret_cast(kargs.a_ptr), - kargs.num_tokens * kargs.stride_token * sizeof(ADataType)); - - auto g_win = [&]() { - const GDataType* g_ptr = reinterpret_cast(kargs.g_ptr) + - static_cast(expert_id) * expert_stride_0 + - interm_idx_nr * kr_0 * BlockShape::Block_W0; - auto g_view_ = make_naive_tensor_view( - g_ptr, - make_tuple(nr_0, kr_0, number{}), - make_tuple(kr_0 * BlockShape::Block_W0, number{}, 1), - number{}, - number<1>{}); - - // number{}.fff(); - // number{}.zzz(); - auto g_window_ = make_tile_window_linear_raw( - g_view_, - make_tuple(number{}, - number{}, - number{}), - {0, 0, 0}, - Policy::template MakeGlobalTileDistribution_G(), - sequence<0, 1, 1>{}); - return g_window_; - }(); - // number{}.rrr2(); - auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; - auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); }, - number{}); - - const auto d_win = [&]() { - const DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + - static_cast(expert_id) * expert_stride_1 + - interm_idx_nr * BlockShape::Block_W1; - // note interm_idx_nr is along the gemm-k dim of 2nd gemm - - const auto d_view_ = make_naive_tensor_view( - d_ptr, - make_tuple(nr_1, kr_1, BlockShape::Block_W1), - make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1), - number{}, - number<1>{}); - - const auto d_window_ = make_tile_window_linear_raw( - d_view_, - make_tuple(number{}, - number{}, - number{}), - {0, 0, 0}, - Policy::template MakeGlobalTileDistribution_D(), - sequence<0, 1, 1>{}); - return d_window_; - }(); - auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; -#if 0 - auto d_coords = generate_tuple([&](auto i) { - return d_win.cached_coords_[i].get_offset(); }, - number{}); -#else - // TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255 - // block-k=512, block-n=128 - // |<----- W_ ----->| - // Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue - // y p y y p p y - // 1 2 0(imm) - auto d_coords = [&]() { - constexpr index_t Nr_ = 2; - constexpr index_t Nw_ = 4; - constexpr index_t Kr0_ = 4; - constexpr index_t Kr1_ = 4; - constexpr index_t Kl_ = 4; - constexpr index_t Nl_ = 16; - constexpr index_t Kv_ = 8; - constexpr index_t W_ = Kl_ * Nl_ * Kv_; - constexpr index_t num_offsets_ = Nr_ * Kr0_; - index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * Kr0_ * Kr1_ * W_; - return generate_tuple( - [&](auto i) { - constexpr auto i_nr_ = number{}; - constexpr auto i_kr0_ = number{}; - - return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ + - base_os_; - }, - number{}); - }(); -#endif - auto o_coords = generate_tuple( - [&](auto i) { - return row_ids_a[i] * kargs.stride_token + - threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO; - }, - number{}); - - auto o_flags = - generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); }, - number{}); - - auto bridge_sst_win = [&]() { - constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc(); - constexpr auto dist_ = Policy::template GetUK_0().MakeCBlockDist(); - return make_tile_window_linear( - make_tensor_view( - reinterpret_cast(smem), - desc_), - desc_.get_lengths(), - {0, 0}, - dist_); - }(); - auto o_res = - make_wave_buffer_resource(reinterpret_cast(kargs.o_ptr), - kargs.num_tokens * kargs.stride_token * sizeof(ODataType)); - - auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0); - auto w_scale = GetWeightScale( - row_coords_o, reinterpret_cast(kargs.sorted_weight_ptr)); -#if 0 - printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, " - "interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, " - "o_coords:%d,%d,%d,%d,%d,%d,%d,%d(%d,%d,%d,%d,%d,%d,%d,%d)\n", - static_cast(blockIdx.x), - static_cast(blockIdx.y), - static_cast(threadIdx.x), - sorted_tile_id, - intermediate_tile_id, - expert_id, - interm_idx_nr, - row_coords_a[0], - row_coords_a[1], - row_coords_a[7], - row_ids_a[0], - row_ids_a[1], - row_ids_a[7], - kr_0 * BlockShape::Block_W0, - g_coords[number<0>{}], - g_coords[number<1>{}], - g_coords[number<7>{}], - o_coords[number<0>{}], - o_coords[number<1>{}], - o_coords[number<2>{}], - o_coords[number<3>{}], - o_coords[number<4>{}], - o_coords[number<5>{}], - o_coords[number<6>{}], - o_coords[number<7>{}], - // (row_ids_a[0] >= kargs.num_tokens ? 1 : 0), - // (row_ids_a[1] >= kargs.num_tokens ? 1 : 0), - // (row_ids_a[2] >= kargs.num_tokens ? 1 : 0), - // (row_ids_a[3] >= kargs.num_tokens ? 1 : 0), - // (row_ids_a[4] >= kargs.num_tokens ? 1 : 0), - // (row_ids_a[5] >= kargs.num_tokens ? 1 : 0), - // (row_ids_a[6] >= kargs.num_tokens ? 1 : 0), - // (row_ids_a[7] >= kargs.num_tokens ? 1 : 0) - - (row_ids_a[0] < kargs.num_tokens && static_cast(o_coords[number<0>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0), - (row_ids_a[1] < kargs.num_tokens && static_cast(o_coords[number<1>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0), - (row_ids_a[2] < kargs.num_tokens && static_cast(o_coords[number<2>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0), - (row_ids_a[3] < kargs.num_tokens && static_cast(o_coords[number<3>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0), - (row_ids_a[4] < kargs.num_tokens && static_cast(o_coords[number<4>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0), - (row_ids_a[5] < kargs.num_tokens && static_cast(o_coords[number<5>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0), - (row_ids_a[6] < kargs.num_tokens && static_cast(o_coords[number<6>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0), - (row_ids_a[7] < kargs.num_tokens && static_cast(o_coords[number<7>{}]) >= - (kargs.num_tokens * kargs.stride_token) - ? 7777 - : 0) - - ); -#endif - auto uk_0 = Policy::template GetUK_0(); - auto acc_0 = uk_0(a_res, - a_coords, - g_res, - g_coords, - smem, - kargs.hidden_size, - BlockShape::Block_K0, // tile offset for B matrix each unroll - BlockShape::Block_Kr0 * - BlockShape::Block_W0); // tile offset for B matrix each unroll - - // return ; - //sweep_tile(acc_0, - // [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); }); - sweep_tile(acc_0, - [&](auto idx0, auto idx1) { - fp32x2_t v_ {acc_0(idx0), acc_0(idx1)}; - typename Problem::GateActivation{}(v_, v_); - acc_0(idx0) = v_.x; - acc_0(idx1) = v_.y; - }, - sequence<1, 2>{}); - -#if 0 - printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, " - "interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, bridge_sst_win:%d" - "acc:%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f\n", - static_cast(blockIdx.x), - static_cast(blockIdx.y), - static_cast(threadIdx.x), - sorted_tile_id, - intermediate_tile_id, - expert_id, - interm_idx_nr, - row_coords_a[0], - row_coords_a[1], - row_coords_a[7], - row_ids_a[0], - row_ids_a[1], - row_ids_a[7], - kr_0 * BlockShape::Block_W0, - g_coords[number<0>{}], - g_coords[number<1>{}], - g_coords[number<7>{}], - bridge_sst_win.cached_coords_[number<0>{}].get_offset(), - acc_0.get_thread_buffer()[number<0>{}], - acc_0.get_thread_buffer()[number<1>{}], - acc_0.get_thread_buffer()[number<2>{}], - acc_0.get_thread_buffer()[number<3>{}], - acc_0.get_thread_buffer()[number<4>{}], - acc_0.get_thread_buffer()[number<5>{}], - acc_0.get_thread_buffer()[number<6>{}], - acc_0.get_thread_buffer()[number<7>{}], - acc_0.get_thread_buffer()[number<8 + 0>{}], - acc_0.get_thread_buffer()[number<8 + 1>{}], - acc_0.get_thread_buffer()[number<8 + 2>{}], - acc_0.get_thread_buffer()[number<8 + 3>{}], - acc_0.get_thread_buffer()[number<8 + 4>{}], - acc_0.get_thread_buffer()[number<8 + 5>{}], - acc_0.get_thread_buffer()[number<8 + 6>{}], - acc_0.get_thread_buffer()[number<8 + 7>{}]); -#endif - - auto y_pre = cast_tile(acc_0); - store_tile(bridge_sst_win, y_pre); - block_sync_lds(); - - auto uk_1 = Policy::template GetUK_1(); - uk_1(d_res, - d_coords, - o_res, - o_coords, - o_flags, - smem, - kargs.hidden_size, // total n number - w_scale, - BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N - BlockShape::Block_N1); // along N } };