mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
f8 mfma issue
This commit is contained in:
@@ -38,6 +38,12 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3)
|
||||
|
||||
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
|
||||
set(GEMM_OPTIONS)
|
||||
# list(APPEND GEMM_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
|
||||
list(APPEND GEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
target_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950)
|
||||
set(target 0)
|
||||
|
||||
@@ -28,10 +28,10 @@ using DeviceGemmV2Instance =
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
224, 256,
|
||||
128, 128,
|
||||
128, 16, 16,
|
||||
16, 16,
|
||||
7, 8,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
|
||||
@@ -178,7 +178,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 32 : 64;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
|
||||
@@ -1120,7 +1120,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -1132,7 +1136,11 @@ struct MfmaSelector
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
@@ -954,11 +954,11 @@ struct vector_type<T, 128, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
StaticallyIndexedArray<d32_t, 4> d32x4_;
|
||||
StaticallyIndexedArray<d64_t, 2> d64x2_;
|
||||
StaticallyIndexedArray<d128_t, 1> d128x1_;
|
||||
} data_;
|
||||
} data_ = {d128_t{0}};
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type() {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; }
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
@@ -1082,11 +1082,11 @@ struct vector_type<T, 256, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
StaticallyIndexedArray<d64_t, 4> d64x4_;
|
||||
StaticallyIndexedArray<d128_t, 2> d128x2_;
|
||||
StaticallyIndexedArray<d256_t, 1> d256x1_;
|
||||
} data_;
|
||||
} data_ = {d256_t{0}};
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type() {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; }
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
|
||||
Reference in New Issue
Block a user