Merge remote-tracking branch 'origin/develop' into andriy/ck_tile/basic-tutorials

This commit is contained in:
Andriy Roshchenko
2026-02-04 23:43:18 +00:00
315 changed files with 18027 additions and 2452 deletions

View File

@@ -96,11 +96,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<8, 32, 1>,
S<8, 16, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
8,
0,
1,
@@ -108,7 +108,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineVersion::v1>;
int main(int argc, char* argv[])
{
@@ -174,6 +174,29 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));

View File

@@ -94,11 +94,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<8, 32, 1>,
S<8, 16, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
8,
0,
1,
@@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineVersion::v1>;
int main(int argc, char* argv[])
{
@@ -133,7 +133,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
@@ -170,6 +170,28 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));

View File

@@ -141,11 +141,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<4, 64, 1>,
S<4, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
1,
1,
8,
8,
0,
1,
@@ -233,6 +233,29 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideD = f_get_default_stride(M, N, StrideD, DLayout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));

View File

@@ -95,11 +95,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
8,
8,
0,
S<8, 32, 1>,
S<8, 16, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
8,
0,
1,
@@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineVersion::v1>;
int main(int argc, char* argv[])
{
@@ -173,6 +173,29 @@ int main(int argc, char* argv[])
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, A0Layout{});
StrideB = f_get_default_stride(K, N, StrideB, B0Layout{});
StrideD = f_get_default_stride(M, N, StrideD, D0Layout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));

View File

@@ -5,4 +5,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_convnd_activ_xdl_convinvscale)
add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8)
endif()
endif()
# WMMA
if (GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_convnd_activ_wmma_convinvscale)
add_example_executable(example_convnd_fwd_wmma_convinvscale_fp8 convnd_fwd_wmma_convinvscale_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convinvscale example_convnd_fwd_wmma_convinvscale_fp8)
endif()

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_convinvscale_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvInvscale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvInvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_convinvscale_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -15,3 +15,19 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8)
endif()
# WMMA
if (GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_convnd_activ_wmma_convscale)
add_example_executable(example_convnd_fwd_wmma_convscale_fp8 convnd_fwd_wmma_convscale_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8)
add_example_executable(example_convnd_fwd_wmma_convscale_bf8 convnd_fwd_wmma_convscale_bf8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8)
add_example_executable(example_convnd_fwd_wmma_convscale_fp8_bf8 convnd_fwd_wmma_convscale_fp8_bf8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8_bf8)
add_example_executable(example_convnd_fwd_wmma_convscale_bf8_fp8 convnd_fwd_wmma_convscale_bf8_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8_fp8)
endif()

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_convscale_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::bf8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = InDataType;
using BComputeDataType = AComputeDataType;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_convscale_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_convscale_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::bf8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::bf8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_convscale_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_convscale_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_convscale_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_convscale_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::bf8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::bf8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_convscale_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -5,4 +5,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_convnd_activ_xdl_convscale_add)
add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8)
endif()
endif()
# WMMA
if (GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_convnd_activ_wmma_convscale_add)
add_example_executable(example_convnd_fwd_wmma_convscale_add_fp8 convnd_fwd_wmma_convscale_add_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale_add example_convnd_fwd_wmma_convscale_add_fp8)
endif()

View File

@@ -0,0 +1,99 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/utility/tuple.hpp"
#include "convnd_fwd_convscale_add_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScaleAdd;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
ck::Tuple<DsLayout>, // DsLayout
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ck::Tuple<DsDataType>, // DsDataType
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_convscale_add_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -8,4 +8,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8)
endif()
endif()
# WMMA
if (GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_convnd_activ_wmma_convscale_reduce)
add_example_executable(example_convnd_fwd_wmma_convscale_amax_fp8 convnd_fwd_wmma_convscale_amax_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale_reduce example_convnd_fwd_wmma_convscale_amax_fp8)
endif()

View File

@@ -0,0 +1,94 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using ConvOutDataType = float; // data type of convolution result
using OutDataType = ck::f8_t; // data type of final result
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
ck::Tuple<>, // DsLayout
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ck::Tuple<>, // DsDataType
ConvOutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -6,3 +6,10 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8)
endif()
# WMMA
if (GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_convnd_activ_wmma_convscale_relu)
add_example_executable(example_convnd_fwd_wmma_convscale_relu_fp8 convnd_fwd_wmma_convscale_relu_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale_relu example_convnd_fwd_wmma_convscale_relu_fp8)
endif()

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_convscale_relu_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScaleRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvScaleRelu)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
#include "run_convnd_fwd_convscale_relu_example.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

