mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
format
This commit is contained in:
@@ -208,7 +208,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
|
||||
|
||||
#if 0
|
||||
# if 1
|
||||
#if 1
|
||||
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
|
||||
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
|
||||
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
|
||||
@@ -217,7 +217,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
|
||||
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
|
||||
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
|
||||
# else
|
||||
#else
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
|
||||
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
|
||||
@@ -226,7 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
|
||||
# endif
|
||||
#endif
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
@@ -266,7 +266,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
|
||||
# if 0
|
||||
#if 0
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
@@ -319,7 +319,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
return 1;
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#endif
|
||||
(void)balance;
|
||||
|
||||
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
@@ -34,11 +34,14 @@ struct fmoe_ // traits, ugly name, only used for internal
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
|
||||
|
||||
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1)
|
||||
static constexpr ck_tile::index_t BT_ =
|
||||
BlockTIle_::at(ck_tile::number<0>{}); // block token(block_m0, block_m1)
|
||||
static constexpr ck_tile::index_t BI_ =
|
||||
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate (block_n0, block_k1)
|
||||
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0)
|
||||
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1)
|
||||
static constexpr ck_tile::index_t BH_ =
|
||||
BlockTIle_::at(ck_tile::number<2>{}); // block hidden(block_k0)
|
||||
static constexpr ck_tile::index_t BD_ =
|
||||
BlockTIle_::at(ck_tile::number<3>{}); // block down(block_n1)
|
||||
|
||||
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -216,7 +216,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f}(topk_weight_host);
|
||||
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
|
||||
@@ -66,448 +66,27 @@ struct FusedMoeGemmPipeline_FlatmmGl
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "flatmm_uk";
|
||||
static constexpr const char* name = "flatmm_gl";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_0 = Policy::template GetUK_1<Problem>().GetSmemSize();
|
||||
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
|
||||
constexpr index_t smem_bridge =
|
||||
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
|
||||
return max(smem_0, max(smem_1, smem_bridge));
|
||||
return smem_bridge;
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetACoord()
|
||||
{
|
||||
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
|
||||
const auto a_coord = a_dist.calculate_index();
|
||||
return a_coord;
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetOCoord()
|
||||
{
|
||||
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
|
||||
const auto o_coord = o_dist.calculate_index();
|
||||
return o_coord;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
|
||||
{
|
||||
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / KLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
|
||||
|
||||
return MRepeat;
|
||||
}
|
||||
|
||||
// TODO: properlly support scatter/gather
|
||||
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
|
||||
{
|
||||
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / KLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
|
||||
|
||||
auto base_coord = threadIdx.x / KLans + base_offset;
|
||||
|
||||
array<index_t, MRepeat> coords;
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
|
||||
template <typename ROW_COORDS>
|
||||
CK_TILE_DEVICE auto GetRowID_A(const ROW_COORDS coords,
|
||||
const IndexDataType* sorted_token_ids_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
|
||||
array<index_t, n_size> row_ids;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
|
||||
});
|
||||
|
||||
return row_ids;
|
||||
}
|
||||
|
||||
// TODO: properlly support scatter/gather
|
||||
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
|
||||
{
|
||||
constexpr index_t WarpGemmLane_M = 16; // TODO: use 16x16
|
||||
constexpr index_t WarpGemmRepeat_M = BlockShape::Block_M0 / WarpGemmLane_M;
|
||||
|
||||
auto base_coord = threadIdx.x % WarpGemmLane_M + base_offset;
|
||||
|
||||
array<index_t, WarpGemmRepeat_M> coords;
|
||||
static_for<0, WarpGemmRepeat_M, 1>{}(
|
||||
[&](auto i) { coords.at(i) = base_coord + i * WarpGemmLane_M; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
|
||||
template <typename ROW_COORDS>
|
||||
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
|
||||
const TopkWeightDataType* sorted_weight_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
|
||||
array<TopkWeightDataType, n_size> w;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
|
||||
});
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto GetRowCoords_O()
|
||||
{
|
||||
constexpr index_t NLans = BlockShape::Block_N1 / kAlignmentA;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / NLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M1 / MLans;
|
||||
|
||||
auto base_coord = threadIdx.x / NLans;
|
||||
|
||||
array<index_t, MRepeat> coords;
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
/*
|
||||
struct FusedMoeGemmKargs
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* a_scale_ptr; // [m, 1], token scale
|
||||
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
void* o_ptr; // [m, k], output token
|
||||
|
||||
const void* sorted_token_ids_ptr;
|
||||
const void* sorted_weight_ptr;
|
||||
const void* sorted_expert_ids_ptr;
|
||||
const void* num_sorted_tiles_ptr;
|
||||
|
||||
index_t hidden_size; // k
|
||||
index_t intermediate_size; // n (TP slice this)
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
*/
|
||||
|
||||
template <typename Karg>
|
||||
CK_TILE_DEVICE auto operator()(const Karg& kargs,
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t sorted_tile_id,
|
||||
index_t intermediate_tile_id)
|
||||
{
|
||||
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
|
||||
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
|
||||
// w1 (Down, N size)
|
||||
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
|
||||
ignore = kargs;
|
||||
ignore = smem;
|
||||
ignore = sorted_tile_id;
|
||||
ignore = intermediate_tile_id;
|
||||
|
||||
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
|
||||
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
|
||||
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
|
||||
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
|
||||
|
||||
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
|
||||
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
|
||||
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
|
||||
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
|
||||
|
||||
// nr*kr*w
|
||||
index_t interm_idx_nr = __builtin_amdgcn_readfirstlane(
|
||||
intermediate_tile_id *
|
||||
BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
|
||||
|
||||
// printf("bid:%d,%d, sorted_tile_id:%d(, intermediate_tile_id:%d, expert_id:%d,
|
||||
// interm_idx_nr:%d\n", static_cast<int>(blockIdx.x),
|
||||
// static_cast<int>(blockIdx.y), sorted_tile_id, intermediate_tile_id, expert_id,
|
||||
// interm_idx_nr);
|
||||
|
||||
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
|
||||
auto row_ids_a = GetRowID_A(
|
||||
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
|
||||
auto a_coords = generate_tuple(
|
||||
[&](auto i) {
|
||||
return row_ids_a[i] * kargs.stride_token +
|
||||
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
|
||||
},
|
||||
number<row_ids_a.size()>{});
|
||||
auto a_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
|
||||
|
||||
auto g_win = [&]() {
|
||||
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_0 +
|
||||
interm_idx_nr * kr_0 * BlockShape::Block_W0;
|
||||
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
g_ptr,
|
||||
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
|
||||
number<kAlignmentG>{},
|
||||
number<1>{});
|
||||
|
||||
// number<BlockShape::Block_Nr0>{}.fff();
|
||||
// number<kAlignmentG>{}.zzz();
|
||||
auto g_window_ = make_tile_window_linear_raw(
|
||||
g_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
{0, 0, 0},
|
||||
Policy::template MakeGlobalTileDistribution_G<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
return g_window_;
|
||||
}();
|
||||
// number<decltype(g_win)::NumAccess_NonLinear>{}.rrr2();
|
||||
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
|
||||
number<decltype(g_win)::NumAccess_NonLinear>{});
|
||||
|
||||
const auto d_win = [&]() {
|
||||
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_1 +
|
||||
interm_idx_nr * BlockShape::Block_W1;
|
||||
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
|
||||
|
||||
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
d_ptr,
|
||||
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
|
||||
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
|
||||
number<kAlignmentD>{},
|
||||
number<1>{});
|
||||
|
||||
const auto d_window_ = make_tile_window_linear_raw(
|
||||
d_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr1>{},
|
||||
number<BlockShape::Block_Kr1>{},
|
||||
number<BlockShape::Block_W1>{}),
|
||||
{0, 0, 0},
|
||||
Policy::template MakeGlobalTileDistribution_D<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
return d_window_;
|
||||
}();
|
||||
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
#if 0
|
||||
auto d_coords = generate_tuple([&](auto i) {
|
||||
return d_win.cached_coords_[i].get_offset(); },
|
||||
number<decltype(d_win)::NumAccess_NonLinear>{});
|
||||
#else
|
||||
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
|
||||
// block-k=512, block-n=128
|
||||
// |<----- W_ ----->|
|
||||
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
|
||||
// y p y y p p y
|
||||
// 1 2 0(imm)
|
||||
auto d_coords = [&]() {
|
||||
constexpr index_t Nr_ = 2;
|
||||
constexpr index_t Nw_ = 4;
|
||||
constexpr index_t Kr0_ = 4;
|
||||
constexpr index_t Kr1_ = 4;
|
||||
constexpr index_t Kl_ = 4;
|
||||
constexpr index_t Nl_ = 16;
|
||||
constexpr index_t Kv_ = 8;
|
||||
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
|
||||
constexpr index_t num_offsets_ = Nr_ * Kr0_;
|
||||
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * Kr0_ * Kr1_ * W_;
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto i_nr_ = number<i % Nr_>{};
|
||||
constexpr auto i_kr0_ = number<i / Nr_>{};
|
||||
|
||||
return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
|
||||
base_os_;
|
||||
},
|
||||
number<num_offsets_>{});
|
||||
}();
|
||||
#endif
|
||||
auto o_coords = generate_tuple(
|
||||
[&](auto i) {
|
||||
return row_ids_a[i] * kargs.stride_token +
|
||||
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
|
||||
},
|
||||
number<row_ids_a.size()>{});
|
||||
|
||||
auto o_flags =
|
||||
generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
|
||||
number<row_ids_a.size()>{});
|
||||
|
||||
auto bridge_sst_win = [&]() {
|
||||
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
|
||||
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<YDataType*>(smem),
|
||||
desc_),
|
||||
desc_.get_lengths(),
|
||||
{0, 0},
|
||||
dist_);
|
||||
}();
|
||||
auto o_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
|
||||
|
||||
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
|
||||
auto w_scale = GetWeightScale(
|
||||
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
|
||||
#if 0
|
||||
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
|
||||
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, "
|
||||
"o_coords:%d,%d,%d,%d,%d,%d,%d,%d(%d,%d,%d,%d,%d,%d,%d,%d)\n",
|
||||
static_cast<int>(blockIdx.x),
|
||||
static_cast<int>(blockIdx.y),
|
||||
static_cast<int>(threadIdx.x),
|
||||
sorted_tile_id,
|
||||
intermediate_tile_id,
|
||||
expert_id,
|
||||
interm_idx_nr,
|
||||
row_coords_a[0],
|
||||
row_coords_a[1],
|
||||
row_coords_a[7],
|
||||
row_ids_a[0],
|
||||
row_ids_a[1],
|
||||
row_ids_a[7],
|
||||
kr_0 * BlockShape::Block_W0,
|
||||
g_coords[number<0>{}],
|
||||
g_coords[number<1>{}],
|
||||
g_coords[number<7>{}],
|
||||
o_coords[number<0>{}],
|
||||
o_coords[number<1>{}],
|
||||
o_coords[number<2>{}],
|
||||
o_coords[number<3>{}],
|
||||
o_coords[number<4>{}],
|
||||
o_coords[number<5>{}],
|
||||
o_coords[number<6>{}],
|
||||
o_coords[number<7>{}],
|
||||
// (row_ids_a[0] >= kargs.num_tokens ? 1 : 0),
|
||||
// (row_ids_a[1] >= kargs.num_tokens ? 1 : 0),
|
||||
// (row_ids_a[2] >= kargs.num_tokens ? 1 : 0),
|
||||
// (row_ids_a[3] >= kargs.num_tokens ? 1 : 0),
|
||||
// (row_ids_a[4] >= kargs.num_tokens ? 1 : 0),
|
||||
// (row_ids_a[5] >= kargs.num_tokens ? 1 : 0),
|
||||
// (row_ids_a[6] >= kargs.num_tokens ? 1 : 0),
|
||||
// (row_ids_a[7] >= kargs.num_tokens ? 1 : 0)
|
||||
|
||||
(row_ids_a[0] < kargs.num_tokens && static_cast<index_t>(o_coords[number<0>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0),
|
||||
(row_ids_a[1] < kargs.num_tokens && static_cast<index_t>(o_coords[number<1>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0),
|
||||
(row_ids_a[2] < kargs.num_tokens && static_cast<index_t>(o_coords[number<2>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0),
|
||||
(row_ids_a[3] < kargs.num_tokens && static_cast<index_t>(o_coords[number<3>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0),
|
||||
(row_ids_a[4] < kargs.num_tokens && static_cast<index_t>(o_coords[number<4>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0),
|
||||
(row_ids_a[5] < kargs.num_tokens && static_cast<index_t>(o_coords[number<5>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0),
|
||||
(row_ids_a[6] < kargs.num_tokens && static_cast<index_t>(o_coords[number<6>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0),
|
||||
(row_ids_a[7] < kargs.num_tokens && static_cast<index_t>(o_coords[number<7>{}]) >=
|
||||
(kargs.num_tokens * kargs.stride_token)
|
||||
? 7777
|
||||
: 0)
|
||||
|
||||
);
|
||||
#endif
|
||||
auto uk_0 = Policy::template GetUK_0<Problem>();
|
||||
auto acc_0 = uk_0(a_res,
|
||||
a_coords,
|
||||
g_res,
|
||||
g_coords,
|
||||
smem,
|
||||
kargs.hidden_size,
|
||||
BlockShape::Block_K0, // tile offset for B matrix each unroll
|
||||
BlockShape::Block_Kr0 *
|
||||
BlockShape::Block_W0); // tile offset for B matrix each unroll
|
||||
|
||||
// return ;
|
||||
//sweep_tile(acc_0,
|
||||
// [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
|
||||
sweep_tile(acc_0,
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_ {acc_0(idx0), acc_0(idx1)};
|
||||
typename Problem::GateActivation{}(v_, v_);
|
||||
acc_0(idx0) = v_.x;
|
||||
acc_0(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
#if 0
|
||||
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
|
||||
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, bridge_sst_win:%d"
|
||||
"acc:%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f,%.1f\n",
|
||||
static_cast<int>(blockIdx.x),
|
||||
static_cast<int>(blockIdx.y),
|
||||
static_cast<int>(threadIdx.x),
|
||||
sorted_tile_id,
|
||||
intermediate_tile_id,
|
||||
expert_id,
|
||||
interm_idx_nr,
|
||||
row_coords_a[0],
|
||||
row_coords_a[1],
|
||||
row_coords_a[7],
|
||||
row_ids_a[0],
|
||||
row_ids_a[1],
|
||||
row_ids_a[7],
|
||||
kr_0 * BlockShape::Block_W0,
|
||||
g_coords[number<0>{}],
|
||||
g_coords[number<1>{}],
|
||||
g_coords[number<7>{}],
|
||||
bridge_sst_win.cached_coords_[number<0>{}].get_offset(),
|
||||
acc_0.get_thread_buffer()[number<0>{}],
|
||||
acc_0.get_thread_buffer()[number<1>{}],
|
||||
acc_0.get_thread_buffer()[number<2>{}],
|
||||
acc_0.get_thread_buffer()[number<3>{}],
|
||||
acc_0.get_thread_buffer()[number<4>{}],
|
||||
acc_0.get_thread_buffer()[number<5>{}],
|
||||
acc_0.get_thread_buffer()[number<6>{}],
|
||||
acc_0.get_thread_buffer()[number<7>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 0>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 1>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 2>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 3>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 4>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 5>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 6>{}],
|
||||
acc_0.get_thread_buffer()[number<8 + 7>{}]);
|
||||
#endif
|
||||
|
||||
auto y_pre = cast_tile<YDataType>(acc_0);
|
||||
store_tile(bridge_sst_win, y_pre);
|
||||
block_sync_lds();
|
||||
|
||||
auto uk_1 = Policy::template GetUK_1<Problem>();
|
||||
uk_1(d_res,
|
||||
d_coords,
|
||||
o_res,
|
||||
o_coords,
|
||||
o_flags,
|
||||
smem,
|
||||
kargs.hidden_size, // total n number
|
||||
w_scale,
|
||||
BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
|
||||
BlockShape::Block_N1); // along N
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user