mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix the vector load & fix the gfx950 compv4 error (#2831)
[ROCm/composable_kernel commit: 1894a0dbc3]
This commit is contained in:
@@ -1335,8 +1335,10 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
static_assert(
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, bf16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, int32_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
@@ -1449,14 +1451,19 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence)));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
else
|
||||
{
|
||||
// use fp32 load to mimic fp16 load
|
||||
fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
// N >= 8: build from fp32x4 chunks
|
||||
thread_buffer<float, N / 2> tmp;
|
||||
|
||||
static_for<0, (N / 8), 1>{}([&](auto i) {
|
||||
constexpr index_t chunk = i;
|
||||
tmp.template get_as<fp32x4_t>()(i) = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + (chunk * 4) * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
});
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
@@ -1486,13 +1493,19 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence)));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
else
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
// N >= 8: build from fp32x4 chunks
|
||||
thread_buffer<float, N / 2> tmp;
|
||||
|
||||
static_for<0, (N / 8), 1>{}([&](auto i) {
|
||||
constexpr index_t chunk = i;
|
||||
tmp.template get_as<fp32x4_t>()(i) = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + (chunk * 4) * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
});
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,17 +20,18 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr bool single_load_tr_length =
|
||||
(DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) ==
|
||||
(WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size());
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
((is_a_load_tr<Problem> || is_b_load_tr<Problem>) && !single_load_tr_length)
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
|
||||
@@ -138,6 +138,7 @@ template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16,
|
||||
using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
//WMMA cases
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
|
||||
Reference in New Issue
Block a user