View File

@@ -37,4 +37,10 @@ add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fw
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16)
# Logistic
add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp)
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16)
add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16)
# WMMA
add_custom_target(example_convnd_activ_dynamic_unary_wmma)
# PassThrough
add_example_executable(example_convnd_fwd_wmma_dynamic_passthrough_fp16 convnd_fwd_wmma_dynamic_passthrough_fp16.cpp)
add_example_dependencies(example_convnd_activ_dynamic_unary_wmma example_convnd_fwd_wmma_dynamic_passthrough_fp16)

View File

@@ -0,0 +1,245 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
constexpr ck::index_t NDimSpatial = 3;
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using OutDataType = ck::half_t;
using AComputeDataType = ck::half_t;
using BComputeDataType = ck::half_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// Use correct tensor layouts for WMMA (matching working tests)
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using DynamicElementOp = ck::tensor_operation::element_wise::DynamicUnaryOp;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGroupedConvNDActivInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
ck::Tuple<>, // DsLayout
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ck::Tuple<>, // DsDataType
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
DynamicElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
true, // UseThreadTileTransfer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const ck::HostTensorDescriptor& in_g_n_c_wis_desc,
const ck::HostTensorDescriptor& wei_g_k_c_xs_desc,
const ck::HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
ck::Tensor<InDataType> in(in_g_n_c_wis_desc);
ck::Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
ck::Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
ck::Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.05, 0.05});
}
ck::DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
ck::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
ck::DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error("The device op with the specified compilation parameters does "
"not support this convolution problem.");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
out_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op);
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-3, 0.1);
}
return true;
}

View File

@@ -0,0 +1,12 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "convnd_fwd_activ_dynamic_unary_wmma_common.hpp"
#include "../run_convnd_activ_dynamic_example.inc"
int main(int argc, char* argv[])
{
ck::tensor_operation::element_wise::PassThrough out_element_op;
return !run_convnd_example(argc, argv, out_element_op);
}

View File

@@ -47,6 +47,12 @@ bool run_convnd_example(int argc, char* argv[], const OutElementOp& out_element_
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
if(std::is_same_v<OutElementOp, ck::tensor_operation::element_wise::SoftRelu> &&
init_method != 2)
{
std::cout << "Running SoftRelu op with int initialization. Risk of overflow.\n\n";
}
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};

View File

@@ -106,7 +106,7 @@ struct bias_info
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;

View File

@@ -78,12 +78,14 @@ QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
}
QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
}
BIAS_MAP = {

View File

@@ -630,6 +630,7 @@ class KernelComponentFactory:
if dtype in ["fp16", "bf16"]:
return {
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
256 : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif dtype in ["fp8bf16"]:
return {
@@ -676,7 +677,7 @@ class KernelComponentFactory:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
["pertensor", "kv_blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
@@ -739,6 +740,10 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale requires page_size >= kN0 (tile.F_bn0)
# This ensures all tokens in a main loop iteration belong to the same page
if pipeline.F_qscale == "kv_blockscale" and page_size < tile.F_bn0:
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,

View File

@@ -602,6 +602,13 @@ struct fmha_batch_prefill_args
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
// k_descale_ptr/v_descale_ptr are reused for KV_BLOCKSCALE mode:
// k_descale_ptr: [num_block, num_kv_head] - points to k block descale
// v_descale_ptr: [num_block, num_kv_head] - points to v block descale
ck_tile::index_t nblock_stride_kv_block_descale = 0; // Stride along num_block dimension
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
template <typename FmhaKernel>
@@ -1225,7 +1232,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
else
{ // create batch mode kernel arguments
@@ -1278,7 +1287,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.nblock_stride_kv_block_descale,
args.nhead_stride_kv_block_descale);
}
}();

View File

@@ -191,7 +191,7 @@ struct mask_info
return area;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;

View File

@@ -8,12 +8,16 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
};
struct quant_scale_info
@@ -28,6 +32,8 @@ struct quant_scale_info
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
}
static quant_scale_info decode(std::string str)
@@ -45,6 +51,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "kvbs" || str == "3")
{
info.type = quant_scale_enum::kv_blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);
@@ -58,3 +68,4 @@ struct quant_scale_info
return os;
}
};
#pragma clang diagnostic pop

View File

@@ -59,7 +59,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
@@ -202,7 +203,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,

View File

