mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Merge commit '2aec38f9ec67bfbdccbdb3a5c25913e5a9ba6136' into develop
This commit is contained in:
@@ -1446,29 +1446,35 @@ struct FmhaFwdKernel
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{kargs.scale_o});
|
||||
else
|
||||
return ck_tile::scales{kargs.scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
// SGrad and Bias use the same address in LDS, finish loading ds on the previous
|
||||
// iteration to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
@@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
// SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to
|
||||
// reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
@@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
// Finish loading bias_s to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
@@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
// SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when
|
||||
// bias is not used, loading ds in the hot loop to reuse LDS.
|
||||
block_sync_lds();
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
|
||||
// LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
|
||||
@@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
|
||||
// LDS.
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
|
||||
@@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
|
||||
constexpr index_t smem_size_stage0_1 = smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot +
|
||||
smem_size_do + smem_size_lse + smem_size_d +
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user