mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Merge commit 'c254f3d7b4cccae5c884b419842a01eec4ed74fc' into develop
This commit is contained in:
@@ -75,54 +75,17 @@ struct layernorm2d_fwd_traits_
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
|
||||
@@ -304,6 +304,14 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -344,6 +352,24 @@ struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
|
||||
@@ -91,7 +91,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -90,61 +92,6 @@ void permute_tensor_b(Tensor& tensor)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename ADataType,
|
||||
@@ -229,16 +176,43 @@ template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK0PerLane = 2;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
kABK0PerLane,
|
||||
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
@@ -399,7 +373,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -5,7 +5,7 @@ KNAME=1
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
run_tests() {
|
||||
for m in 512 1024; do
|
||||
@@ -32,5 +32,8 @@ run_tests "fp16"
|
||||
run_tests "bf16"
|
||||
run_tests "fp8"
|
||||
run_tests "bf8"
|
||||
run_tests "fp16i4"
|
||||
run_tests "fp8i4"
|
||||
run_tests "bf8i4"
|
||||
|
||||
set +x
|
||||
|
||||
@@ -5,11 +5,8 @@
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "run_gemm_example_common.hpp"
|
||||
@@ -58,7 +55,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
else if(data_type == "fp16i4")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
@@ -74,6 +71,36 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
@@ -71,11 +71,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
constexpr bool kTwoPass = true;
|
||||
|
||||
using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
using WarpTile = ck_tile::sequence<1, 64>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
using ThreadPerBlock = ck_tile::sequence<2, 128>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
using PipelineTraits =
|
||||
ck_tile::Rmsnorm2dFwdTraits<true, // kPadN
|
||||
|
||||
@@ -75,54 +75,17 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -80,55 +80,17 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
|
||||
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
|
||||
|
||||
static constexpr auto WarpSize = ck_tile::get_warp_size();
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
|
||||
@@ -99,12 +99,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
constexpr bool kThreePass = true;
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<4, 128>;
|
||||
using WarpTile = ck_tile::sequence<1, 64>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
using BlockTile = ck_tile::sequence<4, 128>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
using ThreadPerBlock = ck_tile::sequence<4, 64>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
GammaDataType,
|
||||
|
||||
@@ -94,12 +94,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
constexpr bool kTwoPass = true;
|
||||
|
||||
using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
using WarpTile = ck_tile::sequence<1, 64>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
using ThreadPerBlock = ck_tile::sequence<2, 128>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
|
||||
SmoothScaleDataType,
|
||||
ComputeDataType,
|
||||
|
||||
@@ -49,54 +49,16 @@ struct smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
|
||||
@@ -38,54 +38,17 @@ struct moe_smoothquant_traits_
|
||||
using InputType = ck_tile::remove_cvref_t<InputType_>;
|
||||
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
|
||||
@@ -190,6 +190,30 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr bool kPadK = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 32 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
@@ -266,16 +290,43 @@ template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK0PerLane = 2;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
kABK0PerLane,
|
||||
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
|
||||
@@ -230,5 +230,9 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -268,6 +268,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
#if defined(CK_TILE_USE_WMMA)
|
||||
return !run_flatmm_example<FlatmmConfig16_Wmma>(argc, argv);
|
||||
#else
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
@@ -285,6 +288,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -86,6 +86,14 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_Wmma : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
@@ -183,8 +191,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
#if !defined(CK_TILE_USE_WMMA)
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
#endif
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "flatmm_basic.json", "json file name to dump results");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
|
||||
@@ -43,15 +43,40 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
|
||||
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK0PerLane = 2;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
kABK0PerLane,
|
||||
FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
|
||||
@@ -228,4 +228,4 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include "../00_shared/host_tensor_utils.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
@@ -217,7 +218,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
@@ -208,7 +209,17 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
@@ -308,7 +309,17 @@ int run_gemm_example_with_layouts(int argc,
|
||||
aq_dev_buf.ToDevice(aq_tensor.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -125,7 +125,7 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_in
|
||||
float x_h = ((x_u8 & 0xf0) >> 4);
|
||||
|
||||
x_l = x_l > 7 ? x_l - 16 : x_l;
|
||||
x_h = x_l > 7 ? x_l - 16 : x_l;
|
||||
x_h = x_h > 7 ? x_h - 16 : x_h;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
fp32x2_t res = {x_h, x_l};
|
||||
|
||||
78
include/ck_tile/host/permute_pk_int4.hpp
Normal file
78
include/ck_tile/host/permute_pk_int4.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Permute packed int4 vectors for device implementation compatibility
|
||||
*
|
||||
* This function transforms 4 pk_int4_t values from original layout to hardware-optimized layout:
|
||||
* - Original layout (4 pk_int4_t): 0x76543210
|
||||
* - Transformed layout (4 pk_int4_t): 0x75316420
|
||||
*
|
||||
* Each pk_int4_t contains two 4-bit values packed in the high and low nibbles of an int8_t
|
||||
*
|
||||
* Example:
|
||||
* - Input: 0x76, 0x54, 0x32, 0x10
|
||||
* - Output: 0x75, 0x31, 0x64, 0x20
|
||||
*
|
||||
* @note Input tensor length must be a multiple of 4
|
||||
*
|
||||
* This transformation is required before transferring B matrix data (of type pk_int4_t) to device.
|
||||
* The device conversion functions (i4_to_half4, i4_to_bhalf4, amd_assembly_i4_to_fp8x8,
|
||||
* amd_assembly_i4_to_bf8x8) require data in 0x75316420 order to correctly convert pk_int4_t to
|
||||
* other numeric types.
|
||||
*/
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
auto tensor_row_buf = tensor.data();
|
||||
for(size_t idx = 0; idx < tensor.size(); idx += 4)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = bit_cast<int8_t>(tensor_row_buf[idx + k]);
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 0x76543210 => 0x75316420
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 0] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 1] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 2] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 3] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -50,7 +50,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
@@ -63,7 +63,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -35,43 +35,69 @@ namespace ck_tile {
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
// clang-format on
|
||||
*/
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename ThreadPerBlock_, // num threads along seq<M, N>
|
||||
typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
|
||||
struct Generic2dBlockShape
|
||||
{
|
||||
// block size
|
||||
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
|
||||
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
|
||||
// repeat of each thread along seq<M, N>
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{});
|
||||
static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{});
|
||||
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size();
|
||||
static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0);
|
||||
static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size();
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
|
||||
return total_warps * (get_warp_size() / ThreadPerBlock_N);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N / get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr index_t WarpPerBlock_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(get_warp_size() % ThreadPerBlock_N == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N % get_warp_size() == 0);
|
||||
return ThreadPerBlock_N / get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M;
|
||||
static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N;
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M;
|
||||
static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N;
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
|
||||
|
||||
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
|
||||
// repeat of each thread along seq<M, N>
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,15 +4,29 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
// Fast int4x4 to fp16x8_t data type conversion based on paper
|
||||
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
|
||||
// (https://arxiv.org/abs/2211.10017) and implementation:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
/**
|
||||
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
|
||||
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
|
||||
* @see https://arxiv.org/abs/2211.10017
|
||||
* @see
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
*
|
||||
* This function converts 4 4-bit integers into 4 fp16 values.
|
||||
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
@@ -46,6 +60,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function dequantizes 4 int4 values into 4 fp16 values and applies scaling.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
@@ -81,6 +107,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function converts 4 4-bit integers into 4 bf16 values.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf16(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be bf16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
@@ -110,37 +148,55 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, each byte represents 2 int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp8(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be fp8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t fp8x4_lo, fp8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
// register values [3, 2, 1, 0]
|
||||
static constexpr uint32_t reg0 = 0xd2d4d6d8;
|
||||
// register values [7, 6, 5, 4]
|
||||
static constexpr uint32_t reg1 = 0xc0c8ccd0;
|
||||
// register values [-1, -2, -3, -4]
|
||||
static constexpr uint32_t reg2 = 0x4C484000;
|
||||
// register values [-5, -6, -7, -8]
|
||||
static constexpr uint32_t reg3 = 0x56545250;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(fp8x4_lo),
|
||||
[v_dst_hi] "+v"(fp8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
return bit_cast<fp8x8_t>(((static_cast<uint64_t>(fp8x4_hi) << 32) | fp8x4_lo));
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
|
||||
auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
|
||||
|
||||
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
|
||||
@@ -157,37 +213,55 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a)
|
||||
/**
|
||||
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, each byte represents 2 int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf8(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be bf8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t bf8x4_lo, bf8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
// register values [3, 2, 1, 0]
|
||||
static constexpr uint32_t reg0 = 0Xc9cacbcc;
|
||||
// register values [7, 6, 5, 4]
|
||||
static constexpr uint32_t reg1 = 0Xc0c4c6c8;
|
||||
// register values [11, 10, 9, 8]
|
||||
static constexpr uint32_t reg2 = 0X46444000;
|
||||
// register values [15, 14, 13, 12]
|
||||
static constexpr uint32_t reg3 = 0X4b4a4948;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(bf8x4_lo),
|
||||
[v_dst_hi] "+v"(bf8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
return bit_cast<bf8x8_t>(((static_cast<uint64_t>(bf8x4_hi) << 32) | bf8x4_lo));
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
|
||||
auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
|
||||
|
||||
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
|
||||
}
|
||||
|
||||
struct PassThroughPack8
|
||||
@@ -209,12 +283,12 @@ struct PassThroughPack8
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_fp8x8(bit_cast<int>(x));
|
||||
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_bf8x8(bit_cast<int>(x));
|
||||
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
@@ -127,7 +127,10 @@ struct FlatmmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
|
||||
|
||||
@@ -185,11 +185,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
|
||||
@@ -237,8 +237,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t scale = 4;
|
||||
#else
|
||||
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
|
||||
#endif
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) * scale / 2;
|
||||
@@ -342,7 +346,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -350,8 +354,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
@@ -362,16 +371,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
|
||||
@@ -89,14 +89,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t scale = 4;
|
||||
#else
|
||||
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
|
||||
#endif
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
return TileShape::WarpTile::at(I2) * scale / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
return TileShape::WarpTile::at(I2) * scale / 4;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +197,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -200,8 +205,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
@@ -212,16 +222,15 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
|
||||
@@ -189,11 +189,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
|
||||
@@ -146,10 +146,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t mfma_per_wg = 1;
|
||||
#endif
|
||||
static constexpr index_t dsread_per_wg =
|
||||
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
|
||||
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
|
||||
Problem::VectorLoadSize ==
|
||||
0);
|
||||
#endif
|
||||
static constexpr index_t dsread_num_perK =
|
||||
WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
@@ -499,12 +503,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
|
||||
|
||||
@@ -181,9 +181,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
@@ -451,7 +449,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -471,7 +469,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -556,7 +554,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
reg_offset_for_row_data] +=
|
||||
(c_warp_tensor
|
||||
.get_thread_buffer()[reg_offset_for_row_data] *
|
||||
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +179,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
@@ -384,8 +382,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
|
||||
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
|
||||
kA_cvt_scale * kB_cvt_scale);
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -80,55 +80,17 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
|
||||
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
|
||||
|
||||
static constexpr auto WarpSize = ck_tile::get_warp_size();
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_aquant_utils.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -336,7 +337,17 @@ bool run_gemm_test_with_layouts(int argc,
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
|
||||
@@ -75,54 +75,17 @@ struct layernorm2d_fwd_traits_
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
|
||||
@@ -38,54 +38,16 @@ struct moe_smoothquant_traits_
|
||||
using InputType = ck_tile::remove_cvref_t<InputType_>;
|
||||
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
|
||||
@@ -74,54 +74,17 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
|
||||
@@ -49,54 +49,17 @@ struct smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
|
||||
Reference in New Issue
Block a user