@@ -44,7 +44,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
@@ -210,7 +211,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // APreshuffleQuant
false, // BPreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,

View File

@@ -21,7 +21,6 @@ if(has_supported_gpu)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
add_executable(tile_example_flatmm_basic flatmm_basic.cpp)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = 16 * packed_size; // fp4:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
int KPack =
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
@@ -295,7 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[])
}
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
{
throw std::runtime_error("fp6xfp6 is not supported.");
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
ck_tile::pk_fp6x16_t,
ck_tile::fp16_t,
MXfp6_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{

View File

@@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp6_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp8_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;

View File

@@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST)
set(C_LAYOUT ROW)
set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16")
set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16")
set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16")
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)

View File

@@ -19,6 +19,7 @@
using FP4 = ck_tile::pk_fp4_t;
using FP8 = ck_tile::fp8_t;
using FP6 = ck_tile::pk_fp6x16_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;

View File

@@ -68,24 +68,47 @@ int run_mx_flatmm_with_layouts(int argc,
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp6x16_t>)
{
auto a_buffer_bytes = a_host.get_element_space_size_in_bytes();
auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes();
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b);
std::vector<int8_t> random_bufA(a_buffer_bytes);
std::vector<int8_t> random_bufB(b_buffer_bytes);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> dis(1, 4);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
for(size_t i = 0; i < a_buffer_bytes; ++i)
random_bufA[i] = static_cast<int8_t>(dis(gen));
for(size_t i = 0; i < b_buffer_bytes; ++i)
random_bufB[i] = static_cast<int8_t>(dis(gen));
memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes);
memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
if(init_method == 0)
{
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
}
}
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);

View File

@@ -134,5 +134,65 @@ static auto _ = []() {
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"non-preshuffleb",
"preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp4", "abquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"fp4", "abquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::half_t,
float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::pk_fp4_raw_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
return 0;
}();

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
.insert("prec",
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4; for ABQuant: fp8, bf8")
"or bf8i4; for ABQuant: fp8, bf8, fp4")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -80,7 +80,8 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr bool PreshuffleQuant = false;
static constexpr bool APreshuffleQuant = false;
static constexpr bool BPreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false;
@@ -157,7 +158,8 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool PreshuffleQuant = true;
static constexpr bool APreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>
@@ -187,7 +189,7 @@ template <typename PrecType>
struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode
: public GemmConfigPreshuffleB_BQuant_Decode<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>
@@ -218,7 +220,7 @@ template <typename PrecType>
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>
@@ -272,7 +274,7 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename PrecType>

View File

