Merge remote-tracking branch 'origin/lwpck-3447' into mmflat

This commit is contained in:
AviralGoelAMD
2025-07-16 21:32:40 +00:00
23 changed files with 90 additions and 79 deletions

9
Jenkinsfile vendored
View File

@@ -234,11 +234,6 @@ def cmake_build(Map conf=[:]){
def build_type_debug = (conf.get("build_type",'release') == 'debug')
// use special compiler for gfx950
if ( check_arch() == 7){
compiler = "/llvm-project/build/bin/clang++"
}
//cmake_env can overwrite default CXX variables.
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
@@ -1352,12 +1347,12 @@ pipeline {
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx950" \
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
}
}

View File

@@ -68,3 +68,6 @@ endif()
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS})
target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0)
target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0)

View File

@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
static const double epsilon = std::numeric_limits<float>::epsilon();
bool pass = true;
if(argc > 1)

View File

@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
bool pass = true;
static const double epsilon = std::numeric_limits<float>::epsilon();
bool pass = true;
if(argc > 1)
{

View File

@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
return (pass);
};
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
int main(int argc, char* argv[])
{
bool pass = true;
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
bool pass = true;
if(argc > 1)
{

View File

@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
return (pass);
};
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
int main(int argc, char* argv[])
{
bool pass = true;
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
bool pass = true;
if(argc > 1)
{

View File

@@ -128,6 +128,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
add_dependencies(examples ${EXAMPLE_NAME})

View File

@@ -243,8 +243,8 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
@@ -265,8 +265,8 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;

View File

@@ -220,7 +220,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
auto [result, arg_parser] = create_args(argc, argv);
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && a_layout != "R" && b_layout != "C")
if(preshuffle && (a_layout != "R" || b_layout != "C"))
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");

View File

@@ -315,8 +315,16 @@ int run_gemm_example_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
if constexpr(preshuffle)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
}
else if(init_method == 1)
{

View File

@@ -18,7 +18,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}

View File

@@ -467,7 +467,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
return make_naive_tensor_descriptor_packed(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
}
@@ -1474,7 +1474,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1567,7 +1567,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
@@ -2185,7 +2185,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
get_warp_local_1d_id() % NWave,
0,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2289,7 +2289,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
get_warp_local_1d_id() % NWave,
0,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(

View File

@@ -1396,8 +1396,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1427,8 +1427,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
@@ -1459,8 +1459,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1490,8 +1490,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
@@ -1522,8 +1522,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1553,8 +1553,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
@@ -1585,8 +1585,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1616,8 +1616,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,

View File

@@ -8,6 +8,7 @@
#include <cstring>
#include <string>
#include <string_view>
#include <map>
namespace ck {
namespace internal {

View File

@@ -33,7 +33,7 @@ __device__ void block_sync_lds_direct_load()
{
#ifdef __gfx12__
asm volatile("\
s_wait_vmcnt 0x0 \n \
s_wait_loadcnt 0x0 \n \
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \

View File

@@ -74,7 +74,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));

View File

@@ -196,7 +196,7 @@ struct GemmKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
// clang-format on
}

View File

@@ -57,7 +57,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
@@ -95,7 +95,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static constexpr auto
CK_TILE_HOST static auto
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
{
index_t grid_size = 0;

View File

@@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
@@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
@@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
@@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
#elif defined(__gfx908__) || defined(__gfx90a__)
CVecType c_vec{0.f};
static_for<0, 8, 1>{}([&](auto k) {
@@ -1580,7 +1580,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
#if defined(__gfx94__) or defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
@@ -1650,7 +1650,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
{
#if defined(__gfx94__) or defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
@@ -1709,7 +1709,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
{
#if defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
@@ -1767,8 +1767,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
else
{
#if defined(__gfx95__)
c_vec =
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;

View File

@@ -6,7 +6,9 @@
#include <iomanip>
#include <iostream>
#include <typeinfo>
#if defined(__unix__)
#include <unistd.h>
#endif
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
@@ -213,7 +215,9 @@ int profile_gemm_impl(int do_verification,
instance_id++;
}
#if defined(__unix__)
sleep(2);
#endif
// Run the best instance again
{

View File

@@ -5,6 +5,7 @@
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <inttypes.h>
#include "profiler/profile_batched_gemm_b_scale_impl.hpp"
#include "profiler_operation_registry.hpp"
@@ -114,7 +115,7 @@ int profile_batched_gemm_b_scale(int argc, char* argv[])
n_iter = std::stoi(argv[18]);
rotating = std::stoull(argv[19]) * 1024 * 1024;
printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating);
printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating);
}
using F32 = float;

View File

@@ -5,6 +5,7 @@
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <inttypes.h>
#include "profiler/profile_gemm_b_scale_impl.hpp"
#include "profiler_operation_registry.hpp"
@@ -100,7 +101,7 @@ int profile_gemm_b_scale(int argc, char* argv[])
n_iter = std::stoi(argv[17]);
rotating = std::stoull(argv[18]) * 1024 * 1024;
printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating);
printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating);
}
using F32 = float;

View File

@@ -140,8 +140,8 @@ union pixel
{
struct __attribute__((packed))
{
unsigned int r : 6;
unsigned int c : 10;
ushort r : 6;
ushort c : 10;
};
ushort data;
};