mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Move the calling of mask.GetTileRangeAlongX() to the kernel
This commit is contained in:
@@ -889,10 +889,17 @@ struct HstuAttentionFwdKernel
|
||||
};
|
||||
}();
|
||||
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
@@ -917,10 +924,17 @@ struct HstuAttentionFwdKernel
|
||||
};
|
||||
}();
|
||||
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
kargs.scale_p,
|
||||
|
||||
@@ -133,6 +133,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
@@ -186,8 +188,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -467,6 +467,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
@@ -482,6 +484,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
|
||||
@@ -133,6 +133,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
@@ -186,8 +188,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -444,6 +444,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
@@ -459,6 +461,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
|
||||
@@ -133,6 +133,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
@@ -199,8 +201,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -593,6 +593,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
@@ -608,6 +610,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
|
||||
@@ -133,6 +133,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLu result
|
||||
@@ -199,8 +201,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -591,6 +591,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask mask,
|
||||
float scale_s, // scaling value exerted on the immediate Q@K result
|
||||
float scale_p, // scaling value exerted on the SiLU result
|
||||
@@ -606,6 +608,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
scale_s,
|
||||
scale_p,
|
||||
|
||||
Reference in New Issue
Block a user