@@ -9,6 +9,7 @@
#include <stdexcept>
#include <string>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/ops/common/utils.hpp"
@@ -33,11 +34,11 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped;
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
typename TypeConfig::ADataType>;
constexpr bool transpose_c =
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
// Use automatically determined compute type from
using ComputeDataType = void;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -50,14 +51,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
GemmConfig::APreshuffleQuant,
GemmConfig::BPreshuffleQuant,
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout, // for AQLayout
BQLayout, // for BQLayout
AQLayout,
BQLayout,
transpose_c,
GemmConfig::DoubleSmemBuffer>;
@@ -73,12 +75,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::ABQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>>;
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
@@ -146,7 +151,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
has_hot_loop_v,
tail_number_v>>>>;
using AQuantPipeline =
std::conditional_t<GemmConfig::PreshuffleQuant,
std::conditional_t<GemmConfig::APreshuffleQuant,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
@@ -180,30 +185,28 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
printf(
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
}
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename TypeConfig::ADataType,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
typename TypeConfig::ADataType,
typename TypeConfig::BDataType>,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
1,
false,
1,
TiledPermuteN>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
ck_tile::tuple<>,
typename TypeConfig::AccDataType,
typename TypeConfig::CDataType,
ck_tile::tuple<>,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transpose_c,
1,
false,
1,
TiledPermuteN>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
@@ -212,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
// Split-K validation is handled by Kernel::IsSupportedArgument
// Split-K is only supported for BQuantGrouped without preshuffle
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
@@ -390,8 +390,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< " APreshuffleQuant = " << (GemmConfig::APreshuffleQuant ? "true" : "false")
<< " : "
<< " BPreshuffleQuant = " << (GemmConfig::BPreshuffleQuant ? "true" : "false")
<< " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
@@ -536,21 +538,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
// Create BQ tensor with appropriate shape
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant)
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
}
std::mt19937 gen(42);
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
@@ -561,8 +555,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
@@ -598,18 +591,26 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
@@ -657,184 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
}
}
else if(init_method == 3)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
}
}
else if(init_method == 4)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else if(init_method == 5)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
}
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
for(ck_tile::index_t row = 0;
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
++row)
{
for(ck_tile::index_t col = 0;
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
++col)
{
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
}
}
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else
{
a_m_k.SetZero();
@@ -870,7 +693,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
{
if constexpr(GemmConfig::PreshuffleQuant)
if constexpr(GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK);
@@ -929,7 +752,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
ck_tile::HostTensor<BQDataType> bq_permuted_host =
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
if constexpr(GemmConfig::PreshuffleQuant)
if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host = ck_tile::shuffle_bq(
&bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK);
@@ -940,7 +763,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
}
}
else if constexpr(GemmConfig::PreshuffleQuant)
else if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK);
@@ -988,10 +811,14 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
if(arg_parser.get_int("v") == 1)
{
std::cout << "Performing CPU verification..." << std::endl;
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Track start time for reference operation
auto start_reference_tick = std::chrono::high_resolution_clock::now();
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
@@ -1055,6 +882,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
}
// Track where we stop reference calculation, and start verification
auto start_verification_tick = std::chrono::high_resolution_clock::now();
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
@@ -1065,6 +895,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
// "Stop" our timer
auto verification_finished_tick = std::chrono::high_resolution_clock::now();
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
@@ -1072,6 +905,21 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
<< std::endl;
}
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
// Calculate and display reference timing
using DurationType = std::chrono::duration<double>;
double reference_sec = std::chrono::duration_cast<DurationType>(verification_finished_tick -
start_reference_tick)
.count();
double verification_sec = std::chrono::duration_cast<DurationType>(
verification_finished_tick - start_verification_tick)
.count();
float reference_msec = static_cast<float>(reference_sec * 1e3);
float verification_msec = static_cast<float>(verification_sec * 1e3);
std::cout << std::fixed << std::setprecision(1) << "CPU reference GEMM took "
<< reference_msec << "ms, verification took " << verification_msec << "ms."
<< std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
@@ -1102,6 +950,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
}
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_fp4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf16_t>)
@@ -1121,7 +970,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
{
if(a_layout == "R" && b_layout == "R")
{
@@ -1142,7 +991,8 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
}
}
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
!GemmConfig::APreshuffleQuant)
{
if(a_layout == "C" && b_layout == "C")
{

View File

@@ -0,0 +1,156 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CMakeLists.txt for sparse attention (Jenga and VSA)
# Use SUPPORTED_GPU_TARGETS directly
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})
message(STATUS "Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}")
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine Sparse Attention: No supported GPU targets found")
return()
endif()
message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")
# Code generation scripts
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
)
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
# ============================================================================
# Jenga Sparse Attention
# ============================================================================
set(SPARSE_ATTN_JENGA_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd_jenga
--receipt 600
)
# Generate list of Jenga kernels (at configure time, only list)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate Jenga kernel list")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt SPARSE_ATTN_JENGA_GEN_BLOBS)
# Generate Jenga kernel source files at build time
add_custom_command(
OUTPUT ${SPARSE_ATTN_JENGA_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile Jenga Sparse Attention kernels"
)
message(STATUS "Jenga kernel files to be generated: ${SPARSE_ATTN_JENGA_GEN_BLOBS}")
# Jenga Instances
set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances")
add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARSE_ATTN_JENGA_GEN_BLOBS}
${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp
)
target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# Jenga Example executable
set(EXAMPLE_JENGA_SPARSE_ATTN "tile_example_jenga_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
# ============================================================================
# VSA Sparse Attention
# ============================================================================
set(SPARSE_ATTN_VSA_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd_vsa
--receipt 600
)
# Generate list of VSA kernels (at configure time, only list)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate VSA kernel list")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS)
# Generate VSA kernel source files at build time
add_custom_command(
OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile VSA Sparse Attention kernels"
)
message(STATUS "VSA kernel files to be generated: ${SPARSE_ATTN_VSA_GEN_BLOBS}")
# VSA Instances
set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances")
add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARSE_ATTN_VSA_GEN_BLOBS}
${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp
)
target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# VSA Example executable
set(EXAMPLE_VSA_SPARSE_ATTN "tile_example_vsa_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,3 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,73 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16": "FmhaSparseFwdFp16",
"bf16": "FmhaSparseFwdBf16",
}
_MASK_SIMPLIFIED_MAP = {
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no": "FmhaMasks::NoMask",
"causal": "FmhaMasks::CausalMask",
"generic": "FmhaMasks::GenericMask",
}
def get_mask_map(mask: str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
MODE_MAP = {"batch": "false"}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
}
PIPELINE_ENUM_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -0,0 +1,3 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

