update coherence

This commit is contained in:
zanzhang
2025-10-29 20:35:25 +08:00
parent 96b89dc4ae
commit 237363809d
3 changed files with 36 additions and 4 deletions

View File

@@ -1326,6 +1326,19 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// s[1:0] System Cache Level: 0=warp, 1=group, 2=device, 3=system
// bit0 = sc0, bit1 = nt, bit2 = swz?, bit4 = sc1
//
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,

View File

@@ -1186,6 +1186,17 @@ enum struct amd_buffer_coherence_enum
glc = 1,
slc = 2,
glc_slc = 3,
// s[1:0] System Cache Level: 0=warp, 1=group, 2=device, 3=system
// bit0 = sc0, bit1 = nt, bit2 = swz?, bit4 = sc1
//
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
};
template <index_t N,

View File

@@ -569,7 +569,9 @@ struct FlatmmKernel
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
return make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
@@ -578,7 +580,9 @@ struct FlatmmKernel
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
return make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
@@ -628,7 +632,9 @@ struct FlatmmKernel
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
return make_naive_tensor_view<address_space_enum::global,
DstInMemOp,
amd_buffer_coherence_enum::SYSTEM_NT1>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
@@ -637,7 +643,9 @@ struct FlatmmKernel
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
return make_naive_tensor_view<address_space_enum::global,
DstInMemOp,
amd_buffer_coherence_enum::SYSTEM_NT1>(
e_ptr,
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_E, 1),