mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4272 (commit 52def72)
feat: add new optimized tutorial kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add 01_naive_gemm baseline implementation - Add 02_padding_k_first with PADDING_K_FIRST + MFMA_32x32x16 - Add 03_mfma_16x16x16 with PADDING_K_FIRST + MFMA_16x16x16 - Share common reference_gemm.hpp in parent gemm/ directory ## Proposed changes Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request. ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
42973fd546
commit
1bf66006c9
@@ -1,10 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_tutorial_naive_gemm practice_gemm.cpp)
|
||||
|
||||
target_compile_options(tile_tutorial_naive_gemm PRIVATE
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
add_dependencies(tutorials tile_tutorial_naive_gemm)
|
||||
@@ -1,92 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = PracticeGemmHostPolicy>
|
||||
struct PracticeGemmHostPipeline
|
||||
{
|
||||
using ADataType = typename Problem_::ADataType;
|
||||
using BDataType = typename Problem_::BDataType;
|
||||
using CDataType = typename Problem_::CDataType;
|
||||
using AccDataType = typename Problem_::AccDataType;
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using BlockTile = typename Problem::Shape::BlockTile;
|
||||
using WaveTile = typename Problem::Shape::WaveTile;
|
||||
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram) const
|
||||
{
|
||||
|
||||
// Size of the entire problem
|
||||
const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K
|
||||
const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N
|
||||
const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K
|
||||
|
||||
// Size of the block tile
|
||||
const auto MPerBlock = BlockTile::at(number<0>{});
|
||||
const auto NPerBlock = BlockTile::at(number<1>{});
|
||||
const auto KPerBlock = BlockTile::at(number<2>{});
|
||||
|
||||
// Number of block tile in the N direction to cover C (resultant) matrix
|
||||
const auto num_tile_n = integer_divide_ceil(N, NPerBlock);
|
||||
// Number of block tile in the M direction to cover C (resultant) matrix
|
||||
const auto num_tile_m = integer_divide_ceil(M, MPerBlock);
|
||||
|
||||
// if(get_thread_id() == 0 && get_block_id() == 0)
|
||||
// {
|
||||
// printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n);
|
||||
// printf("total number of tiles: %d\n", num_tile_m * num_tile_n);
|
||||
// }
|
||||
|
||||
// Get block id
|
||||
const auto id_block =
|
||||
get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1
|
||||
|
||||
// Map block id to tile id
|
||||
const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n);
|
||||
|
||||
const auto tile_id = block2tile(id_block);
|
||||
|
||||
const auto tile_id_m = tile_id.at(number<0>{});
|
||||
const auto tile_id_n = tile_id.at(number<1>{});
|
||||
|
||||
// if(get_thread_id() == 0 && get_block_id() == 15)
|
||||
// {
|
||||
// printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n);
|
||||
// }
|
||||
|
||||
const auto tile_origin_m = tile_id_m * MPerBlock;
|
||||
const auto tile_origin_n = tile_id_n * NPerBlock;
|
||||
|
||||
// create a tile window over dram for A and B
|
||||
const auto a_block_window = make_tile_window(
|
||||
a_dram, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {tile_origin_m, 0});
|
||||
|
||||
const auto b_block_window = make_tile_window(
|
||||
b_dram, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {tile_origin_n, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline =
|
||||
Policy::template GetPracticeGemmBlockPipeline<Problem>();
|
||||
|
||||
int num_loops_k = integer_divide_ceil(K, KPerBlock);
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()];
|
||||
const auto c_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char);
|
||||
auto c_window = make_tile_window(c_dram,
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
|
||||
{tile_origin_m, tile_origin_n});
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -1,54 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp"
|
||||
#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename Shape_>
|
||||
struct PracticeGemmHostProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using Shape = remove_cvref_t<Shape_>;
|
||||
};
|
||||
|
||||
struct PracticeGemmHostPolicy
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline()
|
||||
{
|
||||
using PracticeGemmBlockPipelineProblem_ =
|
||||
PracticeGemmBlockPipelineProblem<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
typename Problem::AccDataType,
|
||||
typename Problem::Shape>;
|
||||
return PracticeGemmBlockPipelineAGmemBGmemCreg<PracticeGemmBlockPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -1,141 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "practice_gemm.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// TODO: GemmTypeConfig
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
// Setup simple argument parser for M, N, K
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "512", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "64", "k dimension")
|
||||
.insert("v", "1", "verification: 0=off, 1=on");
|
||||
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
// Get problem dimensions from command line
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
ck_tile::index_t verification = arg_parser.get_int("v");
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
ck_tile::index_t stride_c = N;
|
||||
|
||||
auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
|
||||
auto a_strides = std::array<ck_tile::index_t, 2>{stride_a, 1};
|
||||
auto b_strides = std::array<ck_tile::index_t, 2>{stride_b, 1};
|
||||
auto c_strides = std::array<ck_tile::index_t, 2>{stride_c, 1};
|
||||
|
||||
// tensors on host (cpu)
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host(c_lengths, c_strides);
|
||||
|
||||
// initialize tensors
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
c_host.SetZero();
|
||||
|
||||
// Print the tensors using the new print_first_n member function
|
||||
// std::cout << "Tensor A (first 10 elements): ";
|
||||
// a_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// std::cout << "Tensor B (first 10 elements): ";
|
||||
// b_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// std::cout << "Tensor C (first 10 elements): ";
|
||||
// c_host.print_first_n(10);
|
||||
// std::cout << std::endl;
|
||||
|
||||
// Create device tensors of same size as host tensors and copy data
|
||||
ck_tile::DeviceMem a_device(a_host);
|
||||
ck_tile::DeviceMem b_device(b_host);
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
|
||||
// TODO: BlockTileConfig
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl;
|
||||
using PracticeGemmShape = ck_tile::PracticeGemmShape<BlockTile, WaveTile>;
|
||||
std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl;
|
||||
using PracticeGemmHostProblem = ck_tile::
|
||||
PracticeGemmHostProblem<ADataType, BDataType, CDataType, AccDataType, PracticeGemmShape>;
|
||||
using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy;
|
||||
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "Total number of thread blocks: " << kGridSize << std::endl;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU
|
||||
|
||||
// Block size is now derived from the shape configuration
|
||||
constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize;
|
||||
std::cout << "Number of threads per block: " << kBlockSize << std::endl;
|
||||
std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl;
|
||||
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCU>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_device.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c));
|
||||
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
c_device.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp"
|
||||
#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockTile_, typename WaveTile_>
|
||||
struct PracticeGemmShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using WaveTile = remove_cvref_t<WaveTile_>;
|
||||
|
||||
static constexpr index_t BlockTile_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t BlockTile_N = BlockTile::at(number<1>{});
|
||||
static constexpr index_t BlockTile_K = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr index_t WaveTile_M = WaveTile::at(number<0>{});
|
||||
static constexpr index_t WaveTile_N = WaveTile::at(number<1>{});
|
||||
static constexpr index_t WaveTile_K = WaveTile::at(number<2>{});
|
||||
|
||||
// Thread block configuration
|
||||
static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront)
|
||||
static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads)
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "practice_gemm_shape",
|
||||
concat('x', BlockTile_M, BlockTile_N, BlockTile_K),
|
||||
concat('x', WaveTile_M, WaveTile_N, WaveTile_K));
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct PracticeGemmKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
// Derive block size from the shape configuration
|
||||
static constexpr index_t kBlockSize = Problem::Shape::kBlockSize;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
typename Problem::CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride_a,
|
||||
const index_t stride_b,
|
||||
const index_t stride_c) const
|
||||
{
|
||||
|
||||
auto a_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{});
|
||||
|
||||
auto b_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{});
|
||||
|
||||
const auto c_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{});
|
||||
|
||||
PracticeGemmHostPipeline<Problem, Policy>{}(a_dram, b_dram, c_dram);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,5 +6,4 @@ include_directories(AFTER
|
||||
)
|
||||
|
||||
add_subdirectory(00_copy_kernel)
|
||||
add_subdirectory(01_naive_gemm)
|
||||
|
||||
add_subdirectory(gemm)
|
||||
|
||||
17
tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt
Normal file
17
tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(EXAMPLE_NAIVE_GEMM "tile_tutorial_naive_gemm")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_NAIVE_GEMM}")
|
||||
|
||||
add_executable(${EXAMPLE_NAIVE_GEMM} EXCLUDE_FROM_ALL practice_gemm.cpp)
|
||||
target_include_directories(${EXAMPLE_NAIVE_GEMM} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported)
|
||||
|
||||
target_compile_options(${EXAMPLE_NAIVE_GEMM} PRIVATE ${EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_dependencies(tutorials ${EXAMPLE_NAIVE_GEMM})
|
||||
@@ -3,34 +3,34 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = PracticeGemmBlockPolicy>
|
||||
struct PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using BlockTile = typename Problem::Shape::BlockTile;
|
||||
using WaveTile = typename Problem::Shape::WaveTile;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t NPerBlock = BlockTile::at(number<1>{});
|
||||
static constexpr index_t KPerBlock = BlockTile::at(number<2>{});
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t MPerWave = WaveTile::at(number<0>{});
|
||||
static constexpr index_t NPerWave = WaveTile::at(number<1>{});
|
||||
static constexpr index_t KPerWave = WaveTile::at(number<2>{});
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
using BlockGemm =
|
||||
remove_cvref_t<decltype(Policy::template GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
@@ -52,9 +52,9 @@ struct PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
@@ -82,38 +82,38 @@ struct PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
@@ -131,28 +131,29 @@ struct PracticeGemmBlockPipelineAGmemBGmemCreg
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers
|
||||
b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS
|
||||
store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
@@ -3,41 +3,27 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "../warp_level/block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp"
|
||||
#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename Shape_>
|
||||
struct PracticeGemmBlockPipelineProblem
|
||||
// Default policy for BlockGemmPipelineAGmemBGmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmPipelineAGmemBGmemCRegPolicy
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using Shape = Shape_;
|
||||
};
|
||||
|
||||
struct PracticeGemmBlockPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline()
|
||||
{
|
||||
return PracticeGemmWarpPipelineASmemBSmemCreg<Problem>{};
|
||||
}
|
||||
|
||||
// 3d + no padding (NAIVE_IMPLEMENTATION)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -52,14 +38,16 @@ struct PracticeGemmBlockPolicy
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + no padding (NAIVE_IMPLEMENTATION)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -81,14 +69,12 @@ struct PracticeGemmBlockPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
constexpr index_t kMWarp = BlockGemm::MWarp;
|
||||
constexpr index_t kNWarp = BlockGemm::NWarp;
|
||||
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -98,25 +84,23 @@ struct PracticeGemmBlockPolicy
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>, // replication
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // hierarchy
|
||||
tuple<sequence<1>, sequence<1, 2>>, // parallelism
|
||||
tuple<sequence<1>, sequence<2, 0>>, // paralleism
|
||||
sequence<1, 2>, // yield
|
||||
sequence<0, 1>>{}); // yield
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPracticeWaveGemmPipeline<Problem>())>;
|
||||
constexpr index_t kMWarp = BlockGemm::MWarp;
|
||||
constexpr index_t kNWarp = BlockGemm::NWarp;
|
||||
constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size();
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{});
|
||||
constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{});
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -133,6 +117,12 @@ struct PracticeGemmBlockPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return BlockGemmASmemBSmemCReg<Problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
72
tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp
Normal file
72
tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GridGemm
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
static constexpr auto kNPerBlock = Policy::kNPerBlock;
|
||||
static constexpr auto kKPerBlock = Policy::kKPerBlock;
|
||||
|
||||
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
|
||||
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
|
||||
const BGridTensorView& b_grid,
|
||||
CGridTensorView& c_grid,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
|
||||
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
const auto id_tile = block2tile(id_block);
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock);
|
||||
|
||||
// A block window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
|
||||
|
||||
// B block window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
|
||||
const auto acc_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast to CDataType and apply CElementFunction
|
||||
const auto c_block_tile = tile_elementwise_in(
|
||||
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
|
||||
acc_block_tile);
|
||||
|
||||
// store C
|
||||
auto c_window = make_tile_window(
|
||||
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
|
||||
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
155
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp
Normal file
155
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp
Normal file
@@ -0,0 +1,155 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "practice_gemm.hpp"
|
||||
#include "../reference_gemm.hpp"
|
||||
|
||||
/*
|
||||
* Naive GEMM implementation (no optimizations)
|
||||
* A [M, K]
|
||||
* B [N, K]
|
||||
* C [M, N]
|
||||
*/
|
||||
|
||||
// elementwise lambda
|
||||
struct CElementFunction
|
||||
{
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
}
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
printf("*** Naive implementation test ***\n");
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
|
||||
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
||||
|
||||
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
||||
|
||||
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.mData.data());
|
||||
b_buf.ToDevice(b_host.mData.data());
|
||||
|
||||
// Alignment
|
||||
constexpr ck_tile::index_t kAAlignment = 8;
|
||||
constexpr ck_tile::index_t kBAlignment = 8;
|
||||
constexpr ck_tile::index_t kCAlignment = 8;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
using gemm_kernel = ck_tile::Gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CElementFunction,
|
||||
kAAlignment,
|
||||
kBAlignment,
|
||||
kCAlignment,
|
||||
kBlockSize,
|
||||
kGemmMPerBlock,
|
||||
kGemmNPerBlock,
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockPerCu>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Lda,
|
||||
Ldb,
|
||||
Ldc,
|
||||
CElementFunction{}));
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
c_buf.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
139
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp
Normal file
139
tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "host_level/grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename CElementFunction_>
|
||||
struct GridGemmProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename CElementFunction,
|
||||
index_t kAAlignment,
|
||||
index_t kBAlignment,
|
||||
index_t kCAlignment,
|
||||
index_t kBlockSize_,
|
||||
index_t kMPerBlock_,
|
||||
index_t kNPerBlock_,
|
||||
index_t kKPerBlock_>
|
||||
struct Gemm
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
|
||||
using GridGemmProblem_ =
|
||||
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
|
||||
|
||||
struct GridGemmPolicy
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kMPerBlock_;
|
||||
static constexpr index_t kNPerBlock = kNPerBlock_;
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
using GridGemm_ = GridGemm<GridGemmProblem_, GridGemmPolicy>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t Lda,
|
||||
const index_t Ldb,
|
||||
const index_t Ldc,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto a_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto b_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto c_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
GridGemm_{}(a_dram, b_dram, c_dram, c_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,18 +4,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = PracticeGemmWarpPolicy>
|
||||
struct PracticeGemmWarpPipelineASmemBSmemCreg
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegPolicy>
|
||||
struct BlockGemmASmemBSmemCReg
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using WaveGemmShape = remove_cvref_t<typename Problem::Shape>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
@@ -58,16 +61,14 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
|
||||
NPerBlock == WaveGemmShape::BlockTile_N &&
|
||||
KPerBlock == WaveGemmShape::BlockTile_K,
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
@@ -116,20 +117,17 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
@@ -165,13 +163,62 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == WaveGemmShape::BlockTile_M &&
|
||||
NPerBlock == WaveGemmShape::BlockTile_N &&
|
||||
KPerBlock == WaveGemmShape::BlockTile_K,
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
@@ -191,6 +238,46 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// c += a * b
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
@@ -10,14 +10,16 @@ namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct PracticeGemmWarpPolicy
|
||||
struct BlockGemmASmemBSmemCRegPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// NAIVE_IMPLEMENTATION uses 4x1 warp configuration
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
// NAIVE_IMPLEMENTATION uses mfma m32 n32 k8
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -25,6 +27,13 @@ struct PracticeGemmWarpPolicy
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
17
tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt
Normal file
17
tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(EXAMPLE_PADDING_K_FIRST "tile_tutorial_padding_k_first")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_PADDING_K_FIRST}")
|
||||
|
||||
add_executable(${EXAMPLE_PADDING_K_FIRST} EXCLUDE_FROM_ALL gemm.cpp)
|
||||
target_include_directories(${EXAMPLE_PADDING_K_FIRST} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported)
|
||||
|
||||
target_compile_options(${EXAMPLE_PADDING_K_FIRST} PRIVATE ${EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS})
|
||||
|
||||
add_dependencies(tutorials ${EXAMPLE_PADDING_K_FIRST})
|
||||
@@ -0,0 +1,285 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegPolicy>
|
||||
struct BlockGemmASmemBSmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// c += a * b
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmASmemBSmemCReg with MFMA_32x32x16 (8x2) instruction
|
||||
struct BlockGemmASmemBSmemCRegPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// KERNEL_A uses 4x1 warp configuration
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
// KERNEL_A uses mfma m32 n32 k16 (8x2 variant)
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,166 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmPipelineAGmemBGmemCReg with PADDING_K_FIRST optimization
|
||||
struct BlockGemmPipelineAGmemBGmemCRegPolicy
|
||||
{
|
||||
// 3d + PADDING_K_FIRST - adds padding to K dimension to avoid bank conflicts
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// PADDING_K_FIRST: stride is (kKPerBlock / kKPack + 1) * kKPack instead of kKPerBlock
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + no padding for B (PADDING_K_FIRST only pads A in version2)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// B uses same layout as NAIVE (no padding)
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return BlockGemmASmemBSmemCReg<Problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
158
tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp
Normal file
158
tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm.hpp"
|
||||
#include "../reference_gemm.hpp"
|
||||
|
||||
/*
|
||||
* KERNEL_A: GEMM with PADDING_K_FIRST + MFMA_32x32x16 (8x2)
|
||||
* A [M, K]
|
||||
* B [N, K]
|
||||
* C [M, N]
|
||||
*/
|
||||
|
||||
// elementwise lambda
|
||||
struct CElementFunction
|
||||
{
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
}
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
printf("*** Kernel A test ***\n");
|
||||
printf(" --> Using PADDING_K_FIRST\n");
|
||||
printf(" --> Using mfma_32x32x(8x2)\n");
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
|
||||
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
||||
|
||||
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
||||
|
||||
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.mData.data());
|
||||
b_buf.ToDevice(b_host.mData.data());
|
||||
|
||||
// Alignment
|
||||
constexpr ck_tile::index_t kAAlignment = 8;
|
||||
constexpr ck_tile::index_t kBAlignment = 8;
|
||||
constexpr ck_tile::index_t kCAlignment = 8;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
using gemm_kernel = ck_tile::Gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CElementFunction,
|
||||
kAAlignment,
|
||||
kBAlignment,
|
||||
kCAlignment,
|
||||
kBlockSize,
|
||||
kGemmMPerBlock,
|
||||
kGemmNPerBlock,
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockPerCu>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Lda,
|
||||
Ldb,
|
||||
Ldc,
|
||||
CElementFunction{}));
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
c_buf.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
139
tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp
Normal file
139
tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename CElementFunction_>
|
||||
struct GridGemmProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename CElementFunction,
|
||||
index_t kAAlignment,
|
||||
index_t kBAlignment,
|
||||
index_t kCAlignment,
|
||||
index_t kBlockSize_,
|
||||
index_t kMPerBlock_,
|
||||
index_t kNPerBlock_,
|
||||
index_t kKPerBlock_>
|
||||
struct Gemm
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
|
||||
using GridGemmProblem_ =
|
||||
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
|
||||
|
||||
struct GridGemmPolicy
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kMPerBlock_;
|
||||
static constexpr index_t kNPerBlock = kNPerBlock_;
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
using GridGemm_ = GridGemm<GridGemmProblem_, GridGemmPolicy>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t Lda,
|
||||
const index_t Ldb,
|
||||
const index_t Ldc,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto a_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto b_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto c_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
GridGemm_{}(a_dram, b_dram, c_dram, c_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
72
tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp
Normal file
72
tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GridGemm
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
static constexpr auto kNPerBlock = Policy::kNPerBlock;
|
||||
static constexpr auto kKPerBlock = Policy::kKPerBlock;
|
||||
|
||||
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
|
||||
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
|
||||
const BGridTensorView& b_grid,
|
||||
CGridTensorView& c_grid,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
|
||||
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
const auto id_tile = block2tile(id_block);
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock);
|
||||
|
||||
// A block window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
|
||||
|
||||
// B block window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
|
||||
const auto acc_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast to CDataType and apply CElementFunction
|
||||
const auto c_block_tile = tile_elementwise_in(
|
||||
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
|
||||
acc_block_tile);
|
||||
|
||||
// store C
|
||||
auto c_window = make_tile_window(
|
||||
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
|
||||
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
17
tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt
Normal file
17
tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(EXAMPLE_MFMA_16X16X16 "tile_tutorial_mfma_16x16x16")
|
||||
|
||||
message(DEBUG "adding example ${EXAMPLE_MFMA_16X16X16}")
|
||||
|
||||
add_executable(${EXAMPLE_MFMA_16X16X16} EXCLUDE_FROM_ALL gemm.cpp)
|
||||
target_include_directories(${EXAMPLE_MFMA_16X16X16} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported)
|
||||
|
||||
target_compile_options(${EXAMPLE_MFMA_16X16X16} PRIVATE ${EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS})
|
||||
|
||||
add_dependencies(tutorials ${EXAMPLE_MFMA_16X16X16})
|
||||
@@ -0,0 +1,285 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegPolicy>
|
||||
struct BlockGemmASmemBSmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// c += a * b
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmASmemBSmemCReg with MFMA_16x16x16 instruction
|
||||
struct BlockGemmASmemBSmemCRegPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// KERNEL_B uses 4x1 warp configuration
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
|
||||
// KERNEL_B uses mfma m16 n16 k16
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,166 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Policy for BlockGemmPipelineAGmemBGmemCReg with PADDING_K_FIRST optimization
|
||||
struct BlockGemmPipelineAGmemBGmemCRegPolicy
|
||||
{
|
||||
// 3d + PADDING_K_FIRST - adds padding to K dimension to avoid bank conflicts
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// PADDING_K_FIRST: stride is (kKPerBlock / kKPack + 1) * kKPack instead of kKPerBlock
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + no padding for B (PADDING_K_FIRST only pads A in version2)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
// B uses same layout as NAIVE (no padding)
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return BlockGemmASmemBSmemCReg<Problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
158
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp
Normal file
158
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm.hpp"
|
||||
#include "../reference_gemm.hpp"
|
||||
|
||||
/*
|
||||
* KERNEL_B: GEMM with PADDING_K_FIRST + MFMA_16x16x16
|
||||
* A [M, K]
|
||||
* B [N, K]
|
||||
* C [M, N]
|
||||
*/
|
||||
|
||||
// elementwise lambda
|
||||
struct CElementFunction
|
||||
{
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
}
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
printf("*** Kernel B test ***\n");
|
||||
printf(" --> Using PADDING_K_FIRST\n");
|
||||
printf(" --> Using mfma_16x16x16\n");
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
|
||||
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
||||
|
||||
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
||||
|
||||
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.mData.data());
|
||||
b_buf.ToDevice(b_host.mData.data());
|
||||
|
||||
// Alignment
|
||||
constexpr ck_tile::index_t kAAlignment = 8;
|
||||
constexpr ck_tile::index_t kBAlignment = 8;
|
||||
constexpr ck_tile::index_t kCAlignment = 8;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
using gemm_kernel = ck_tile::Gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CElementFunction,
|
||||
kAAlignment,
|
||||
kBAlignment,
|
||||
kCAlignment,
|
||||
kBlockSize,
|
||||
kGemmMPerBlock,
|
||||
kGemmNPerBlock,
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockPerCu>(gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Lda,
|
||||
Ldb,
|
||||
Ldc,
|
||||
CElementFunction{}));
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
c_buf.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
139
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp
Normal file
139
tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename CElementFunction_>
|
||||
struct GridGemmProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename CElementFunction,
|
||||
index_t kAAlignment,
|
||||
index_t kBAlignment,
|
||||
index_t kCAlignment,
|
||||
index_t kBlockSize_,
|
||||
index_t kMPerBlock_,
|
||||
index_t kNPerBlock_,
|
||||
index_t kKPerBlock_>
|
||||
struct Gemm
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
|
||||
using GridGemmProblem_ =
|
||||
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
|
||||
|
||||
struct GridGemmPolicy
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kMPerBlock_;
|
||||
static constexpr index_t kNPerBlock = kNPerBlock_;
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
using GridGemm_ = GridGemm<GridGemmProblem_, GridGemmPolicy>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t Lda,
|
||||
const index_t Ldb,
|
||||
const index_t Ldc,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto a_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto b_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto c_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
GridGemm_{}(a_dram, b_dram, c_dram, c_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
72
tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp
Normal file
72
tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GridGemm
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
static constexpr auto kNPerBlock = Policy::kNPerBlock;
|
||||
static constexpr auto kKPerBlock = Policy::kKPerBlock;
|
||||
|
||||
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
|
||||
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
|
||||
const BGridTensorView& b_grid,
|
||||
CGridTensorView& c_grid,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
|
||||
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
const auto id_tile = block2tile(id_block);
|
||||
|
||||
const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock);
|
||||
const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock);
|
||||
|
||||
// A block window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
|
||||
|
||||
// B block window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
|
||||
const auto acc_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast to CDataType and apply CElementFunction
|
||||
const auto c_block_tile = tile_elementwise_in(
|
||||
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
|
||||
acc_block_tile);
|
||||
|
||||
// store C
|
||||
auto c_window = make_tile_window(
|
||||
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
|
||||
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
10
tutorial/ck_tile/gemm/CMakeLists.txt
Normal file
10
tutorial/ck_tile/gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
include_directories(AFTER
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
add_subdirectory(01_naive_gemm)
|
||||
add_subdirectory(02_padding_k_first)
|
||||
add_subdirectory(03_mfma_16x16x16)
|
||||
@@ -32,5 +32,6 @@ void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1);
|
||||
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
Reference in New Issue
Block a user