View File

@@ -0,0 +1,867 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
#include "kernel/fmha_fwd_jenga_kernel.hpp"
"""
# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdJengaKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_jenga_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_jenga_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_jenga_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize( # fmt: skip
32,
32,
128,
128,
32,
128,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
64,
32,
128,
16,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# jenga fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline( # fmt: skip
"qr_async",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async":
continue
k = FmhaFwdKernel(
F_idx=2,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,867 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
import os.path as path
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cpp_symbol_map import (
BOOL_MAP,
FWD_DTYPE_MAP,
LAYOUT_MAP,
MODE_MAP,
PIPELINE_ENUM_MAP,
PIPELINE_MAP,
get_mask_check_map,
get_mask_map,
)
GEN_DIR = ""
def update_file(file_path, content):
"""Update the file at file_path with the given content if it differs from the existing content.
It avoids unnecessary touching of the file which triggers rebuilds
"""
existing_content = ""
if path.exists(file_path):
with open(file_path, "r") as file:
existing_content = file.read()
if existing_content == content:
return
with open(file_path, "w") as file:
file.write(content)
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd_trek.hpp"
#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
#include "kernel/fmha_fwd_vsa_kernel.hpp"
"""
# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert:
# - Group mode: NOT supported (batch mode only)
# - Bias: NOT supported (NO_BIAS only)
# - LSE output: NOT supported (false only)
# - Dropout: NOT supported (false only)
# - Logits soft-cap: NOT supported (false only)
# - FP8 static quantization: NOT supported (NO_SCALE only)
# The template below hardcodes these unsupported features accordingly.
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum,
// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
false, // has_logits_soft_cap - NOT supported
ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported
false, // store_lse - NOT supported
false, // has_dropout - NOT supported
false, // has_randval - NOT supported
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported
{F_occupancy},
false>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported)
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
{F_trload},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaSparseFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdVSAKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
#include <iostream>
template<>
float fmha_vsa_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_vsa_fwd_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << "{F_kernel_name}" << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{
{F_dtype_case}
}}
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>;
return fmha_vsa_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return "true"
else:
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
spad: str
skpad: str
dpad: str
dvpad: str
tr_load: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.spad == "t":
return "true" # always support
return "true"
@property
def seqtune(self) -> str:
if self.bm0 == 128:
return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true
else:
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
@property
def dcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
assert False
@property
def dvcheck(self) -> str:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
assert False
@dataclass
class FmhaFwdPipeline:
tag: str
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_mask: str # value from MASK_MAP
F_trload: str # true/false
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
def pad_name() -> str:
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
n += "_npad"
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
n += "_nbias"
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
n += "_nskip"
n += "_nsquant"
if self.F_trload == "t":
n += "_trload"
else:
n += "_ntrload"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
hdim = trait.hdim, trait.bn1
if hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][hdim] = list()
self.pool[trait.dtype][hdim].append(copy.copy(trait))
@property
def api(self) -> str:
tr_load_cond_map = {"t": "has_load_tr", "f": "true"}
per_tr_load = str()
for tr_load in ["t", "f"]:
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case = str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits = [
t
for t in self.pool[dtype][(hdim, hdim_v)]
if tr_load == t.tr_load
]
inners = str()
for k, trait in enumerate(traits):
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
# F_logits removed - hardcoded to false (NOT supported)
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_trload=BOOL_MAP[trait.tr_load],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(
F_if="if",
F_trload_cond=tr_load_cond_map[tr_load],
F_dtype_case=per_dtypes,
)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load)
@dataclass
class FmhaFwdTileSize:
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=CppConstraint)
@property
def name(self) -> str:
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
# kernel_body removed - unused
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
# F_logits removed - hardcoded to false in template (NOT supported)
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=PIPELINE_MAP[self.F_pipeline.tag],
F_trload=BOOL_MAP[self.F_pipeline.F_trload],
F_kernel_name=self.name,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
tr_load=self.F_pipeline.F_trload,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize( # fmt: skip
32,
32,
128,
128,
32,
128,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
64,
32,
128,
16,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert)
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask in itertools.product(
["f"], # logits soft-cap NOT supported, always false
get_mask_map(mask_impl).keys(),
):
if hdim == 256 and hdim_v == 256:
# vsa fmha only supports dim <= 192 for now.
continue
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"f",
"t",
"t",
logits,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async_vsa",
"row",
"t",
"t",
"t",
"t",
logits,
mask,
"f",
)
)
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = (
CustomFactory
if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1"
else KernelComponentFactory
)
# Only generate fp16/bf16 kernels for now.
# NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert)
for dtype in ["fp16", "bf16"]:
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 128 or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async_vsa":
continue
k = FmhaFwdKernel(
F_idx=1,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= mode == "batch"
cond &= pipeline.F_logits == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -0,0 +1,328 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "01_fmha/mask.hpp"
#include <type_traits>
#include <utility>
#include <variant>
namespace ck_tile {
inline bool is_load_tr_supported() { return is_gfx95_supported(); }
} // namespace ck_tile
struct FmhaSparseFwdFp16
{
};
struct FmhaSparseFwdBf16
{
};
template <typename DataType>
struct FmhaSparseFwdTypeConfig;
template <>
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
// Note: The following types are required by BlockFmhaPipelineProblem but not used
// by sparse attention (bias, dropout, LSE are not supported).
using BiasDataType = ck_tile::half_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float;
};
template <>
struct FmhaSparseFwdTypeConfig<FmhaSparseFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
// Note: The following types are required by BlockFmhaPipelineProblem but not used
// by sparse attention (bias, dropout, LSE are not supported).
using BiasDataType = ck_tile::bf16_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// jenga
struct fmha_jenga_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* block_relation_onehot_ptr; // one-hot block map [B,H,Q_blk,K_blk], 1=active
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
// Dropout is not supported for sparse attention; keep args minimal.
};
// vsa
struct fmha_vsa_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk]
const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk]
void* o_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
// Dropout is not supported for sparse attention; keep args minimal.
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.block_relation_onehot_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.lut_ptr,
args.valid_block_num_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_>
struct fmha_jenga_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
};
struct fmha_jenga_fwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_v_rowmajor;
mask_enum mask_type;
// TODO: padding check is inside this api
};
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&);
// VSA uses the same traits structure as Jenga; aliases for clarity
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_>
using fmha_vsa_fwd_traits_ = fmha_jenga_fwd_traits_<HDim_,
DataType_,
kM0_,
kN0_,
kK0_,
kN1_,
kK1_,
kK0BlockLength_,
kIsVLayoutRowMajor_,
FmhaPipelineEnum_,
kHasLogitsSoftCap_,
FmhaMask_,
kPadS_,
kPadSK_,
kPadD_,
kPadDv_,
kUseTrLoad_>;
using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits;
float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&);
template <typename Traits_>
float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args);
float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&);

