mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
rewrite save o
This commit is contained in:
@@ -500,8 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
auto c_dev = c_buf.ToHost<ADataType>();
|
||||
std::cout << std::endl;
|
||||
// std::cout << o_dev << std::endl;
|
||||
// std::cout << c_dev << std::endl;
|
||||
std::cout << o_dev << std::endl;
|
||||
// int count = 0;
|
||||
// std::cout << "[";
|
||||
// for(int i = 0; i < tokens; i++)
|
||||
|
||||
@@ -81,7 +81,7 @@ struct indexing_adaptor
|
||||
#if Using_Gather
|
||||
pre_up_index_ = idx_up[number<0>{}];
|
||||
pre_low_index_ = idx_low(number<0>{});
|
||||
#if 0
|
||||
#if 1
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
|
||||
@@ -93,8 +93,8 @@ struct indexing_adaptor
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& /*idx_low*/,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
@@ -109,14 +109,16 @@ struct indexing_adaptor
|
||||
|
||||
pre_up_index_ = up_index;
|
||||
pre_low_index_ = low_index;
|
||||
#if 0
|
||||
#if 1
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
printf("\n index form %d to %d, diff from %d to %d \n",
|
||||
printf("\n index form %d to %d, idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d \n",
|
||||
up_index,
|
||||
low_index,
|
||||
idx_diff_low(number<0>{}),
|
||||
idx_diff_up[number<0>{}],
|
||||
idx_diff_low(number<0>{}));
|
||||
idx_low(number<0>{}),
|
||||
idx_up.at(number<0>{}));
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -252,13 +252,6 @@ struct FusedMoeGemmGlKernel
|
||||
index_t idx_n0 =
|
||||
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0);
|
||||
|
||||
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
|
||||
// const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
|
||||
// // position
|
||||
|
||||
// auto topk_weight =
|
||||
// reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
|
||||
|
||||
const index_t* sorted_token_ids_ptr =
|
||||
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr);
|
||||
|
||||
@@ -375,18 +368,17 @@ struct FusedMoeGemmGlKernel
|
||||
}();
|
||||
|
||||
const auto w_window = [&]() {
|
||||
const TopkWeightDataType* w_ptr = reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr);
|
||||
const auto w_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
const TopkWeightDataType* w_ptr =
|
||||
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr);
|
||||
const auto w_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
w_ptr,
|
||||
make_tuple(kargs.max_num_tokens_padded),
|
||||
make_tuple(1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
const auto w_window_ = make_tile_window(
|
||||
w_view_,
|
||||
make_tuple(number<BlockShape::Block_M0>{}),
|
||||
{idx_m0});
|
||||
const auto w_window_ =
|
||||
make_tile_window(w_view_, make_tuple(number<BlockShape::Block_M0>{}), {idx_m0});
|
||||
return w_window_;
|
||||
}();
|
||||
|
||||
|
||||
@@ -348,22 +348,28 @@ struct FusedMoeGemmPipeline_General
|
||||
while(iCounter1 > 0)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
block_sync_lds();
|
||||
block_sync_lds_direct_load();
|
||||
gemm_1(o_acc, y, d);
|
||||
block_sync_lds();
|
||||
|
||||
move_tile_window(d_global_to_dram_window, {kN1, 0});
|
||||
d = load_tile(d_global_to_dram_window);
|
||||
|
||||
// move out window and save data
|
||||
tile_elementwise_inout([&weight](auto& x) { x = x * type_convert<float>(weight); },
|
||||
o_acc);
|
||||
auto o = cast_tile<ODataType>(o_acc);
|
||||
store_tile(o_window_, o);
|
||||
move_tile_window(o_window_, {kN1, 0});
|
||||
store_tile(o_alds_win, o);
|
||||
block_sync_lds();
|
||||
save_o();
|
||||
|
||||
move_tile_window(o_window_, {0, kN1});
|
||||
|
||||
iCounter1--;
|
||||
}
|
||||
// tail
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
block_sync_lds();
|
||||
block_sync_lds_direct_load();
|
||||
gemm_1(o_acc, y, d);
|
||||
|
||||
// block_sync_lds();
|
||||
|
||||
Reference in New Issue
Block a user