mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Grouped convolution backward weight special vector size loads (#1772)
* Grouped convolution backward weight special vector size loads * Instnaces and tests * Fixes * Add 7 and 13 special cases * fix comments * Fix * Fix2 * fixes * fix atomic add bf16
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -53,7 +53,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
static constexpr auto I10 = Number<10>{};
|
||||
static constexpr auto I12 = Number<12>{};
|
||||
static constexpr auto I13 = Number<13>{};
|
||||
static constexpr auto I14 = Number<14>{};
|
||||
static constexpr auto I16 = Number<16>{};
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
dst_vector_type op_r_v;
|
||||
@@ -234,14 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
|
||||
using VectorSizeLookupTable = Tuple<Sequence<>,
|
||||
Sequence<I1>,
|
||||
Sequence<I2>,
|
||||
Sequence<I2, I1>,
|
||||
Sequence<I4>,
|
||||
Sequence<I4, I1>,
|
||||
Sequence<I4, I2>,
|
||||
Sequence<I4, I2, I1>,
|
||||
Sequence<I8>,
|
||||
Sequence<I8, I1>,
|
||||
Sequence<I8, I2>,
|
||||
Sequence<I8, I2, I1>,
|
||||
Sequence<I8, I4>,
|
||||
Sequence<I8, I4, I1>,
|
||||
Sequence<I8, I4, I2>,
|
||||
Sequence<I8, I4, I2, I1>,
|
||||
Sequence<I16>>;
|
||||
using VectorOffsetsLookupTable = Tuple<Sequence<>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0, I2>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0, I4>,
|
||||
Sequence<I0, I4>,
|
||||
Sequence<I0, I4, I6>,
|
||||
Sequence<I0>,
|
||||
Sequence<I0, I8>,
|
||||
Sequence<I0, I8>,
|
||||
Sequence<I0, I8, I10>,
|
||||
Sequence<I0, I8>,
|
||||
Sequence<I0, I8, I12>,
|
||||
Sequence<I0, I8, I12>,
|
||||
Sequence<I0, I8, I12, I14>,
|
||||
Sequence<I0>>;
|
||||
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
|
||||
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
static_for<0, tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::Size(), 1>{}(
|
||||
[&](auto v_idx) {
|
||||
constexpr auto VectorLoadSize =
|
||||
tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::At(v_idx);
|
||||
constexpr auto LoadOffset =
|
||||
tuple_element_t<SrcScalarPerVector, VectorOffsetsLookupTable>::At(v_idx);
|
||||
|
||||
using src_vector_container = vector_type_maker_t<SrcData, VectorLoadSize>;
|
||||
using src_vector_container_t = typename src_vector_container::type;
|
||||
|
||||
src_vector_container src_vector =
|
||||
src_vector_container{src_buf.template Get<src_vector_container_t>(
|
||||
src_coord_.GetOffset() / PackedSize + LoadOffset, true)};
|
||||
|
||||
static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if
|
||||
// needed
|
||||
src_element_op_(
|
||||
op_r_v.template AsType<dst_elem_op_vec_t>()(idx + LoadOffset),
|
||||
src_vector.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
|
||||
Reference in New Issue
Block a user