View File

@@ -0,0 +1,166 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
from typing import List, Optional
import codegen.ops
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = "fmha_"
handlers = dict(
[
(
op.__name__[len(unwanted_prefix) :]
if op.__name__.startswith(unwanted_prefix)
else op.__name__,
(op.list_blobs, op.write_blobs),
)
for op in ops
]
)
assert 0 < len(handlers)
def write_blobs(
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
assert output_file is not None
file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK fmha kernel",
)
parser.add_argument(
"-d",
"--direction", # we keep 'direction' option for backward compatibility
"-a",
"--api",
default="fwd_jenga",
required=False,
help="supply API(s) to generate (default: fwd). separated by comma.",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory",
)
parser.add_argument(
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default="",
required=False,
help="filter out kernels that need to generate, using fnmatch module",
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic",
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n"
+ " 1: generate more instance to cover all hdim\n"
+ " 2: Only generate instance for Flash attention integration\n"
+ " 4: Only generate instance for PyTorch integration\n"
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
)
parser.add_argument(
"--optdim",
default="-1",
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
+ "eg. --optdim=32,64,128,256",
)
args = parser.parse_args()
api_list = args.direction.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
if args.list_blobs is not None:
list_blobs(
args.list_blobs,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)
else:
write_blobs(
args.output_dir,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)

View File

