mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
lds a,b ok
This commit is contained in:
@@ -503,10 +503,6 @@ include_directories(BEFORE
|
||||
)
|
||||
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
add_compile_options(-Werror)
|
||||
add_compile_options(-Weverything)
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
|
||||
|
||||
@@ -66,7 +66,7 @@ else()
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
-Wno-reserved-identifier
|
||||
-Werror
|
||||
# -Werror
|
||||
-Wno-option-ignored
|
||||
-Wsign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
|
||||
@@ -117,6 +117,10 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
// ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
// ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
// ck_tile::FillConstant<ADataType>{1.f}(a_m_k);
|
||||
// ck_tile::FillConstant<BDataType>{1.f}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -211,15 +211,16 @@ struct FillNormalDistributionIntegerValue
|
||||
template <typename T>
|
||||
struct FillMonotonicSeq
|
||||
{
|
||||
T init_value_{0};
|
||||
T init_value_{-1024};
|
||||
T step_{1};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::generate(first, last, [=, n = init_value_]() mutable {
|
||||
T step_start = init_value_;
|
||||
std::generate(first, last, [&, n = init_value_]() mutable {
|
||||
auto tmp = n;
|
||||
n += step_;
|
||||
if (n > step_start + 2047) {step_start += step_; n = step_start;}
|
||||
return tmp;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -42,9 +42,6 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
// if(threadIdx.x == 0 && blockIdx.x==0) {
|
||||
// printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock);
|
||||
// }
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -56,12 +53,12 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
// constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
// constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
// constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
// const index_t iMWarp = get_warp_id() / NWarp;
|
||||
// const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// if(threadIdx.x == 0 && blockIdx.x==0) {
|
||||
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
|
||||
@@ -69,91 +66,69 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
|
||||
|
||||
|
||||
// construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
#if 0 // FIXME: using array will cause register spill
|
||||
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
|
||||
{a_warp_window_tmp}};
|
||||
|
||||
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
}
|
||||
}
|
||||
#else
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// construct B-warp-window
|
||||
make_tuple(MPerBlock, KPerBlock),
|
||||
{0, 0},
|
||||
Policy::template MakeALDSTileDistribution<Problem>());
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
make_tuple(NPerBlock, KPerBlock),
|
||||
{0, 0},
|
||||
Policy::template MakeBLDSTileDistribution<Problem>());
|
||||
|
||||
#if 0 // FIXME: using array will cause register spill
|
||||
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
|
||||
{b_warp_window_tmp}};
|
||||
auto a_block_tensor = load_tile(a_warp_window_tmp);
|
||||
auto b_block_tensor = load_tile(b_warp_window_tmp);
|
||||
|
||||
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
|
||||
{
|
||||
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
}
|
||||
}
|
||||
#else
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
// if (threadIdx.x == 0) {
|
||||
// printf("0\n");
|
||||
// constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans();
|
||||
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx)));
|
||||
// });
|
||||
// printf("\n");
|
||||
// });
|
||||
// }
|
||||
// __syncthreads();
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -173,6 +148,36 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
});
|
||||
});
|
||||
});
|
||||
// constexpr auto c_warp_y_lengths =
|
||||
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
// // hot loop:
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// // read A warp tensor from A block window
|
||||
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// // read B warp tensor from B Block window
|
||||
// // const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// // read C warp tensor from C block tensor
|
||||
// CWarpTensor c_warp_tensor;
|
||||
|
||||
// c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// // warp GEMM
|
||||
// WG{}(c_warp_tensor, a_warp_tensor(mIter, kIter), b_warp_tensor(nIter, kIter));
|
||||
|
||||
// // write C warp tensor into C block tensor
|
||||
// c_block_tensor.set_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
// c_warp_tensor.get_thread_buffer());
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
@@ -217,5 +222,72 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
// construct A-warp-window
|
||||
// auto a_warp_window_tmp = make_tile_window(
|
||||
// a_block_window.get_bottom_tensor_view(),
|
||||
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
|
||||
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
// #if 0 // FIXME: using array will cause register spill
|
||||
// array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
|
||||
// {a_warp_window_tmp}};
|
||||
|
||||
// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
// {
|
||||
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
// {
|
||||
// move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
// }
|
||||
// }
|
||||
// #else
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
// MIterPerWarp>
|
||||
// a_warp_windows;
|
||||
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
// });
|
||||
// });
|
||||
// #endif
|
||||
|
||||
// construct B-warp-window
|
||||
// auto b_warp_window_tmp = make_tile_window(
|
||||
// b_block_window.get_bottom_tensor_view(),
|
||||
// make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
// make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
// #if 0 // FIXME: using array will cause register spill
|
||||
// array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
|
||||
// {b_warp_window_tmp}};
|
||||
|
||||
// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
|
||||
// {
|
||||
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
// {
|
||||
// move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
// }
|
||||
// }
|
||||
// #else
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
// NIterPerWarp>
|
||||
// b_warp_windows;
|
||||
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
// move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
// });
|
||||
// });
|
||||
// #endif
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -40,7 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 2, 2);
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2);
|
||||
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
@@ -55,6 +55,96 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALDSTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
static_assert(false, "Unsupported tensor_layout right now.");
|
||||
}
|
||||
else
|
||||
{
|
||||
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
|
||||
constexpr index_t K2 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t K0 = KPerBlock / K1 / K2;
|
||||
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
|
||||
constexpr index_t M2 = 32; // MPERXDL
|
||||
constexpr index_t M1 = 2; //MWAVE
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
{
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<2>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported shape right now.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLDSTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert(false, "Unsupported tensor_layout right now.");
|
||||
}
|
||||
else
|
||||
{
|
||||
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
|
||||
constexpr index_t K2 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t K0 = KPerBlock / K1 / K2;
|
||||
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
|
||||
constexpr index_t N2 = 32; // MPERXDL
|
||||
constexpr index_t N1 = 2; //MWAVE
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (N2 * K0) == 0)
|
||||
{
|
||||
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<2>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported shape right now.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -133,7 +133,16 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
// constexpr auto span_2d = decltype(a_block_tile)::get_distributed_spans();
|
||||
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// printf("%f,", type_convert<float>(a_block_tile(i_j_idx)));
|
||||
// });
|
||||
// printf("\n");
|
||||
// });
|
||||
// }
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
@@ -170,7 +179,17 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
|
||||
}
|
||||
}
|
||||
|
||||
// __syncthreads();
|
||||
// if (threadIdx.x == 0) {
|
||||
// for (int j = 0; j < 256; j++) {
|
||||
// for(int i = 0; i < 32; i++) {
|
||||
// int ik0 = i /8;
|
||||
// int ik1 = i % 8;
|
||||
// printf("%f,", type_convert<float>(p_b_lds[ik1 + j * 8 + ik0 * 8 * 256]));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// }
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
@@ -219,6 +238,17 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
|
||||
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
|
||||
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
|
||||
// });
|
||||
// printf("\n");
|
||||
// });
|
||||
// }
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -77,7 +77,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
make_tuple(number<(kNPerBlock) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -130,74 +130,74 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
}
|
||||
#elif 1
|
||||
// fake XOR
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
// {
|
||||
// using namespace ck_tile;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
|
||||
number<kKPerBlock>{});
|
||||
// constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
|
||||
// make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
|
||||
// number<kKPerBlock>{});
|
||||
|
||||
constexpr index_t kK1 = 16 / sizeof(ADataType);
|
||||
// constexpr index_t kK1 = 16 / sizeof(ADataType);
|
||||
|
||||
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_d1_d2_d3,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
|
||||
make_pass_through_transform(2)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}));
|
||||
// constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_d1_d2_d3,
|
||||
// make_tuple(
|
||||
// make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
|
||||
// make_pass_through_transform(2)),
|
||||
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
// make_tuple(sequence<0, 2>{}, sequence<1>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
a_lds_block_desc_d4_d5_d6,
|
||||
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
|
||||
make_pass_through_transform(kKPerBlock)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_d4_d5_d6,
|
||||
// make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
|
||||
// make_pass_through_transform(kKPerBlock)),
|
||||
// make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc_m_k;
|
||||
}
|
||||
// return a_lds_block_desc_m_k;
|
||||
// }
|
||||
|
||||
// fake XOR
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
// // fake XOR
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
// {
|
||||
// using namespace ck_tile;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
// using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
// constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
|
||||
number<kKPerBlock>{});
|
||||
// constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
|
||||
// make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
|
||||
// number<kKPerBlock>{});
|
||||
|
||||
constexpr index_t kK1 = 16 / sizeof(BDataType);
|
||||
// constexpr index_t kK1 = 16 / sizeof(BDataType);
|
||||
|
||||
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_d1_d2_d3,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
|
||||
make_pass_through_transform(2)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}));
|
||||
// constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
|
||||
// b_lds_block_desc_d1_d2_d3,
|
||||
// make_tuple(
|
||||
// make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
|
||||
// make_pass_through_transform(2)),
|
||||
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
// make_tuple(sequence<0, 2>{}, sequence<1>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
b_lds_block_desc_d4_d5_d6,
|
||||
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
|
||||
make_pass_through_transform(kKPerBlock)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
// b_lds_block_desc_d4_d5_d6,
|
||||
// make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
|
||||
// make_pass_through_transform(kKPerBlock)),
|
||||
// make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc_n_k;
|
||||
}
|
||||
// return b_lds_block_desc_n_k;
|
||||
// }
|
||||
#endif
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user