Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfixing the AVX2 Extract8+16 codes, where there's lines like [...] #4

Merged
merged 1 commit into from
Jul 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 20 additions & 29 deletions src/arch/intsimdmatrixavx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,54 +132,45 @@ static inline __m128i load64_to_128(const int8_t *wi_) {
}

#if defined(FAST_FLOAT)
static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales,
float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg

static inline void ExtractResults8(__m256i result, const int8_t *wi,
const float *scales, float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256d scale01234567 = _mm256_loadu_ps(scales);
//~ __m256d scale4567 = _mm256_loadu_ps(scales + 8);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result));
__m256 res01234567 = _mm256_cvtepi32_ps(result);
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
res01234567 = _mm256_mul_pd(res01234567, scale01234567);
//~ res4567 = _mm256_mul_pd(res4567, scale4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
//~ _mm256_storeu_pd(v + 4, res4567);
}

static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
const float *&scales, float *&v) {
static inline void ExtractResults16(__m256i result0, __m256i result1,
const int8_t *&wi, const float *&scales,
float *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
const __m256i bias_scale =
_mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256d scale0123 = _mm256_loadu_ps(scales);
__m256d scale4567 = _mm256_loadu_ps(scales + 8);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
__m256 res01234567 = _mm256_cvtepi32_ps(result0);
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v, res0123);
_mm256_storeu_ps(v + 8, res4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale0123 = _mm256_loadu_ps(scales + 16);
scale4567 = _mm256_loadu_ps(scales + 24);
scale01234567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res01234567 = _mm256_cvtepi32_ps(result1);
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v + 16, res0123);
_mm256_storeu_ps(v + 24, res4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v + 8, res01234567);
wi += 16;
scales += 16;
v += 16;
Expand Down