@@ -0,0 +1,199 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "jenga_sparse_attention.h"
#include "fmha_fwd_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"Jenga sparse attention supports fp16/bf16 only.");
// Determine data type string based on template parameter
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
if(max_seqlen_k == 0)
max_seqlen_k = seqlen_k;
bool is_v_rowmajor = true;
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
const ck_tile::index_t shape_seqlen_q = seqlen_q;
const ck_tile::index_t shape_seqlen_k = seqlen_k;
ck_tile::stream_config stream_config{nullptr,
false, // time_kernel
log_level,
0,
1,
false};
// Create device memory and copy data to device
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
v_buf.ToDevice(TV.data());
block_relation_buf.ToDevice(Tblock_relation_onehot.data());
const auto init_args = [&](auto& args) {
assert(nhead % nhead_k == 0);
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// Use device buffer pointers instead of host tensor data pointers
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // batch mode only
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k; // batch mode only
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.stride_o = stride_o;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
// Dropout not supported for sparse attention.
};
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
};
fmha_jenga_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_jenga_fwd_args args;
init_args(args);
fmha_jenga_fwd(fmha_traits, args, stream_config);
// Copy output back to host without changing tensor shape
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
return Y;
}
// Explicit template instantiations
template ck_tile::HostTensor<ck_tile::half_t>
jenga_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<uint8_t>&,
ck_tile::HostTensor<ck_tile::half_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);
template ck_tile::HostTensor<ck_tile::bf16_t>
jenga_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<uint8_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <optional>
#include <cstdint>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename DataType_>
ck_tile::HostTensor<DataType_>
jenga_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<uint8_t>& Tblock_relation_onehot,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level = 0);
template <typename DataType_>
ck_tile::HostTensor<DataType_> vsa_sparse_attention(
const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<int32_t>& TKV_block_idx, // LUT must be int32_t
const ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level = 0);

View File

