[CK TILE] GEMM with packed i4 (#1885)

* [CK TILE] GEMM with packed i4

* Fixes

* fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2025-02-20 09:59:49 +01:00
committed by GitHub
parent 824e2c1737
commit 4d9973ec8e
32 changed files with 882 additions and 305 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -153,12 +153,12 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename>
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<array<T, N>>
struct vector_traits<array<T, N>, void>
{
using scalar_type = T;
static constexpr index_t vector_size = N;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -149,17 +149,24 @@ struct thread_buffer {
};
// clang-format on
template <typename>
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>>
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<!std::is_class_v<T>>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<std::is_class_v<T>>>
{
using scalar_type = typename T::type;
static constexpr index_t vector_size = N;
};
#endif
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -294,7 +294,7 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename>
template <typename, typename = void>
struct vector_traits;
// specialization for array