mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
supporting prefill shapes for preshuffle block scale gemm (#2975)
* debugging * debugging for prefill shapes * comment unused code * fix for prefill shapes * clearing up the code * add int4 to universal gemm example * clang formatted * adding test for prefill shapes in block scale gemm * lil improv on the block pipeline * Address Review Comment --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -75,6 +75,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int4")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
@@ -194,10 +194,7 @@ struct WeightPreshuffleInvoker
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -300,16 +300,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -353,6 +345,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}();
|
||||
// shuffled buffer B for device implementation
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::permute_vectors_i4x4_b(b_shuffle_host);
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -4,8 +4,18 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
|
||||
|
||||
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
|
||||
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
|
||||
- Row and Column-wise scaled: scaling implemented in Epilogue
|
||||
- Tensor-wise scaled: scaling implemented in Epilogue
|
||||
- Row and Column-wise scaled: All of the rowwise elements in A Matrix and columwise elements in B Matrix will share the same quantization element and the elementwisde operation will complete in epilogue.
|
||||
- Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
|
||||
- **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading
|
||||
- **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
|
||||
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
|
||||
- **Validation**: CPU/GPU validation and error tolerance options.
|
||||
|
||||
## build
|
||||
```
|
||||
|
||||
@@ -47,6 +47,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
QuantMode,
|
||||
ALayout, // for AQLayout
|
||||
BLayout, // for BQLayout
|
||||
false,
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
|
||||
@@ -450,4 +451,4 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv); }
|
||||
|
||||
@@ -166,6 +166,26 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_ = ADataType_,
|
||||
typename CDataType_ = ADataType_,
|
||||
@@ -261,7 +281,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
Reference in New Issue
Block a user