[CK_TILE] not using structures under ck_tile/ops for ck_tile/host (#1834)

* not using structures under ck_tile/ops for ck_tile/host

* update as constexpr function

* Rename fn

* Update other examples.

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
carlushuang
2025-01-24 15:35:54 +08:00
committed by GitHub
parent 052a72655c
commit 5b9b083dbc
4 changed files with 67 additions and 109 deletions

View File

@@ -3,6 +3,13 @@
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
@@ -128,16 +135,19 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{});
stride_As[i] =
ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
stride_Bs[i] =
ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] =
ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout)));
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
@@ -177,8 +187,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
{
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);