mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Fix build errors on windows (#2456)
* Fix build errors on windows * correct clang format --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
This commit is contained in:
@@ -68,3 +68,6 @@ endif()
|
|||||||
|
|
||||||
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
|
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
|
||||||
target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS})
|
target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS})
|
||||||
|
target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0)
|
||||||
|
target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0)
|
||||||
|
|
||||||
|
|||||||
@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
|||||||
return (pass);
|
return (pass);
|
||||||
};
|
};
|
||||||
|
|
||||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
|
||||||
|
|
||||||
int main(int argc, char* argv[])
|
int main(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
|
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||||
|
|
||||||
bool pass = true;
|
bool pass = true;
|
||||||
|
|
||||||
if(argc > 1)
|
if(argc > 1)
|
||||||
|
|||||||
@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
|||||||
return (pass);
|
return (pass);
|
||||||
};
|
};
|
||||||
|
|
||||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
|
||||||
|
|
||||||
int main(int argc, char* argv[])
|
int main(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
bool pass = true;
|
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||||
|
bool pass = true;
|
||||||
|
|
||||||
if(argc > 1)
|
if(argc > 1)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
|||||||
return (pass);
|
return (pass);
|
||||||
};
|
};
|
||||||
|
|
||||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
|
||||||
static const double averageFactor = 0.1;
|
|
||||||
|
|
||||||
int main(int argc, char* argv[])
|
int main(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
bool pass = true;
|
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||||
|
static const double averageFactor = 0.1;
|
||||||
|
bool pass = true;
|
||||||
|
|
||||||
if(argc > 1)
|
if(argc > 1)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
|||||||
return (pass);
|
return (pass);
|
||||||
};
|
};
|
||||||
|
|
||||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
|
||||||
static const double averageFactor = 0.1;
|
|
||||||
|
|
||||||
int main(int argc, char* argv[])
|
int main(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
bool pass = true;
|
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||||
|
static const double averageFactor = 0.1;
|
||||||
|
bool pass = true;
|
||||||
|
|
||||||
if(argc > 1)
|
if(argc > 1)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
|||||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||||
|
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
|
||||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||||
add_dependencies(examples ${EXAMPLE_NAME})
|
add_dependencies(examples ${EXAMPLE_NAME})
|
||||||
|
|||||||
@@ -1396,8 +1396,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
|
|||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@@ -1427,8 +1427,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
|||||||
{
|
{
|
||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@@ -1459,8 +1459,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
|
|||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@@ -1490,8 +1490,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
|
|||||||
{
|
{
|
||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@@ -1522,8 +1522,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
|
|||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@@ -1553,8 +1553,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
|
|||||||
{
|
{
|
||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@@ -1585,8 +1585,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
|
|||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@@ -1616,8 +1616,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
|||||||
{
|
{
|
||||||
#if defined(__gfx94__)
|
#if defined(__gfx94__)
|
||||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||||
bit_cast<long>(reg_a),
|
bit_cast<int64_t>(reg_a),
|
||||||
bit_cast<long>(reg_b),
|
bit_cast<int64_t>(reg_b),
|
||||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ __device__ void block_sync_lds_direct_load()
|
|||||||
{
|
{
|
||||||
#ifdef __gfx12__
|
#ifdef __gfx12__
|
||||||
asm volatile("\
|
asm volatile("\
|
||||||
s_wait_vmcnt 0x0 \n \
|
s_wait_loadcnt 0x0 \n \
|
||||||
s_wait_dscnt 0x0 \n \
|
s_wait_dscnt 0x0 \n \
|
||||||
s_barrier_signal -1 \n \
|
s_barrier_signal -1 \n \
|
||||||
s_barrier_wait -1 \
|
s_barrier_wait -1 \
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
using P_ = GemmPipeline;
|
using P_ = GemmPipeline;
|
||||||
|
|
||||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
|
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
|
||||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ struct GemmKernel
|
|||||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||||
{
|
{
|
||||||
// clang-format off
|
// clang-format off
|
||||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
using P_ = GemmPipeline;
|
using P_ = GemmPipeline;
|
||||||
|
|
||||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
|
||||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||||
@@ -95,7 +95,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
|||||||
return dim3(grid_size, 1, 1);
|
return dim3(grid_size, 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
CK_TILE_HOST static constexpr auto
|
CK_TILE_HOST static auto
|
||||||
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||||
{
|
{
|
||||||
index_t grid_size = 0;
|
index_t grid_size = 0;
|
||||||
|
|||||||
@@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
|||||||
#if defined(__gfx94__) or defined(__gfx95__)
|
#if defined(__gfx94__) or defined(__gfx95__)
|
||||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
#else
|
#else
|
||||||
ck_tile::ignore = c_vec;
|
ck_tile::ignore = c_vec;
|
||||||
ck_tile::ignore = a_vec;
|
ck_tile::ignore = a_vec;
|
||||||
@@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
|||||||
#if defined(__gfx94__) or defined(__gfx95__)
|
#if defined(__gfx94__) or defined(__gfx95__)
|
||||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
#else
|
#else
|
||||||
ck_tile::ignore = a_vec;
|
ck_tile::ignore = a_vec;
|
||||||
ck_tile::ignore = b_vec;
|
ck_tile::ignore = b_vec;
|
||||||
@@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
|||||||
#if defined(__gfx94__) or defined(__gfx95__)
|
#if defined(__gfx94__) or defined(__gfx95__)
|
||||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||||
static_for<0, 8, 1>{}([&](auto k) {
|
static_for<0, 8, 1>{}([&](auto k) {
|
||||||
float a_f32 =
|
float a_f32 =
|
||||||
@@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
|||||||
#if defined(__gfx94__) or defined(__gfx95__)
|
#if defined(__gfx94__) or defined(__gfx95__)
|
||||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||||
CVecType c_vec{0.f};
|
CVecType c_vec{0.f};
|
||||||
static_for<0, 8, 1>{}([&](auto k) {
|
static_for<0, 8, 1>{}([&](auto k) {
|
||||||
@@ -1580,7 +1580,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
|||||||
{
|
{
|
||||||
#if defined(__gfx94__) or defined(__gfx95__)
|
#if defined(__gfx94__) or defined(__gfx95__)
|
||||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
|
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||||
static_for<0, 8, 1>{}([&](auto k) {
|
static_for<0, 8, 1>{}([&](auto k) {
|
||||||
float a_f32 =
|
float a_f32 =
|
||||||
@@ -1650,7 +1650,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
|
|||||||
{
|
{
|
||||||
#if defined(__gfx94__) or defined(__gfx95__)
|
#if defined(__gfx94__) or defined(__gfx95__)
|
||||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
|
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
#else
|
#else
|
||||||
ck_tile::ignore = c_vec;
|
ck_tile::ignore = c_vec;
|
||||||
ck_tile::ignore = a_vec;
|
ck_tile::ignore = a_vec;
|
||||||
@@ -1709,7 +1709,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
|
|||||||
{
|
{
|
||||||
#if defined(__gfx95__)
|
#if defined(__gfx95__)
|
||||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
||||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
#else
|
#else
|
||||||
ck_tile::ignore = c_vec;
|
ck_tile::ignore = c_vec;
|
||||||
ck_tile::ignore = a_vec;
|
ck_tile::ignore = a_vec;
|
||||||
@@ -1767,8 +1767,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
#if defined(__gfx95__)
|
#if defined(__gfx95__)
|
||||||
c_vec =
|
c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
|
||||||
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||||
#else
|
#else
|
||||||
ck_tile::ignore = c_vec;
|
ck_tile::ignore = c_vec;
|
||||||
ck_tile::ignore = a_vec;
|
ck_tile::ignore = a_vec;
|
||||||
|
|||||||
@@ -6,7 +6,9 @@
|
|||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <typeinfo>
|
#include <typeinfo>
|
||||||
|
#if defined(__unix__)
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "ck/ck.hpp"
|
#include "ck/ck.hpp"
|
||||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||||
@@ -213,7 +215,9 @@ int profile_gemm_impl(int do_verification,
|
|||||||
instance_id++;
|
instance_id++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(__unix__)
|
||||||
sleep(2);
|
sleep(2);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Run the best instance again
|
// Run the best instance again
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <inttypes.h>
|
||||||
|
|
||||||
#include "profiler/profile_batched_gemm_b_scale_impl.hpp"
|
#include "profiler/profile_batched_gemm_b_scale_impl.hpp"
|
||||||
#include "profiler_operation_registry.hpp"
|
#include "profiler_operation_registry.hpp"
|
||||||
@@ -114,7 +115,7 @@ int profile_batched_gemm_b_scale(int argc, char* argv[])
|
|||||||
n_iter = std::stoi(argv[18]);
|
n_iter = std::stoi(argv[18]);
|
||||||
rotating = std::stoull(argv[19]) * 1024 * 1024;
|
rotating = std::stoull(argv[19]) * 1024 * 1024;
|
||||||
|
|
||||||
printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating);
|
printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating);
|
||||||
}
|
}
|
||||||
|
|
||||||
using F32 = float;
|
using F32 = float;
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <inttypes.h>
|
||||||
|
|
||||||
#include "profiler/profile_gemm_b_scale_impl.hpp"
|
#include "profiler/profile_gemm_b_scale_impl.hpp"
|
||||||
#include "profiler_operation_registry.hpp"
|
#include "profiler_operation_registry.hpp"
|
||||||
@@ -100,7 +101,7 @@ int profile_gemm_b_scale(int argc, char* argv[])
|
|||||||
n_iter = std::stoi(argv[17]);
|
n_iter = std::stoi(argv[17]);
|
||||||
rotating = std::stoull(argv[18]) * 1024 * 1024;
|
rotating = std::stoull(argv[18]) * 1024 * 1024;
|
||||||
|
|
||||||
printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating);
|
printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating);
|
||||||
}
|
}
|
||||||
|
|
||||||
using F32 = float;
|
using F32 = float;
|
||||||
|
|||||||
@@ -140,8 +140,8 @@ union pixel
|
|||||||
{
|
{
|
||||||
struct __attribute__((packed))
|
struct __attribute__((packed))
|
||||||
{
|
{
|
||||||
unsigned int r : 6;
|
ushort r : 6;
|
||||||
unsigned int c : 10;
|
ushort c : 10;
|
||||||
};
|
};
|
||||||
ushort data;
|
ushort data;
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user