mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Merge commit '0db21053e68817a50b0ed0ceea87e88228ab2475' into develop
This commit is contained in:
@@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
|
||||
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
|
||||
fmha_traits.hdim_q,
|
||||
fmha_traits.hdim_v,
|
||||
fmha_traits.data_type.c_str(),
|
||||
fmha_traits.is_group_mode,
|
||||
static_cast<int>(fmha_traits.mask_type),
|
||||
static_cast<int>(fmha_traits.bias_type),
|
||||
fmha_traits.has_dbias,
|
||||
fmha_traits.has_dropout,
|
||||
fmha_traits.is_store_randval,
|
||||
fmha_traits.is_deterministic);
|
||||
fflush(stdout);
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
|
||||
@@ -1276,26 +1276,46 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t /*soffset*/,
|
||||
index_t ioffset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
template <unsigned num_dwords, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t /*soffset*/,
|
||||
index_t ioffset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset)
|
||||
#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \
|
||||
if constexpr(pre_nop) \
|
||||
asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \
|
||||
: "=r"(smem) /*dummy dependency for smem*/ \
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
|
||||
: "memory"); \
|
||||
else \
|
||||
asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \
|
||||
: "=r"(smem) /*dummy dependency for smem*/ \
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
|
||||
: "memory");
|
||||
|
||||
if constexpr(num_dwords == 1)
|
||||
{
|
||||
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword");
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
else if constexpr(num_dwords == 3)
|
||||
{
|
||||
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3");
|
||||
}
|
||||
else if constexpr(num_dwords == 4)
|
||||
{
|
||||
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4");
|
||||
}
|
||||
#endif
|
||||
else
|
||||
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset)
|
||||
: "memory");
|
||||
{
|
||||
static_assert(false, "wrong! not implemented data width");
|
||||
}
|
||||
#undef CK_TILE_ASYNC_LOAD_WITH_INSTR
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
|
||||
@@ -1766,15 +1786,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
constexpr index_t num_bytes = sizeof(T) * N;
|
||||
constexpr index_t num_words = num_bytes / 4;
|
||||
static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4),
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
|
||||
async_buffer_load_dword_v(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<pre_nop>{});
|
||||
async_buffer_load_dwordxn_v<num_words>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -1144,26 +1144,46 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t /*soffset*/,
|
||||
index_t ioffset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
template <unsigned num_dwords, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t /*soffset*/,
|
||||
index_t ioffset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset)
|
||||
#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \
|
||||
if constexpr(pre_nop) \
|
||||
asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \
|
||||
: "=r"(smem) /*dummy dependency for smem*/ \
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
|
||||
: "memory"); \
|
||||
else \
|
||||
asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \
|
||||
: "=r"(smem) /*dummy dependency for smem*/ \
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset) \
|
||||
: "memory");
|
||||
|
||||
if constexpr(num_dwords == 1)
|
||||
{
|
||||
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword");
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
else if constexpr(num_dwords == 3)
|
||||
{
|
||||
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3");
|
||||
}
|
||||
else if constexpr(num_dwords == 4)
|
||||
{
|
||||
CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4");
|
||||
}
|
||||
#endif
|
||||
else
|
||||
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(voffset), "s"(rsrc), "n"(ioffset)
|
||||
: "memory");
|
||||
{
|
||||
static_assert(false, "wrong! not implemented data width");
|
||||
}
|
||||
#undef CK_TILE_ASYNC_LOAD_WITH_INSTR
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
|
||||
@@ -1536,15 +1556,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
constexpr index_t num_bytes = sizeof(T) * N;
|
||||
constexpr index_t num_words = num_bytes / 4;
|
||||
static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4),
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
|
||||
async_buffer_load_dword_v(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<pre_nop>{});
|
||||
async_buffer_load_dwordxn_v<num_words>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -98,9 +98,18 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
|
||||
// Use these instead
|
||||
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
template <bool ReturnSgpr = true>
|
||||
CK_TILE_DEVICE index_t get_warp_id(bool_constant<ReturnSgpr> = {})
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
const index_t warp_id = threadIdx.x / get_warp_size();
|
||||
if constexpr(ReturnSgpr)
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(warp_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
return warp_id;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
@@ -288,8 +288,11 @@ struct tile_window_with_static_distribution
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
// Use VALU so the compiler can optimize redundant/repeated computations
|
||||
const index_t m0_init_value =
|
||||
size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
|
||||
m0_set_with_memory(
|
||||
__builtin_amdgcn_readfirstlane(m0_init_value)); // This should be wave independent
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
|
||||
@@ -649,8 +649,12 @@ struct FmhaFwdAppendKVKernel
|
||||
{0, i_n0});
|
||||
|
||||
// If kApplyRoPe is false, we set the rotary_dim to 0
|
||||
auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0;
|
||||
|
||||
auto rotary_dim = [&]() {
|
||||
if constexpr(kApplyRoPE)
|
||||
return kargs.rotary_dim;
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
i_page_block_k,
|
||||
|
||||
@@ -347,22 +347,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
|
||||
bias_lds_write_window.get_window_lengths(),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
@@ -500,8 +497,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
async_load_tile(bias_lds_write_window, bias_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
|
||||
@@ -323,22 +323,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_kv_start},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
|
||||
bias_lds_write_window.get_window_lengths(),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
@@ -490,8 +487,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
async_load_tile(bias_lds_write_window, bias_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
|
||||
@@ -551,11 +551,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor()
|
||||
{
|
||||
return MakeXLdsWriteBlockDescriptor<typename Problem::BiasDataType,
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0>();
|
||||
return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasLdsBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem, bool Transposed = false>
|
||||
@@ -684,13 +682,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor()
|
||||
{
|
||||
return MakeXLdsReadBlockDescriptor<typename Problem::BiasDataType,
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor()
|
||||
@@ -966,25 +957,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t N1 = min(static_cast<index_t>(GetAlignmentBias<Problem>()),
|
||||
kMPerBlock * kNPerBlock / kBlockSize);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = get_warp_size() / N0;
|
||||
constexpr index_t M2 = kMPerBlock / M1 / M0;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
return BlockFmhaBwdPipelineDefaultPolicy::MakeShuffledBiasTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
@@ -1048,7 +1021,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return sizeof(typename Problem::BiasDataType) *
|
||||
MakeBiasLdsWriteBlockDescriptor<Problem>().get_element_space_size();
|
||||
MakeBiasLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -364,7 +364,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
if constexpr(AsyncCopy)
|
||||
{
|
||||
return 4 / sizeof(KDataType);
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4
|
||||
#else
|
||||
constexpr index_t MaxLoadSizeInBytes = 4; // dword
|
||||
#endif
|
||||
|
||||
return MaxLoadSizeInBytes / sizeof(KDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user