mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Grouped convolution backward weight special vector size loads (#1772)
* Grouped convolution backward weight special vector size loads
* Instnaces and tests
* Fixes
* Add 7 and 13 special cases
* fix comments
* Fix
* Fix2
* fixes
* fix atomic add bf16
[ROCm/composable_kernel commit: fd46a01d8b]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
const bool is_w_pad_zero = arg.input_left_pads_[NDimSpatial - 1] == 0 &&
|
||||
arg.input_right_pads_[NDimSpatial - 1] == 0;
|
||||
const auto X = arg.filter_spatial_lengths_[NDimSpatial - 1];
|
||||
const bool XC_access_allowed = arg.Conv_G_ == 1 &&
|
||||
(arg.Conv_C_ * X) % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
is_w_pad_zero;
|
||||
|
||||
if(!((arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 || XC_access_allowed) &&
|
||||
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1))
|
||||
if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1 &&
|
||||
NumGroupsToMerge > 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1))
|
||||
if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 &&
|
||||
NumGroupsToMerge > 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -53,7 +53,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
static constexpr auto I10 = Number<10>{};
|
||||
static constexpr auto I12 = Number<12>{};
|
||||
static constexpr auto I13 = Number<13>{};
|
||||
static constexpr auto I14 = Number<14>{};
|
||||
static constexpr auto I16 = Number<16>{};
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
dst_vector_type op_r_v;
|
||||
@@ -234,14 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
|
||||
using VectorSizeLookupTable = Tuple<Sequence<>,
|
||||
Sequence<I1>,
|
||||
Sequence<I2>,
|
||||
Sequence<I2, I1>,
|
||||
Sequence<I4>,
|
||||
Sequence<I4, I1>,
|
||||
Sequence<I4, I2>,
|
||||
Sequence<I4, I2, I1>,
|
||||
Sequence<I8>,
|
||||
Sequence<I8, I1>,
|
||||
Sequence<I8, I2>,
|
||||
Sequence<I8, I2, I1>,
|
||||
Sequence<I8, I4>,
|
||||
Sequence<I8, I4, I1>,
|
||||
Sequence<I8, I4, I2>,
|
||||
Sequence<I8, I4, I2, I1>,
|
||||
Sequence<I16>>;
|
||||
using VectorOffsetsLookupTable = Tuple<Sequence<>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0, I2>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0, I4>,
|
||||
Sequence<I0, I4>,
|
||||
Sequence<I0, I4, I6>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0, I8>,
|
||||
Sequence<I0, I8>,
|
||||
Sequence<I0, I8, I10>,
|
||||
Sequence<I0, I8>,
|
||||
Sequence<I0, I8, I12>,
|
||||
Sequence<I0, I8, I12>,
|
||||
Sequence<I0, I8, I12, I14>,
|
||||
Sequence<I0>>;
|
||||
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
|
||||
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
static_for<0, tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::Size(), 1>{}(
|
||||
[&](auto v_idx) {
|
||||
constexpr auto VectorLoadSize =
|
||||
tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::At(v_idx);
|
||||
constexpr auto LoadOffset =
|
||||
tuple_element_t<SrcScalarPerVector, VectorOffsetsLookupTable>::At(v_idx);
|
||||
|
||||
using src_vector_container = vector_type_maker_t<SrcData, VectorLoadSize>;
|
||||
using src_vector_container_t = typename src_vector_container::type;
|
||||
|
||||
src_vector_container src_vector =
|
||||
src_vector_container{src_buf.template Get<src_vector_container_t>(
|
||||
src_coord_.GetOffset() / PackedSize + LoadOffset, true)};
|
||||
|
||||
static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if
|
||||
// needed
|
||||
src_element_op_(
|
||||
op_r_v.template AsType<dst_elem_op_vec_t>()(idx + LoadOffset),
|
||||
src_vector.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -314,6 +314,76 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d3_t __attribute__((ext_vector_type(3)));
|
||||
|
||||
using type = d3_t;
|
||||
|
||||
union
|
||||
{
|
||||
d3_t d3_;
|
||||
StaticallyIndexedArray<d1_t, 3> d1x3_;
|
||||
StaticallyIndexedArray<d2_t, 1> d2x1_;
|
||||
StaticallyIndexedArray<d3_t, 1> d3x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d3_t>::value)
|
||||
{
|
||||
return data_.d3x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d3_t>::value)
|
||||
{
|
||||
return data_.d3x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
@@ -384,6 +454,158 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d5_t __attribute__((ext_vector_type(5)));
|
||||
|
||||
using type = d5_t;
|
||||
|
||||
union
|
||||
{
|
||||
d5_t d5_;
|
||||
StaticallyIndexedArray<d1_t, 5> d1x5_;
|
||||
StaticallyIndexedArray<d4_t, 1> d4x1_;
|
||||
StaticallyIndexedArray<d5_t, 1> d5x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x5_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d5_t>::value)
|
||||
{
|
||||
return data_.d5x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x5_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d5_t>::value)
|
||||
{
|
||||
return data_.d5x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d7_t __attribute__((ext_vector_type(7)));
|
||||
|
||||
using type = d7_t;
|
||||
|
||||
union
|
||||
{
|
||||
d7_t d7_;
|
||||
StaticallyIndexedArray<d1_t, 7> d1x7_;
|
||||
StaticallyIndexedArray<d2_t, 3> d2x3_;
|
||||
StaticallyIndexedArray<d4_t, 1> d4x1_;
|
||||
StaticallyIndexedArray<d7_t, 1> d7x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x7_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d7_t>::value)
|
||||
{
|
||||
return data_.d7x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x7_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d7_t>::value)
|
||||
{
|
||||
return data_.d7x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
@@ -466,6 +688,88 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d13_t __attribute__((ext_vector_type(13)));
|
||||
|
||||
using type = d13_t;
|
||||
|
||||
union
|
||||
{
|
||||
d13_t d13_;
|
||||
StaticallyIndexedArray<d1_t, 13> d1x13_;
|
||||
StaticallyIndexedArray<d4_t, 3> d4x3_;
|
||||
StaticallyIndexedArray<d8_t, 1> d8x1_;
|
||||
StaticallyIndexedArray<d13_t, 1> d13x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x13_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d13_t>::value)
|
||||
{
|
||||
return data_.d13x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x13_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x3_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
else if constexpr(is_same<X, d13_t>::value)
|
||||
{
|
||||
return data_.d13x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user