mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
[CK_TILE] Adding support for TiledPermuteN on preshuffle Block Scale Gemm (#3019)
* Adding support for TiledPermuteN
* Adding test
* resolving remod.py
---------
Co-authored-by: root <root@banff-cyxtera-s73-2.ctr.dcgpu>
[ROCm/composable_kernel commit: 0584399571]
This commit is contained in:
@@ -143,7 +143,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ struct GemmConfigBase
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -164,6 +165,9 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -184,6 +188,9 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
|
||||
@@ -5,40 +5,7 @@
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
|
||||
{
|
||||
if(t->get_lengths().size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Host tensor is not rank 2 tensor.");
|
||||
}
|
||||
int m_ = t->get_lengths()[0];
|
||||
int aqk_ = t->get_lengths()[1];
|
||||
if(aqk_ % block_aq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
#include "ck_tile/host/shuffle_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
@@ -390,7 +357,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
|
||||
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
@@ -412,25 +379,26 @@ int run_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PreshuffleB)
|
||||
{
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
printf("PreshuffleB with TiledMMAPermuteN\n");
|
||||
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("PreshuffleB without TiledMMAPermuteN\n");
|
||||
b_k_n_dev = ck_tile::shuffle_b<GemmConfig>(b_k_n);
|
||||
}
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleB)
|
||||
{
|
||||
b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
|
||||
}
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleB)
|
||||
{
|
||||
b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -438,7 +406,15 @@ int run_gemm_example_with_layouts(int argc,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
|
||||
if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(*bq_tensor_ptr);
|
||||
bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
|
||||
}
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
|
||||
Reference in New Issue
Block a user