mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
temp save, change all instance to 1wave
This commit is contained in:
@@ -1315,6 +1315,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -1183,6 +1183,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -712,6 +712,7 @@ struct FmhaFwdDecodeKernel
|
||||
{
|
||||
// reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
|
||||
// hdim_q)
|
||||
// We expect Q data reuse among different KVSplited in decode case.
|
||||
const auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
|
||||
@@ -755,7 +756,8 @@ struct FmhaFwdDecodeKernel
|
||||
}();
|
||||
|
||||
const auto make_k_dram = [&](const KDataType* data, index_t height) {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
// We don't expect K data reuse among different blocks in decode case.
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(height, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
@@ -781,7 +783,8 @@ struct FmhaFwdDecodeKernel
|
||||
const auto make_v_dram = [&](const VDataType* data, index_t length) {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
// We don't expect V data reuse among different blocks in decode case.
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(length, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
|
||||
@@ -44,6 +44,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -546,13 +548,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// In Nwarp=1 and NXdl=32, GEMM0 output naturally fit the input of GEMM1
|
||||
// Otherwise shuffle through LDS so that the tile layout is consistent with required by Gemm1
|
||||
auto s_new = [&](){
|
||||
if constexpr ( !((kNWarp==1) && (kNXdl == 32)) ){
|
||||
auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
// shuffle through LDS so that the tile layout is consistent with required by Gemm1
|
||||
store_tile(s_write_lds_window, s);
|
||||
block_sync_lds();
|
||||
auto s_new = load_tile(s_read_lds_window);
|
||||
store_tile(s_write_lds_window, s);
|
||||
block_sync_lds();
|
||||
return load_tile(s_read_lds_window);
|
||||
}
|
||||
else{
|
||||
return cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
}
|
||||
}();
|
||||
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s_new,
|
||||
|
||||
@@ -157,7 +157,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static_assert(MWarp == 1, "Check failed!");
|
||||
// static_assert(MWarp == 1, "Check failed!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
Reference in New Issue
Block a user