mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
fix
This commit is contained in:
@@ -187,11 +187,10 @@ size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
void* device_ws, const void* host_ws, size_t device_ws_size, size_t host_ws_size)
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
k_::PrepareWorkspaceDevice(device_ws, host_ws, device_ws_size, host_ws_size);
|
||||
return k_::NeedsZeroDqAcc();
|
||||
}}
|
||||
|
||||
template <>
|
||||
|
||||
@@ -479,10 +479,7 @@ size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
|
||||
const ck_tile::index_t* seqstart_qs,
|
||||
const ck_tile::index_t* seqstart_ks);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_(void* device_ws,
|
||||
const void* host_ws,
|
||||
size_t device_ws_size,
|
||||
size_t host_ws_size);
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
@@ -629,10 +626,15 @@ struct fmha_bwd_launcher
|
||||
traits.seqstart_qs,
|
||||
traits.seqstart_ks);
|
||||
}
|
||||
workspace_size = host_ws_size + device_ws_size;
|
||||
prepare_workspace = [this](void* device_ws) {
|
||||
fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<T1, Arch>(
|
||||
device_ws, ws_host.get(), device_ws_size, host_ws_size);
|
||||
workspace_size = host_ws_size + device_ws_size;
|
||||
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
|
||||
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
|
||||
if(host_ws_size > 0)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
|
||||
if(needs_zero_dq_acc)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user