mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
@@ -45,7 +45,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(example_gemm_xdl_fp8_pk_i4_v3 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
target_compile_options(example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
@@ -5,8 +5,6 @@ add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
|
||||
# add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
|
||||
# add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx942)
|
||||
set(target 0)
|
||||
|
||||
@@ -472,7 +472,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}, {BlockGemmPipelineVersion::v3, "v3"}};
|
||||
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmm"
|
||||
|
||||
@@ -81,9 +81,27 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
|
||||
__device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
{
|
||||
#if 0
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
|
||||
int lo = amd_assembly_and_b32(q, LO);
|
||||
int hi = amd_assembly_and_b32(q, HI);
|
||||
|
||||
float f32_0 = amd_assemble_cvt_f32_i4(lo);
|
||||
float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16);
|
||||
float f32_2 = amd_assemble_cvt_f32_i4(hi);
|
||||
float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16);
|
||||
|
||||
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
|
||||
#else
|
||||
// [0, 1, 2, 3] encoded as FP8
|
||||
static constexpr uint32_t POS_E4M3s_TABLE1 = 0x2C282000;
|
||||
// [4, 5, 6, 7] encoded as FP8
|
||||
static constexpr uint32_t POS_E4M3s_TABLE2 = 0x36343230;
|
||||
// [-8, -7, -6, -5] encoded as FP8
|
||||
static constexpr uint32_t NEG_E4M3s_TABLE1 = 0xB2B4B6B8;
|
||||
// [-4, -3, -2, -1] encoded as FP8
|
||||
static constexpr uint32_t NEG_E4M3s_TABLE2 = 0xA0A8ACB0;
|
||||
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res, final_sel;
|
||||
@@ -103,6 +121,7 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
|
||||
|
||||
return res.template AsType<f8x4_t>()[Number<0>{}];
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q)
|
||||
|
||||
Reference in New Issue
Block a user