@@ -0,0 +1,423 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Test for jenga_sparse_attention function
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include <chrono>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "jenga_sparse_attention.h"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
// Get error tolerance based on data type
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
// bf16 accumulation/rounding can be noisier in sparse patterns
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
// Parse arguments
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t block_size = arg_parser.get_int("block_size");
float sparsity = arg_parser.get_float("sparsity");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Handle default values
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
ck_tile::index_t BLKQ = block_size;
ck_tile::index_t BLKK = block_size;
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "Jenga kernel instances are generated for block_size=128 and hdim=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
// Calculate number of Q and K blocks
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[Jenga Sparse Attention Test]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
<< std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " sparsity: " << sparsity << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors (using BHSD layout when i_perm=true)
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
// Block relation onehot: [B, H, Q_blocks, K_blocks]
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
// Initialize tensors with random values
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Initialize block_relation_onehot with sparse pattern
std::mt19937 rng(seed + 100);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
ck_tile::index_t total_blocks = 0;
ck_tile::index_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
bool is_diagonal = (qb == kb && qb < num_k_blocks);
bool random_active = (dist(rng) > sparsity);
if(is_diagonal || random_active)
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
active_blocks++;
}
else
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
}
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
// Run kernel
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
try
{
if(kname)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
// Warmup
for(int i = 0; i < warmup; ++i)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
// Benchmark
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; ++i)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
// Validation
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
std::cout << "Computing reference output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
// Compare results
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -0,0 +1,486 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Test for vsa_sparse_attention function
// Based on the Python test: test_jenga_attention.py
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "jenga_sparse_attention.h"
#include "fmha_fwd_trek.hpp"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
// Convert block_relation_onehot to LUT format (similar to triton_block_map_to_lut_kernel)
template <typename T>
void block_map_to_lut(
const ck_tile::HostTensor<T>& block_map, // [B, H, Q_blocks, K_blocks]
ck_tile::HostTensor<int32_t>& lut, // [B, H, Q_blocks, K_blocks] - int32_t for kernel
ck_tile::HostTensor<int32_t>& valid_block_num, // [B, H, Q_blocks] - int32_t for kernel
ck_tile::index_t num_block_k)
{
auto lengths = block_map.get_lengths();
ck_tile::index_t B = lengths[0];
ck_tile::index_t H = lengths[1];
ck_tile::index_t Q = lengths[2];
for(ck_tile::index_t b = 0; b < B; ++b)
{
for(ck_tile::index_t h = 0; h < H; ++h)
{
for(ck_tile::index_t q = 0; q < Q; ++q)
{
int32_t valid_count = 0;
int32_t prev_block = 0;
for(ck_tile::index_t k = 0; k < num_block_k; ++k)
{
T cur_block = block_map(b, h, q, k);
if(static_cast<float>(cur_block) > 0.5f)
{ // Check if block is active
lut(b, h, q, valid_count) = static_cast<int32_t>(k - prev_block);
valid_count++;
prev_block = static_cast<int32_t>(k);
}
}
valid_block_num(b, h, q) = valid_count;
}
}
}
}
// Get error tolerance based on data type
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
// bf16 accumulation/rounding can be noisier in sparse patterns
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)")
.insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
// Parse arguments
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t block_size = arg_parser.get_int("block_size");
float sparsity = arg_parser.get_float("sparsity");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Handle default values
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
ck_tile::index_t BLKQ = block_size;
ck_tile::index_t BLKK = block_size;
if(block_size != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "VSA kernel instances are generated for block_size=128 and hdim=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
// Calculate number of Q and K blocks
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[VSA Sparse Attention Test]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")"
<< std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " sparsity: " << sparsity << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors (using BHSD layout when i_perm=true)
// Q: [B, H, S_q, D]
// K: [B, H_k, S_k, D]
// V: [B, H_k, S_k, D_v]
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
// Block relation onehot: [B, H, Q_blocks, K_blocks]
ck_tile::HostTensor<uint8_t> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
// LUT and valid_block_num (output of block_map_to_lut) - must be int32_t for kernel
ck_tile::HostTensor<int32_t> lut_host({batch, nhead, num_q_blocks, num_k_blocks});
ck_tile::HostTensor<int32_t> valid_block_num_host({batch, nhead, num_q_blocks});
// Initialize tensors with random values
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Initialize block_relation_onehot with sparse pattern
std::mt19937 rng(seed + 100);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
ck_tile::index_t total_blocks = 0;
ck_tile::index_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
// Each Q block always attends to its diagonal K block (if exists)
// Plus random blocks based on sparsity
bool is_diagonal = (qb == kb && qb < num_k_blocks);
bool random_active = (dist(rng) > sparsity);
if(is_diagonal || random_active)
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(1);
active_blocks++;
}
else
{
block_relation_onehot(b, h, qb, kb) = static_cast<uint8_t>(0);
}
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
// Convert block_relation_onehot to LUT format
std::cout << "Converting block map to LUT format..." << std::endl;
block_map_to_lut(block_relation_onehot, lut_host, valid_block_num_host, num_k_blocks);
// vsa_sparse_attention handles device memory internally
// Run kernel
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
try
{
// Print kernel name once by invoking with log_level=1.
// This is separate from warmup/benchmark to avoid polluting timing.
if(kname)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
// Warmup
for(int i = 0; i < warmup; ++i)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
// Benchmark
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; ++i)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
lut_host,
valid_block_num_host,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
// Note: vsa_sparse_attention already returns output in output_host
// Validation
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
// Compute scale factor
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
// Run reference implementation
std::cout << "Computing reference output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
// Compare results
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "jenga_sparse_attention.h"
#include "fmha_fwd_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
template <typename DataType_>
ck_tile::HostTensor<DataType_>
vsa_sparse_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
const ck_tile::HostTensor<int32_t>& TKV_blocks,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"VSA sparse attention supports fp16/bf16 only.");
// Determine data type string based on template parameter
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
if(max_seqlen_k == 0)
max_seqlen_k = seqlen_k;
bool is_v_rowmajor = true;
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
const ck_tile::index_t shape_seqlen_q = seqlen_q;
const ck_tile::index_t shape_seqlen_k = seqlen_k;
ck_tile::stream_config stream_config{nullptr,
false, // time_kernel
log_level,
0,
1,
false};
// Create device memory and copy data to device
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes());
ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
v_buf.ToDevice(TV.data());
lut_buf.ToDevice(TKV_block_idx.data());
valid_block_num_buf.ToDevice(TKV_blocks.data());
const auto init_args = [&](auto& args) {
assert(nhead % nhead_k == 0);
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// Use device buffer pointers instead of host tensor data pointers
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.lut_ptr = lut_buf.GetDeviceBuffer();
args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // batch mode only
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k; // batch mode only
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.stride_o = stride_o;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
// Dropout not supported for sparse attention.
};
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
};
fmha_vsa_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_vsa_fwd_args args;
init_args(args);
fmha_vsa_fwd(fmha_traits, args, stream_config);
// Copy output back to host without changing tensor shape
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
return Y;
}
// Explicit template instantiations
template ck_tile::HostTensor<ck_tile::half_t>
vsa_sparse_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<int32_t>&,
const ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<ck_tile::half_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);
template ck_tile::HostTensor<ck_tile::bf16_t>
vsa_sparse_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<int32_t>&,
const ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);

View File

@@ -32,3 +32,5 @@ add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)
add_subdirectory(99_toy_example)
add_subdirectory(99_toy_tutorial)
add_subdirectory(50_sparse_attn)