mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
code cleanup
This commit is contained in:
@@ -58,43 +58,33 @@ struct BlockDropout<true, IsWG32_, IsStoreRandval_>
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
if constexpr(IsDropout)
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr index_t kMPerStep = MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
|
||||
auto randval_dram_window = [&]() {
|
||||
if constexpr(IsFwd)
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
|
||||
}
|
||||
}();
|
||||
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
|
||||
auto randval_dram_window = [&]() {
|
||||
if constexpr(IsFwd)
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
|
||||
}
|
||||
}();
|
||||
|
||||
return randval_dram_window;
|
||||
}
|
||||
else
|
||||
{
|
||||
(void)randval_dram_block_window_tmp;
|
||||
(void)seqlen_qk_start;
|
||||
|
||||
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
|
||||
}
|
||||
return randval_dram_window;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
|
||||
Reference in New Issue
Block a user