Merge commit '2aec38f9ec67bfbdccbdb3a5c25913e5a9ba6136' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-19 07:12:19 +00:00
parent 6e7460a434
commit 2d48a99ddd
17 changed files with 287 additions and 162 deletions

View File

@@ -1446,29 +1446,35 @@ struct FmhaFwdKernel
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{kargs.scale_o});
else
return ck_tile::scales{kargs.scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
else
{

View File

@@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
// SGrad and Bias use the same address in LDS, finish loading ds on the previous
// iteration to reuse LDS.
block_sync_lds();
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
@@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
// SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to
// reuse LDS.
block_sync_lds();
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
@@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
return cast_tile<BiasGradDataType>(ds);
}
}();
// Finish loading bias_s to reuse LDS.
block_sync_lds();
store_tile(bias_lds_write_window, dbias);
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
@@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS.
block_sync_lds();
}
// SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when
// bias is not used, loading ds in the hot loop to reuse LDS.
block_sync_lds();
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();

View File

@@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
// LDS.
block_sync_lds();
}
store_tile(ds_lds_window, ds_gemm);
}
s_waitcnt</*vmcnt=*/0>();

View File

@@ -656,6 +656,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
if constexpr(kHasBiasGrad)
{
// SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
// LDS.
block_sync_lds();
}
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);

View File

@@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds);