mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
correct clang-format
This commit is contained in:
@@ -304,8 +304,9 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
[[maybe_unused]] const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1 /*kbatch*/, max_accumulated_value);
|
||||
[[maybe_unused]] const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1 /*kbatch*/, max_accumulated_value);
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
@@ -1264,7 +1264,7 @@ struct MoeFlatmmKernel
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user