mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
update code
This commit is contained in:
@@ -24,19 +24,20 @@
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XDataType;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XDataType;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
@@ -170,7 +171,9 @@ using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
@@ -213,14 +216,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, 128,
|
||||
32, 32,
|
||||
MPerBlock, 256, KPerBlock,
|
||||
16, 16,
|
||||
8, 2,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
16, 16,
|
||||
8, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
@@ -328,22 +331,22 @@ int main(int argc, char* argv[])
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<A1DataType> a1_t_k_k(
|
||||
Tensor<XDataType> a1_t_k_k(
|
||||
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
|
||||
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B1DataType> b1_e_n_k(
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
// B preshuffle
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<A1DataType> a_scale_sorted(HostTensorDescriptor(
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<A1DataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B1DataType> b_scale_preshuffled(
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
@@ -364,50 +367,50 @@ int main(int argc, char* argv[])
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_2<A1DataType>{0, 1});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_2<XDataType>{0, 1});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_2<XDataType>{0, 1});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-1, 1});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
@@ -415,35 +418,37 @@ int main(int argc, char* argv[])
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a_scale_sorted.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tokenid = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topkid = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(tokenid = = tokens)
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k_k(tokenid, topkid, k);
|
||||
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preShuffleBuffer<ck::is_same_v<A0Layout, Row>>(
|
||||
a_scale_sorted.mData.data(), a_scale_preshuffled.mData.data(), sorted_size, K);
|
||||
preShuffleBuffer<ck::is_same_v<B0Layout, Row>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K);
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Row>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
@@ -614,9 +619,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
A1DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
XDataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -76,6 +76,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
// Hardcode to 2, for better 8-bit access pattern
|
||||
|
||||
static constexpr index_t MXdlPack = 2;
|
||||
static constexpr index_t NXdlPack = 2;
|
||||
static constexpr index_t KXdlPack = 2;
|
||||
|
||||
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -127,7 +133,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
|
||||
return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
@@ -138,7 +144,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
|
||||
return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
@@ -170,7 +176,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
|
||||
@@ -190,8 +196,8 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
* repeat dimensions.
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
@@ -327,49 +333,63 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
|
||||
|
||||
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
|
||||
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
|
||||
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k;
|
||||
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k;
|
||||
|
||||
protected:
|
||||
// M1, N1 as double buffer index
|
||||
// Read buffer + Compute buffer
|
||||
// A[M0, M1, M2, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / APackedSize>{}),
|
||||
make_tuple(Number<KPack / APackedSize>{},
|
||||
Number<KRepeat * MRepeat * KPack / APackedSize>{},
|
||||
Number<MRepeat * KPack / APackedSize>{},
|
||||
I1));
|
||||
// A[M0, M1, M2, M3, KPack]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor(make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
I1,
|
||||
Number<MXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<KPack>{}),
|
||||
make_tuple(Number<KPack * MXdlPack>{},
|
||||
Number<KRepeat * MRepeat * KPack>{},
|
||||
Number<MRepeat * KPack>{},
|
||||
Number<KPack>{},
|
||||
I1));
|
||||
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / BPackedSize>{}),
|
||||
make_tuple(Number<KPack / BPackedSize>{},
|
||||
Number<KRepeat * NRepeat * KPack / BPackedSize>{},
|
||||
Number<NRepeat * KPack / BPackedSize>{},
|
||||
I1));
|
||||
// B[N0, N1, N2, N3, KPack]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor(make_tuple(Number<NRepeat / NXdlPack>{},
|
||||
I1,
|
||||
Number<KRepeat>{},
|
||||
Number<NXdlPack>{},
|
||||
Number<KPack>{}),
|
||||
make_tuple(Number<KPack * NXdlPack>{},
|
||||
Number<KRepeat * NRepeat * KPack>{},
|
||||
Number<NRepeat * KPack>{},
|
||||
Number<KPack>{},
|
||||
I1));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
Number<NRepeat / NXdlPack>{},
|
||||
Number<MXdlPack>{},
|
||||
Number<NXdlPack>{},
|
||||
xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_block_desc_m0_m1_m2_m3_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KThreadChunk>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
Sequence<1, 1, 1, 1, KThreadChunk>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_block_desc_n0_n1_n2_n3_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KThreadChunk>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
Sequence<1, 1, 1, 1, KThreadChunk>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
|
||||
@@ -495,7 +495,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf, auto a_buf) {
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
// Prefetch a_scales to buf 1
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
@@ -683,8 +683,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}; // LoopFunc
|
||||
|
||||
LoopFunc(I0, I1, I0);
|
||||
LoopFunc(I1, I0, I1);
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
|
||||
@@ -277,16 +277,35 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 0
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I0)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
@@ -329,15 +348,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I1)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -194,6 +194,22 @@ struct GridwiseMoeGemmMX
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
static constexpr auto is_scale_mfma = true;
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA,
|
||||
@@ -232,22 +248,6 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
|
||||
@@ -317,7 +317,11 @@ struct GridwiseMoeGemmMX
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
|
||||
template <index_t MNXdlPerWave,
|
||||
index_t MNWaves,
|
||||
index_t MNXdlPack,
|
||||
index_t MNPerXdl,
|
||||
typename TileDesc_K0_MN_K1>
|
||||
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
@@ -326,10 +330,12 @@ struct GridwiseMoeGemmMX
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_unmerge_transform(make_tuple(Number<MNXdlPerWave / MNXdlPack>{},
|
||||
Number<MNWaves>{},
|
||||
Number<MNXdlPack>{},
|
||||
Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
@@ -513,16 +519,18 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
|
||||
MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
|
||||
{
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWave, MPerXdl>(ABlockDesc_AK0_M_AK1{});
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWave, MXdlPack, MPerXdl>(
|
||||
ABlockDesc_AK0_M_AK1{});
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_BK0_N_BK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
|
||||
MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
|
||||
{
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NXdlPack, NPerXdl>(
|
||||
BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
template <typename ELayout>
|
||||
@@ -789,7 +797,7 @@ struct GridwiseMoeGemmMX
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset =
|
||||
k_id * karg.KRead / (ScaleBlockSize / PackedSize) * karg.StrideScaleA;
|
||||
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
@@ -939,8 +947,11 @@ struct GridwiseMoeGemmMX
|
||||
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}));
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(Number<NXdlPerWave / NXdlPack>{},
|
||||
I1,
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}));
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
|
||||
@@ -969,9 +980,9 @@ struct GridwiseMoeGemmMX
|
||||
AccDataType,
|
||||
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
|
||||
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(
|
||||
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(
|
||||
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
@@ -1395,21 +1406,26 @@ struct GridwiseMoeGemmMX
|
||||
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
auto b_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave / NXdlPack>{},
|
||||
I1,
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
4,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1511,15 +1527,17 @@ struct GridwiseMoeGemmMX
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
|
||||
Sequence<1, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1,
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
|
||||
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
|
||||
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_b));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
@@ -1958,17 +1976,17 @@ struct GridwiseMoeGemmMX
|
||||
problem.NPadded,
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize) /
|
||||
ScalesPerXdlopsRunPerThread,
|
||||
ScalesPerXdlopsRunPerThread),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize),
|
||||
ScalesPerXdlopsRunPerThread,
|
||||
1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -2145,62 +2163,48 @@ struct GridwiseMoeGemmMX
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
|
||||
auto thread_offset_shuffled =
|
||||
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
|
||||
|
||||
auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
|
||||
mfma.selected_mfma.num_threads_per_blk;
|
||||
|
||||
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
|
||||
auto a_thread_offset_m = waveId_m;
|
||||
|
||||
// get each thread's offset int the scale tensor
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock;
|
||||
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
|
||||
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
|
||||
const index_t fused_token =
|
||||
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scale_gather_offsets(m0) =
|
||||
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize);
|
||||
});
|
||||
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather<
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
AScaleDataType,
|
||||
AScaleDataType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
|
||||
Sequence<1, 1, 1>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true,
|
||||
MXdlPerWave,
|
||||
KRepeat>(
|
||||
a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets);
|
||||
Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_a));
|
||||
|
||||
// B scale load
|
||||
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
|
||||
auto b_thread_offset_n = waveId_n;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
|
||||
Sequence<1, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1,
|
||||
true>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
|
||||
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
|
||||
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_b));
|
||||
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
@@ -2231,15 +2235,18 @@ struct GridwiseMoeGemmMX
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
|
||||
Sequence<1, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1,
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
|
||||
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
|
||||
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_b));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
|
||||
Reference in New Issue
Block a user