Skip to content

Commit

Permalink
MLAS: more code cleanup (microsoft#7036)
Browse files Browse the repository at this point in the history
Change int32_t->ptrdiff_t when interacting with the threadpool.
Migrate more code from MlasMaskMoveAvx->MlasMaskMoveTableAvx.
Update more code to use FUNCTION_ENTRY macro.
tracysh authored Mar 17, 2021
1 parent 8e0970a commit 90642e7
Showing 23 changed files with 95 additions and 133 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ Module Name:

#pragma once

#include <cstddef>
#include <cstdlib>
#include <cstdint>

@@ -362,7 +363,7 @@ struct MLAS_CONV_PARAMETERS {
size_t OutputSize;
size_t K;
MLAS_CONV_ALGORITHM Algorithm;
int32_t ThreadCount;
ptrdiff_t ThreadCount;
union {
struct {
CBLAS_TRANSPOSE TransB;
2 changes: 0 additions & 2 deletions onnxruntime/core/mlas/lib/amd64/DgemmKernelCommon.inc
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@ FgemmElementShift EQU 3
FgemmElementSize EQU (1 SHL FgemmElementShift)
FgemmElementPtr EQU QWORD PTR
FgemmElementBcst EQU QWORD BCST
FgemmMaskMoveVector EQU MlasMaskMoveAvx64

;
; Define the typed instructions for double precision.
@@ -41,7 +40,6 @@ vmovapf EQU vmovapd
vmovsf EQU vmovsd
vmovupf EQU vmovupd
vmulpf EQU vmulpd
vpcmpgtf EQU vpcmpgtq
vxorpf EQU vxorpd

INCLUDE FgemmKernelCommon.inc
14 changes: 4 additions & 10 deletions onnxruntime/core/mlas/lib/amd64/FgemmKernelAvxCommon.inc
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
;
;--

EXTERN FgemmMaskMoveVector:NEAR
EXTERN MlasMaskMoveTableAvx:NEAR

;
; Macro Description:
@@ -319,15 +319,9 @@ ENDIF
add rbp,FgemmYmmElementCount ; correct for over-subtract above

OutputMasked1xNBlock:
mov FgemmKernelFrame.CountN[rsp],rbp
IF FgemmElementSize EQ 4
vbroadcastsf xmm0,FgemmElementPtr FgemmKernelFrame.CountN[rsp]
ELSE
vmovddup xmm0,FgemmElementPtr FgemmKernelFrame.CountN[rsp]
ENDIF
vpcmpgtf xmm1,xmm0,XMMWORD PTR [FgemmMaskMoveVector+16]
vpcmpgtf xmm0,xmm0,XMMWORD PTR [FgemmMaskMoveVector]
vinsertf128 ymm0,ymm0,xmm1,1
neg rbp
lea rcx,MlasMaskMoveTableAvx+8*4
vmovdqu ymm0,YMMWORD PTR [rcx+rbp*FgemmElementSize]
test r15b,r15b ; ZeroMode?
jnz StoreMasked1xNBlock
EmitIfCountGE RowCount, 1, <vmaskmovpf ymm8,ymm0,YMMWORD PTR [r8]>
8 changes: 4 additions & 4 deletions onnxruntime/core/mlas/lib/amd64/FgemmKernelFma3Common.inc
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
;
;--

EXTERN FgemmMaskMoveVector:NEAR
EXTERN MlasMaskMoveTableAvx:NEAR

;
; Macro Description:
@@ -359,9 +359,9 @@ ENDIF
add rbp,FgemmYmmElementCount ; correct for over-subtract above

OutputMasked1xNBlock:
mov FgemmKernelFrame.CountN[rsp],rbp
vbroadcastsf ymm0,FgemmElementPtr FgemmKernelFrame.CountN[rsp]
vpcmpgtf ymm0,ymm0,YMMWORD PTR [FgemmMaskMoveVector]
neg rbp
lea rcx,MlasMaskMoveTableAvx+8*4
vmovdqu ymm0,YMMWORD PTR [rcx+rbp*FgemmElementSize]
test r15b,r15b ; ZeroMode?
jnz MultiplyAlphaMasked1xNBlock
EmitIfCountGE RowCount, 1, <vmaskmovpf ymm4,ymm0,YMMWORD PTR [r8]>
2 changes: 0 additions & 2 deletions onnxruntime/core/mlas/lib/amd64/SgemmKernelCommon.inc
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@ FgemmElementShift EQU 2
FgemmElementSize EQU (1 SHL FgemmElementShift)
FgemmElementPtr EQU DWORD PTR
FgemmElementBcst EQU DWORD BCST
FgemmMaskMoveVector EQU MlasMaskMoveAvx

;
; Define the typed instructions for single precision.
@@ -41,7 +40,6 @@ vmovapf EQU vmovaps
vmovsf EQU vmovss
vmovupf EQU vmovups
vmulpf EQU vmulps
vpcmpgtf EQU vpcmpgtd
vxorpf EQU vxorps

INCLUDE FgemmKernelCommon.inc
10 changes: 5 additions & 5 deletions onnxruntime/core/mlas/lib/compute.cpp
Original file line number Diff line number Diff line change
@@ -69,7 +69,7 @@ MLAS_INTERNAL_DATA const float MlasMinimumF32Value = std::numeric_limits<float>:
//

struct MLAS_SOFTMAX_WORK_BLOCK {
int32_t ThreadCountN;
ptrdiff_t ThreadCountN;
bool LogSoftmax;
const float* Input;
float* Output;
@@ -801,7 +801,7 @@ Return Value:
void
MlasComputeSoftmaxThreaded(
void* Context,
int32_t Index
ptrdiff_t Index
)
/*++
@@ -971,18 +971,18 @@ Return Value:
// another thread.
//

int32_t ThreadCountN = MlasGetMaximumThreadCount(ThreadPool);
ptrdiff_t ThreadCountN = MlasGetMaximumThreadCount(ThreadPool);

if (size_t(ThreadCountN) > N) {
ThreadCountN = int32_t(N);
ThreadCountN = ptrdiff_t(N);
}

constexpr size_t MinimumElementsPerThread = 16384;

size_t BlockCount = ((N * D) / MinimumElementsPerThread) + 1;

if (size_t(ThreadCountN) > BlockCount) {
ThreadCountN = int32_t(BlockCount);
ThreadCountN = ptrdiff_t(BlockCount);
}

WorkBlock.ThreadCountN = ThreadCountN;
34 changes: 13 additions & 21 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ struct MLAS_CONV_WORK_BLOCK {
size_t StartN;
size_t CountN;
} Segments[MLAS_MAXIMUM_THREAD_COUNT];
int32_t TargetThreadCount;
ptrdiff_t TargetThreadCount;
};

void
@@ -609,7 +609,7 @@ Return Value:
void
MlasConvOperationThreaded(
void* Context,
int32_t Index
ptrdiff_t Index
)
/*++
@@ -645,7 +645,7 @@ Return Value:
void
MlasConvGemmDirectThreaded(
void* Context,
int32_t Index
ptrdiff_t Index
)
/*++
@@ -677,21 +677,13 @@ Return Value:
const size_t GroupCount = Parameters->GroupCount;
const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;

const size_t TargetThreadCount = WorkBlock->TargetThreadCount;

const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount;
const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount;

size_t BatchGroupStart;
size_t BatchGroupEnd;
size_t BatchGroupRemaining;

if (uint32_t(Index) < BatchGroupCountExtra) {
BatchGroupStart = (BatchGroupCountPerThread + 1) * Index;
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1;
} else {
BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra;
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread;
}
MlasPartitionWork(Index, WorkBlock->TargetThreadCount, BatchGroupCount,
&BatchGroupStart, &BatchGroupRemaining);

size_t BatchGroupEnd = BatchGroupStart + BatchGroupRemaining;

//
// Iterate over the batch and groups allocated to this thread.
@@ -889,10 +881,10 @@ Return Value:

const size_t BatchGroupCount = BatchCount * GroupCount;

int32_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);

if (size_t(TargetThreadCount) >= BatchGroupCount) {
TargetThreadCount = int32_t(BatchGroupCount);
TargetThreadCount = ptrdiff_t(BatchGroupCount);
}

MLAS_CONV_WORK_BLOCK WorkBlock;
@@ -1213,16 +1205,16 @@ Return Value:
// threaded path.
//

int32_t TargetThreadCount;
ptrdiff_t TargetThreadCount;
double Complexity = double(FilterCount) * double(OutputSize) * double(K);

if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * MLAS_MAXIMUM_THREAD_COUNT)) {
TargetThreadCount = int32_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1;
} else {
TargetThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
}

int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);
ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);

if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
24 changes: 12 additions & 12 deletions onnxruntime/core/mlas/lib/dgemm.cpp
Original file line number Diff line number Diff line change
@@ -32,8 +32,8 @@ Module Name:
//

struct MLAS_DGEMM_WORK_BLOCK {
int32_t ThreadCountM;
int32_t ThreadCountN;
ptrdiff_t ThreadCountM;
ptrdiff_t ThreadCountN;
CBLAS_TRANSPOSE TransA;
CBLAS_TRANSPOSE TransB;
size_t M;
@@ -751,7 +751,7 @@ Return Value:
void
MlasDgemmThreaded(
void* Context,
int32_t ThreadId
ptrdiff_t ThreadId
)
/*++
@@ -774,11 +774,11 @@ Return Value:
{
const auto* WorkBlock = (MLAS_DGEMM_WORK_BLOCK*)Context;

const int32_t ThreadCountM = WorkBlock->ThreadCountM;
const int32_t ThreadCountN = WorkBlock->ThreadCountN;
const ptrdiff_t ThreadCountM = WorkBlock->ThreadCountM;
const ptrdiff_t ThreadCountN = WorkBlock->ThreadCountN;

const int32_t ThreadIdM = ThreadId / ThreadCountN;
const int32_t ThreadIdN = ThreadId % ThreadCountN;
const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN;
const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN;

//
// Partition the operation along the M dimension.
@@ -864,15 +864,15 @@ Return Value:

const double Complexity = double(M) * double(N) * double(K);

int32_t TargetThreadCount;
ptrdiff_t TargetThreadCount;

if (Complexity < double(MLAS_DGEMM_THREAD_COMPLEXITY * MlasPlatform.MaximumThreadCount)) {
TargetThreadCount = int32_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1;
TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_DGEMM_THREAD_COMPLEXITY)) + 1;
} else {
TargetThreadCount = MlasPlatform.MaximumThreadCount;
}

int32_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);
ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool);

if (TargetThreadCount >= MaximumThreadCount) {
TargetThreadCount = MaximumThreadCount;
@@ -891,7 +891,7 @@ Return Value:
MLAS_DGEMM_STRIDEN_THREAD_ALIGN;

if (size_t(TargetThreadCount) > BlockedN) {
TargetThreadCount = int32_t(BlockedN);
TargetThreadCount = ptrdiff_t(BlockedN);
}

WorkBlock->ThreadCountM = 1;
@@ -900,7 +900,7 @@ Return Value:
} else {

if (size_t(TargetThreadCount) > M) {
TargetThreadCount = int32_t(M);
TargetThreadCount = ptrdiff_t(M);
}

WorkBlock->ThreadCountM = TargetThreadCount;
16 changes: 8 additions & 8 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
@@ -733,9 +733,9 @@ struct MLAS_PLATFORM {
MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel;
uint32_t NchwcBlockSize;
uint32_t PreferredBufferAlignment;
uint32_t MaximumThreadCount;
int32_t MaximumThreadCount;
#else
static constexpr uint32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
#endif

#if defined(MLAS_TARGET_ARM64)
@@ -753,19 +753,19 @@ typedef
void
(MLAS_THREADED_ROUTINE)(
void* Context,
int32_t Index
ptrdiff_t Index
);

void
MlasExecuteThreaded(
MLAS_THREADED_ROUTINE* ThreadedRoutine,
void* Context,
int32_t Iterations,
ptrdiff_t Iterations,
MLAS_THREADPOOL* ThreadPool
);

inline
int32_t
ptrdiff_t
MlasGetMaximumThreadCount(
MLAS_THREADPOOL* ThreadPool
)
@@ -786,8 +786,8 @@ MlasGetMaximumThreadCount(
inline
void
MlasPartitionWork(
int32_t ThreadId,
int32_t ThreadCount,
ptrdiff_t ThreadId,
ptrdiff_t ThreadCount,
size_t TotalWork,
size_t* WorkIndex,
size_t* WorkRemaining
@@ -796,7 +796,7 @@ MlasPartitionWork(
const size_t WorkPerThread = TotalWork / ThreadCount;
const size_t WorkPerThreadExtra = TotalWork % ThreadCount;

if (uint32_t(ThreadId) < WorkPerThreadExtra) {
if (size_t(ThreadId) < WorkPerThreadExtra) {
*WorkIndex = (WorkPerThread + 1) * ThreadId;
*WorkRemaining = WorkPerThread + 1;
} else {
6 changes: 0 additions & 6 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
@@ -41,12 +41,6 @@ MLAS_PLATFORM MlasPlatform;

MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint32_t MlasMaskMoveAvx[8], 32) = { 0, 1, 2, 3, 4, 5, 6, 7 };

//
// Stores a vector to build a conditional load/store mask for vmaskmovpd.
//

MLAS_INTERNAL_DATA MLAS_DECLSPEC_ALIGN(const uint64_t MlasMaskMoveAvx64[4], 32) = { 0, 1, 2, 3 };

//
// Stores a table of AVX vmaskmovps/vmaskmovpd load/store masks.
//
Loading
Oops, something went wrong.

0 comments on commit 90642e7

Please sign in to comment.