Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved #4 / 3494: AVX2 bugfixes + no code duplication for the integer workhorses in there #5

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 54 additions & 253 deletions src/arch/intsimdmatrixavx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,253 +132,52 @@ static inline __m128i load64_to_128(const int8_t *wi_) {
}

#if defined(FAST_FLOAT)
static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales,
float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256d scale01234567 = _mm256_loadu_ps(scales);
//~ __m256d scale4567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result));
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
res01234567 = _mm256_mul_pd(res01234567, scale01234567);
//~ res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v, res01234567);
//~ _mm256_storeu_pd(v + 4, res4567);
}

static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
const float *&scales, float *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256d scale0123 = _mm256_loadu_ps(scales);
__m256d scale4567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v, res0123);
_mm256_storeu_ps(v + 8, res4567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale0123 = _mm256_loadu_ps(scales + 16);
scale4567 = _mm256_loadu_ps(scales + 24);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v + 16, res0123);
_mm256_storeu_ps(v + 24, res4567);
wi += 16;
scales += 16;
v += 16;
}

// Computes part of matrix.vector v = Wu. Computes N=64 results.
// The weights *must* be arranged so that consecutive reads from wi
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
// bias weights, before continuing with any more weights.
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
__m256i result2 = _mm256_setzero_si256();
__m256i result3 = _mm256_setzero_si256();
__m256i result4 = _mm256_setzero_si256();
__m256i result5 = _mm256_setzero_si256();
__m256i result6 = _mm256_setzero_si256();
__m256i result7 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
}
}
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
ExtractResults16(result4, result5, wi, scales, v);
ExtractResults16(result6, result7, wi, scales, v);
}

// Computes part of matrix.vector v = Wu. Computes N=32 results.
// For details see PartialMatrixDotVector64 with N=32.
static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
__m256i result2 = _mm256_setzero_si256();
__m256i result3 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
}
}
ExtractResults16(result0, result1, wi, scales, v);
ExtractResults16(result2, result3, wi, scales, v);
}

// Computes part of matrix.vector v = Wu. Computes N=16 results.
// For details see PartialMatrixDotVector64 with N=16.
static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
__m256i result1 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
}
}
ExtractResults16(result0, result1, wi, scales, v);
static inline void ExtractResults8(__m256i result, const int8_t* wi, const TFloat* scales,
TFloat* v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(result);
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
}

// Computes part of matrix.vector v = Wu. Computes N=8 results.
// For details see PartialMatrixDotVector64 with N=8.
static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
int num_in, float *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
__m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
// Initialize all the results to 0.
__m256i result0 = _mm256_setzero_si256();
// Iterate over the input (u), one registerful at a time.
for (int j = 0; j < num_in;) {
__m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
// Inputs are processed in groups of kNumInputsPerGroup, replicated
// kNumInputGroups times.
for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
// Replicate the low 32 bits (4 inputs) 8 times.
__m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
// Rotate the inputs in groups of 4, so the next 4 inputs are ready.
inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
__m256i weights, reps;
// Mul-add, with horizontal add of the 4 inputs to each of the results.
MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
}
}
ExtractResults8(result0, wi, scales, v);
static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t*& wi,
const TFloat*& scales, TFloat*& v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(result0);
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale01234567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res01234567 = _mm256_cvtepi32_ps(result1);
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v + 8, res01234567);
wi += 16;
scales += 16;
v += 16;
}

static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
const int8_t *u, float *v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
// Each call to a partial_func_ produces group_size outputs, except the
// last one, which can produce less.
const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
int output = 0;

int w_step = (rounded_num_in + 1) * group_size;

// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
}
group_size /= 2;
w_step /= 2;

if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;

if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
wi += w_step;
scales += group_size;
v += group_size;
output += group_size;
}
group_size /= 2;
w_step /= 2;

if (output + group_size <= rounded_num_out) {
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
}
#else
static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
double *v) {

static inline void ExtractResults8(__m256i result, const int8_t *wi, const TFloat *scales,
TFloat *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
Expand All @@ -396,7 +195,7 @@ static inline void ExtractResults8(__m256i result, const int8_t *wi, const doubl
}

static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
const double *&scales, double *&v) {
const TFloat *&scales, TFloat *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
Expand Down Expand Up @@ -430,15 +229,17 @@ static inline void ExtractResults16(__m256i result0, __m256i result1, const int8
v += 16;
}

#endif

// Computes part of matrix.vector v = Wu. Computes N=64 results.
// The weights *must* be arranged so that consecutive reads from wi
// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
// (kNumInputsPerGroup inputs))). After that there must be N consecutive
// bias weights, before continuing with any more weights.
// u must be padded out with zeros to
// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static void PartialMatrixDotVector64(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand Down Expand Up @@ -482,8 +283,8 @@ static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, con

// Computes part of matrix.vector v = Wu. Computes N=32 results.
// For details see PartialMatrixDotVector64 with N=32.
static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static void PartialMatrixDotVector32(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand Down Expand Up @@ -517,8 +318,8 @@ static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, con

// Computes part of matrix.vector v = Wu. Computes N=16 results.
// For details see PartialMatrixDotVector64 with N=16.
static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static void PartialMatrixDotVector16(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand Down Expand Up @@ -547,8 +348,8 @@ static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, con

// Computes part of matrix.vector v = Wu. Computes N=8 results.
// For details see PartialMatrixDotVector64 with N=8.
static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
int num_in, double *v) {
static inline void PartialMatrixDotVector8(const int8_t *wi, const TFloat *scales, const int8_t *u,
int num_in, TFloat *v) {
// Register containing 16-bit ones for horizontal add with 16->32 bit
// conversion.
__m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
Expand All @@ -573,8 +374,8 @@ static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scale
ExtractResults8(result0, wi, scales, v);
}

static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
const int8_t *u, double *v) {
static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales,
const int8_t *u, TFloat *v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
// Each call to a partial_func_ produces group_size outputs, except the
Expand Down Expand Up @@ -621,7 +422,7 @@ static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
}
}
#endif


static const IntSimdMatrix simdMatrix = {
// Function.
Expand Down