From 55be3ac273a32367800557138959b98c65be4360 Mon Sep 17 00:00:00 2001
From: Benjamin Trent
Date: Thu, 14 Nov 2024 16:46:21 -0500
Subject: [PATCH] Add multi_dense_vector value access to scripts (#116610)
This adds value access to multi_dense_vector values in scripts. The
users will get:
- Count of vectors per field
- Magnitudes of all the individual vectors
- Access to each vector with an iterator
I will happily take design critiques around how these are exposed in
scripting.
I initially though of just providing directly `float[][]` access, but
this seems to have some unfavorable behavior around creating a TON of
garbage. The reason is that each field could have a different number of
vectors, so allocating a new collection of `float[dim]` for every field
seemed rough.
Generally, when scripting or using the vectors, an iterator should be
enough and I have the iterator backed by a simple buffer to keep garbage
down.
---
.../org.elasticsearch.script.fields.txt | 15 +
.../painless/org.elasticsearch.txt | 5 +
.../181_multi_dense_vector_dv_fields_api.yml | 178 +++++++++
.../MultiDenseVectorScriptDocValues.java | 81 ++++
.../vectors/MultiVectorDVLeafFieldData.java | 31 +-
.../vectors/MultiVectorIndexFieldData.java | 2 +-
.../mapper/vectors/VectorEncoderDecoder.java | 20 +
.../action/search/SearchCapabilities.java | 3 +
.../field/vectors/BitMultiDenseVector.java | 38 ++
.../BitMultiDenseVectorDocValuesField.java | 31 ++
.../field/vectors/ByteMultiDenseVector.java | 91 +++++
.../ByteMultiDenseVectorDocValuesField.java | 142 +++++++
.../field/vectors/FloatMultiDenseVector.java | 61 +++
.../FloatMultiDenseVectorDocValuesField.java | 143 +++++++
.../field/vectors/MultiDenseVector.java | 71 ++++
.../MultiDenseVectorDocValuesField.java | 57 +++
.../MultiDenseVectorScriptDocValuesTests.java | 374 ++++++++++++++++++
17 files changed, 1330 insertions(+), 13 deletions(-)
create mode 100644 modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml
create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java
create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java
create mode 100644 server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java
diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt
index a739635e85a9c..875b9a1dac3e8 100644
--- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt
+++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt
@@ -132,6 +132,21 @@ class org.elasticsearch.script.field.SeqNoDocValuesField @dynamic_type {
class org.elasticsearch.script.field.VersionDocValuesField @dynamic_type {
}
+class org.elasticsearch.script.field.vectors.MultiDenseVector {
+ MultiDenseVector EMPTY
+ float[] getMagnitudes()
+
+ Iterator getVectors()
+ boolean isEmpty()
+ int getDims()
+ int size()
+}
+
+class org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField {
+ MultiDenseVector get()
+ MultiDenseVector get(MultiDenseVector)
+}
+
class org.elasticsearch.script.field.vectors.DenseVector {
DenseVector EMPTY
float getMagnitude()
diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt
index 7ab9eb32852b6..b2db0d1006d40 100644
--- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt
+++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt
@@ -123,6 +123,11 @@ class org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues {
float getMagnitude()
}
+class org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues {
+ Iterator getVectorValues()
+ float[] getMagnitudes()
+}
+
class org.apache.lucene.util.BytesRef {
byte[] bytes
int offset
diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml
new file mode 100644
index 0000000000000..66cb3f3c46fcc
--- /dev/null
+++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml
@@ -0,0 +1,178 @@
+setup:
+ - requires:
+ capabilities:
+ - method: POST
+ path: /_search
+ capabilities: [ multi_dense_vector_script_access ]
+ test_runner_features: capabilities
+ reason: "Support for multi dense vector field script access capability required"
+ - skip:
+ features: headers
+
+ - do:
+ indices.create:
+ index: test-index
+ body:
+ settings:
+ number_of_shards: 1
+ mappings:
+ properties:
+ vector:
+ type: multi_dense_vector
+ dims: 5
+ byte_vector:
+ type: multi_dense_vector
+ dims: 5
+ element_type: byte
+ bit_vector:
+ type: multi_dense_vector
+ dims: 40
+ element_type: bit
+ - do:
+ index:
+ index: test-index
+ id: "1"
+ body:
+ vector: [[230.0, 300.33, -34.8988, 15.555, -200.0], [-0.5, 100.0, -13, 14.8, -156.0]]
+ byte_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]
+ bit_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]
+
+ - do:
+ index:
+ index: test-index
+ id: "3"
+ body:
+ vector: [[0.5, 111.3, -13.0, 14.8, -156.0]]
+ byte_vector: [[2, 18, -5, 0, -124]]
+ bit_vector: [[2, 18, -5, 0, -124]]
+
+ - do:
+ indices.refresh: {}
+---
+"Test vector magnitude equality":
+ - skip:
+ features: close_to
+
+ - do:
+ headers:
+ Content-Type: application/json
+ search:
+ rest_total_hits_as_int: true
+ body:
+ query:
+ script_score:
+ query: {match_all: {} }
+ script:
+ source: "doc['vector'].magnitudes[0]"
+
+ - match: {hits.total: 2}
+
+ - match: {hits.hits.0._id: "1"}
+ - close_to: {hits.hits.0._score: {value: 429.6021, error: 0.01}}
+
+ - match: {hits.hits.1._id: "3"}
+ - close_to: {hits.hits.1._score: {value: 192.6447, error: 0.01}}
+
+ - do:
+ headers:
+ Content-Type: application/json
+ search:
+ rest_total_hits_as_int: true
+ body:
+ query:
+ script_score:
+ query: {match_all: {} }
+ script:
+ source: "doc['byte_vector'].magnitudes[0]"
+
+ - match: {hits.total: 2}
+
+ - match: {hits.hits.0._id: "3"}
+ - close_to: {hits.hits.0._score: {value: 125.41531, error: 0.01}}
+
+ - match: {hits.hits.1._id: "1"}
+ - close_to: {hits.hits.1._score: {value: 19.07878, error: 0.01}}
+
+ - do:
+ headers:
+ Content-Type: application/json
+ search:
+ rest_total_hits_as_int: true
+ body:
+ query:
+ script_score:
+ query: {match_all: {} }
+ script:
+ source: "doc['bit_vector'].magnitudes[0]"
+
+ - match: {hits.total: 2}
+
+ - match: {hits.hits.0._id: "1"}
+ - close_to: {hits.hits.0._score: {value: 3.872983, error: 0.01}}
+
+ - match: {hits.hits.1._id: "3"}
+ - close_to: {hits.hits.1._score: {value: 3.464101, error: 0.01}}
+---
+"Test vector value scoring":
+ - skip:
+ features: close_to
+
+ - do:
+ headers:
+ Content-Type: application/json
+ search:
+ rest_total_hits_as_int: true
+ body:
+ query:
+ script_score:
+ query: {match_all: {} }
+ script:
+ source: "doc['vector'].vectorValues.next()[0];"
+
+ - match: {hits.total: 2}
+
+ - match: {hits.hits.0._id: "1"}
+ - close_to: {hits.hits.0._score: {value: 230, error: 0.01}}
+
+ - match: {hits.hits.1._id: "3"}
+ - close_to: {hits.hits.1._score: {value: 0.5, error: 0.01}}
+
+ - do:
+ headers:
+ Content-Type: application/json
+ search:
+ rest_total_hits_as_int: true
+ body:
+ query:
+ script_score:
+ query: {match_all: {} }
+ script:
+ source: "doc['byte_vector'].vectorValues.next()[0];"
+
+ - match: {hits.total: 2}
+
+ - match: {hits.hits.0._id: "1"}
+ - close_to: {hits.hits.0._score: {value: 8, error: 0.01}}
+
+ - match: {hits.hits.1._id: "3"}
+ - close_to: {hits.hits.1._score: {value: 2, error: 0.01}}
+
+ - do:
+ headers:
+ Content-Type: application/json
+ search:
+ rest_total_hits_as_int: true
+ body:
+ query:
+ script_score:
+ query: {match_all: {} }
+ script:
+ source: "doc['bit_vector'].vectorValues.next()[0];"
+
+ - match: {hits.total: 2}
+
+ - match: {hits.hits.0._id: "1"}
+ - close_to: {hits.hits.0._score: {value: 8, error: 0.01}}
+
+ - match: {hits.hits.1._id: "3"}
+ - close_to: {hits.hits.1._score: {value: 2, error: 0.01}}
diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java
new file mode 100644
index 0000000000000..a91960832239f
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.mapper.vectors;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.index.fielddata.ScriptDocValues;
+import org.elasticsearch.script.field.vectors.MultiDenseVector;
+
+import java.util.Iterator;
+
+public class MultiDenseVectorScriptDocValues extends ScriptDocValues {
+
+ public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a multi-vector field!";
+
+ private final int dims;
+ protected final MultiDenseVectorSupplier dvSupplier;
+
+ public MultiDenseVectorScriptDocValues(MultiDenseVectorSupplier supplier, int dims) {
+ super(supplier);
+ this.dvSupplier = supplier;
+ this.dims = dims;
+ }
+
+ public int dims() {
+ return dims;
+ }
+
+ private MultiDenseVector getCheckedVector() {
+ MultiDenseVector vector = dvSupplier.getInternal();
+ if (vector == null) {
+ throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+ }
+ return vector;
+ }
+
+ /**
+ * Get multi-dense vector's value as an array of floats
+ */
+ public Iterator getVectorValues() {
+ return getCheckedVector().getVectors();
+ }
+
+ /**
+ * Get dense vector's magnitude
+ */
+ public float[] getMagnitudes() {
+ return getCheckedVector().getMagnitudes();
+ }
+
+ @Override
+ public BytesRef get(int index) {
+ throw new UnsupportedOperationException(
+ "accessing a multi-vector field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead."
+ );
+ }
+
+ @Override
+ public int size() {
+ MultiDenseVector mdv = dvSupplier.getInternal();
+ if (mdv != null) {
+ return mdv.size();
+ }
+ return 0;
+ }
+
+ public interface MultiDenseVectorSupplier extends Supplier {
+ @Override
+ default BytesRef getInternal(int index) {
+ throw new UnsupportedOperationException();
+ }
+
+ MultiDenseVector getInternal();
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java
index cc6fb38274451..b9716d315f33a 100644
--- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java
+++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java
@@ -9,37 +9,44 @@
package org.elasticsearch.index.mapper.vectors;
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
-import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.fielddata.LeafFieldData;
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
+import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField;
+import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField;
+import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField;
+
+import java.io.IOException;
final class MultiVectorDVLeafFieldData implements LeafFieldData {
private final LeafReader reader;
private final String field;
- private final IndexVersion indexVersion;
private final DenseVectorFieldMapper.ElementType elementType;
private final int dims;
- MultiVectorDVLeafFieldData(
- LeafReader reader,
- String field,
- IndexVersion indexVersion,
- DenseVectorFieldMapper.ElementType elementType,
- int dims
- ) {
+ MultiVectorDVLeafFieldData(LeafReader reader, String field, DenseVectorFieldMapper.ElementType elementType, int dims) {
this.reader = reader;
this.field = field;
- this.indexVersion = indexVersion;
this.elementType = elementType;
this.dims = dims;
}
@Override
public DocValuesScriptFieldFactory getScriptFieldFactory(String name) {
- // TODO
- return null;
+ try {
+ BinaryDocValues values = DocValues.getBinary(reader, field);
+ BinaryDocValues magnitudeValues = DocValues.getBinary(reader, field + MultiDenseVectorFieldMapper.VECTOR_MAGNITUDES_SUFFIX);
+ return switch (elementType) {
+ case BYTE -> new ByteMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
+ case FLOAT -> new FloatMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
+ case BIT -> new BitMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
+ };
+ } catch (IOException e) {
+ throw new IllegalStateException("Cannot load doc values for multi-vector field!", e);
+ }
}
@Override
diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java
index 65ef492ce052b..44a666e25a611 100644
--- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java
+++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java
@@ -55,7 +55,7 @@ public ValuesSourceType getValuesSourceType() {
@Override
public MultiVectorDVLeafFieldData load(LeafReaderContext context) {
- return new MultiVectorDVLeafFieldData(context.reader(), fieldName, indexVersion, elementType, dims);
+ return new MultiVectorDVLeafFieldData(context.reader(), fieldName, elementType, dims);
}
@Override
diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java
index 9d09a7493d605..3db2d164846bd 100644
--- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java
+++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java
@@ -84,4 +84,24 @@ public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorB
}
}
+ public static float[] getMultiMagnitudes(BytesRef magnitudes) {
+ assert magnitudes.length % Float.BYTES == 0;
+ float[] multiMagnitudes = new float[magnitudes.length / Float.BYTES];
+ ByteBuffer byteBuffer = ByteBuffer.wrap(magnitudes.bytes, magnitudes.offset, magnitudes.length).order(ByteOrder.LITTLE_ENDIAN);
+ for (int i = 0; i < magnitudes.length / Float.BYTES; i++) {
+ multiMagnitudes[i] = byteBuffer.getFloat();
+ }
+ return multiMagnitudes;
+ }
+
+ public static void decodeMultiDenseVector(BytesRef vectorBR, int numVectors, float[][] multiVectorValue) {
+ if (vectorBR == null) {
+ throw new IllegalArgumentException(MultiDenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
+ }
+ FloatBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
+ for (int i = 0; i < numVectors; i++) {
+ fb.get(multiVectorValue[i]);
+ }
+ }
+
}
diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java
index 3bc1c467323a3..7b57481ad5716 100644
--- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java
+++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java
@@ -38,6 +38,8 @@ private SearchCapabilities() {}
private static final String MULTI_DENSE_VECTOR_FIELD_MAPPER = "multi_dense_vector_field_mapper";
/** Support propagating nested retrievers' inner_hits to top-level compound retrievers . */
private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";
+ /** Support multi-dense-vector script field access. */
+ private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access";
public static final Set CAPABILITIES;
static {
@@ -50,6 +52,7 @@ private SearchCapabilities() {}
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
+ capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS);
}
if (Build.current().isSnapshot()) {
capabilities.add(KQL_QUERY_SUPPORTED);
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java
new file mode 100644
index 0000000000000..24e19a803ff38
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import org.apache.lucene.util.BytesRef;
+
+import java.util.Iterator;
+
+public class BitMultiDenseVector extends ByteMultiDenseVector {
+ public BitMultiDenseVector(Iterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
+ super(vectorValues, magnitudesBytes, numVecs, dims);
+ }
+
+ @Override
+ public void checkDimensions(int qvDims) {
+ if (qvDims != dims) {
+ throw new IllegalArgumentException(
+ "The query vector has a different number of dimensions ["
+ + qvDims * Byte.SIZE
+ + "] than the document vectors ["
+ + dims * Byte.SIZE
+ + "]."
+ );
+ }
+ }
+
+ @Override
+ public int getDims() {
+ return dims * Byte.SIZE;
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java
new file mode 100644
index 0000000000000..35a43eabb8f0c
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import org.apache.lucene.index.BinaryDocValues;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
+
+public class BitMultiDenseVectorDocValuesField extends ByteMultiDenseVectorDocValuesField {
+
+ public BitMultiDenseVectorDocValuesField(
+ BinaryDocValues input,
+ BinaryDocValues magnitudes,
+ String name,
+ ElementType elementType,
+ int dims
+ ) {
+ super(input, magnitudes, name, elementType, dims / 8);
+ }
+
+ @Override
+ protected MultiDenseVector getVector() {
+ return new BitMultiDenseVector(vectorValue, magnitudesValue, numVecs, dims);
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java
new file mode 100644
index 0000000000000..e610d10146b2f
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java
@@ -0,0 +1,91 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
+
+import java.util.Iterator;
+
+public class ByteMultiDenseVector implements MultiDenseVector {
+
+ protected final Iterator vectorValues;
+ protected final int numVecs;
+ protected final int dims;
+
+ private Iterator floatDocVectors;
+ private float[] magnitudes;
+ private final BytesRef magnitudesBytes;
+
+ public ByteMultiDenseVector(Iterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
+ assert magnitudesBytes.length == numVecs * Float.BYTES;
+ this.vectorValues = vectorValues;
+ this.numVecs = numVecs;
+ this.dims = dims;
+ this.magnitudesBytes = magnitudesBytes;
+ }
+
+ @Override
+ public Iterator getVectors() {
+ if (floatDocVectors == null) {
+ floatDocVectors = new ByteToFloatIteratorWrapper(vectorValues, dims);
+ }
+ return floatDocVectors;
+ }
+
+ @Override
+ public float[] getMagnitudes() {
+ if (magnitudes == null) {
+ magnitudes = VectorEncoderDecoder.getMultiMagnitudes(magnitudesBytes);
+ }
+ return magnitudes;
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return false;
+ }
+
+ @Override
+ public int getDims() {
+ return dims;
+ }
+
+ @Override
+ public int size() {
+ return numVecs;
+ }
+
+ static class ByteToFloatIteratorWrapper implements Iterator {
+ private final Iterator byteIterator;
+ private final float[] buffer;
+ private final int dims;
+
+ ByteToFloatIteratorWrapper(Iterator byteIterator, int dims) {
+ this.byteIterator = byteIterator;
+ this.buffer = new float[dims];
+ this.dims = dims;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return byteIterator.hasNext();
+ }
+
+ @Override
+ public float[] next() {
+ byte[] next = byteIterator.next();
+ for (int i = 0; i < dims; i++) {
+ buffer[i] = next[i];
+ }
+ return buffer;
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java
new file mode 100644
index 0000000000000..d1e062e0a3dee
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java
@@ -0,0 +1,142 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
+import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValuesField {
+
+ protected final BinaryDocValues input;
+ private final BinaryDocValues magnitudes;
+ protected final int dims;
+ protected int numVecs;
+ protected Iterator vectorValue;
+ protected boolean decoded;
+ protected BytesRef value;
+ protected BytesRef magnitudesValue;
+ private byte[] buffer;
+
+ public ByteMultiDenseVectorDocValuesField(
+ BinaryDocValues input,
+ BinaryDocValues magnitudes,
+ String name,
+ ElementType elementType,
+ int dims
+ ) {
+ super(name, elementType);
+ this.input = input;
+ this.dims = dims;
+ this.buffer = new byte[dims];
+ this.magnitudes = magnitudes;
+ }
+
+ @Override
+ public void setNextDocId(int docId) throws IOException {
+ decoded = false;
+ if (input.advanceExact(docId)) {
+ boolean magnitudesFound = magnitudes.advanceExact(docId);
+ assert magnitudesFound;
+ value = input.binaryValue();
+ assert value.length % dims == 0;
+ numVecs = value.length / dims;
+ magnitudesValue = magnitudes.binaryValue();
+ assert magnitudesValue.length == (numVecs * Float.BYTES);
+ } else {
+ value = null;
+ magnitudesValue = null;
+ vectorValue = null;
+ numVecs = 0;
+ }
+ }
+
+ @Override
+ public MultiDenseVectorScriptDocValues toScriptDocValues() {
+ return new MultiDenseVectorScriptDocValues(this, dims);
+ }
+
+ protected MultiDenseVector getVector() {
+ return new ByteMultiDenseVector(vectorValue, magnitudesValue, numVecs, dims);
+ }
+
+ @Override
+ public MultiDenseVector get() {
+ if (isEmpty()) {
+ return MultiDenseVector.EMPTY;
+ }
+ decodeVectorIfNecessary();
+ return getVector();
+ }
+
+ @Override
+ public MultiDenseVector get(MultiDenseVector defaultValue) {
+ if (isEmpty()) {
+ return defaultValue;
+ }
+ decodeVectorIfNecessary();
+ return getVector();
+ }
+
+ @Override
+ public MultiDenseVector getInternal() {
+ return get(null);
+ }
+
+ private void decodeVectorIfNecessary() {
+ if (decoded == false && value != null) {
+ vectorValue = new ByteVectorIterator(value, buffer, numVecs);
+ decoded = true;
+ }
+ }
+
+ @Override
+ public int size() {
+ return value == null ? 0 : value.length / dims;
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return value == null;
+ }
+
+ static class ByteVectorIterator implements Iterator {
+ private final byte[] buffer;
+ private final BytesRef vectorValues;
+ private final int size;
+ private int idx = 0;
+
+ ByteVectorIterator(BytesRef vectorValues, byte[] buffer, int size) {
+ assert vectorValues.length == (buffer.length * size);
+ this.vectorValues = vectorValues;
+ this.size = size;
+ this.buffer = buffer;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return idx < size;
+ }
+
+ @Override
+ public byte[] next() {
+ if (hasNext() == false) {
+ throw new IllegalArgumentException("No more elements in the iterator");
+ }
+ System.arraycopy(vectorValues.bytes, vectorValues.offset + idx * buffer.length, buffer, 0, buffer.length);
+ idx++;
+ return buffer;
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java
new file mode 100644
index 0000000000000..9ffe8b3b970c4
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import org.apache.lucene.util.BytesRef;
+
+import java.util.Iterator;
+
+import static org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder.getMultiMagnitudes;
+
+public class FloatMultiDenseVector implements MultiDenseVector {
+
+ private final BytesRef magnitudes;
+ private float[] magnitudesArray = null;
+ private final int dims;
+ private final int numVectors;
+ private final Iterator decodedDocVector;
+
+ public FloatMultiDenseVector(Iterator decodedDocVector, BytesRef magnitudes, int numVectors, int dims) {
+ assert magnitudes.length == numVectors * Float.BYTES;
+ this.decodedDocVector = decodedDocVector;
+ this.magnitudes = magnitudes;
+ this.numVectors = numVectors;
+ this.dims = dims;
+ }
+
+ @Override
+ public Iterator getVectors() {
+ return decodedDocVector;
+ }
+
+ @Override
+ public float[] getMagnitudes() {
+ if (magnitudesArray == null) {
+ magnitudesArray = getMultiMagnitudes(magnitudes);
+ }
+ return magnitudesArray;
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return false;
+ }
+
+ @Override
+ public int getDims() {
+ return dims;
+ }
+
+ @Override
+ public int size() {
+ return numVectors;
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java
new file mode 100644
index 0000000000000..356db58d989c5
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
+import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.util.Iterator;
+
+public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValuesField {
+
+ private final BinaryDocValues input;
+ private final BinaryDocValues magnitudes;
+ private boolean decoded;
+ private final int dims;
+ private BytesRef value;
+ private BytesRef magnitudesValue;
+ private FloatVectorIterator vectorValues;
+ private int numVectors;
+ private float[] buffer;
+
+ public FloatMultiDenseVectorDocValuesField(
+ BinaryDocValues input,
+ BinaryDocValues magnitudes,
+ String name,
+ ElementType elementType,
+ int dims
+ ) {
+ super(name, elementType);
+ this.input = input;
+ this.magnitudes = magnitudes;
+ this.dims = dims;
+ this.buffer = new float[dims];
+ }
+
+ @Override
+ public void setNextDocId(int docId) throws IOException {
+ decoded = false;
+ if (input.advanceExact(docId)) {
+ boolean magnitudesFound = magnitudes.advanceExact(docId);
+ assert magnitudesFound;
+
+ value = input.binaryValue();
+ assert value.length % (Float.BYTES * dims) == 0;
+ numVectors = value.length / (Float.BYTES * dims);
+ magnitudesValue = magnitudes.binaryValue();
+ assert magnitudesValue.length == (Float.BYTES * numVectors);
+ } else {
+ value = null;
+ magnitudesValue = null;
+ numVectors = 0;
+ }
+ }
+
+ @Override
+ public MultiDenseVectorScriptDocValues toScriptDocValues() {
+ return new MultiDenseVectorScriptDocValues(this, dims);
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return value == null;
+ }
+
+ @Override
+ public MultiDenseVector get() {
+ if (isEmpty()) {
+ return MultiDenseVector.EMPTY;
+ }
+ decodeVectorIfNecessary();
+ return new FloatMultiDenseVector(vectorValues, magnitudesValue, numVectors, dims);
+ }
+
+ @Override
+ public MultiDenseVector get(MultiDenseVector defaultValue) {
+ if (isEmpty()) {
+ return defaultValue;
+ }
+ decodeVectorIfNecessary();
+ return new FloatMultiDenseVector(vectorValues, magnitudesValue, numVectors, dims);
+ }
+
+ @Override
+ public MultiDenseVector getInternal() {
+ return get(null);
+ }
+
+ @Override
+ public int size() {
+ return value == null ? 0 : value.length / (Float.BYTES * dims);
+ }
+
+ private void decodeVectorIfNecessary() {
+ if (decoded == false && value != null) {
+ vectorValues = new FloatVectorIterator(value, buffer, numVectors);
+ decoded = true;
+ }
+ }
+
+ static class FloatVectorIterator implements Iterator {
+ private final float[] buffer;
+ private final FloatBuffer vectorValues;
+ private final int size;
+ private int idx = 0;
+
+ FloatVectorIterator(BytesRef vectorValues, float[] buffer, int size) {
+ assert vectorValues.length == (buffer.length * Float.BYTES * size);
+ this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length)
+ .order(ByteOrder.LITTLE_ENDIAN)
+ .asFloatBuffer();
+ this.size = size;
+ this.buffer = buffer;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return idx < size;
+ }
+
+ @Override
+ public float[] next() {
+ if (hasNext() == false) {
+ throw new IllegalArgumentException("No more elements in the iterator");
+ }
+ vectorValues.get(buffer);
+ idx++;
+ return buffer;
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java
new file mode 100644
index 0000000000000..85c851dbe545c
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import java.util.Iterator;
+
+public interface MultiDenseVector {
+
+ default void checkDimensions(int qvDims) {
+ checkDimensions(getDims(), qvDims);
+ }
+
+ Iterator getVectors();
+
+ float[] getMagnitudes();
+
+ boolean isEmpty();
+
+ int getDims();
+
+ int size();
+
+ static void checkDimensions(int dvDims, int qvDims) {
+ if (dvDims != qvDims) {
+ throw new IllegalArgumentException(
+ "The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
+ );
+ }
+ }
+
+ private static String badQueryVectorType(Object queryVector) {
+ return "Cannot use vector [" + queryVector + "] with class [" + queryVector.getClass().getName() + "] as query vector";
+ }
+
+ MultiDenseVector EMPTY = new MultiDenseVector() {
+ public static final String MISSING_VECTOR_FIELD_MESSAGE = "Multi Dense vector value missing for a field,"
+ + " use isEmpty() to check for a missing vector value";
+
+ @Override
+ public Iterator getVectors() {
+ throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+ }
+
+ @Override
+ public float[] getMagnitudes() {
+ throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return true;
+ }
+
+ @Override
+ public int getDims() {
+ throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+ }
+
+ @Override
+ public int size() {
+ return 0;
+ }
+ };
+}
diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java
new file mode 100644
index 0000000000000..61ae4304683c8
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.script.field.vectors;
+
+import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues;
+import org.elasticsearch.script.field.AbstractScriptFieldFactory;
+import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
+import org.elasticsearch.script.field.Field;
+
+import java.util.Iterator;
+
+import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
+
+public abstract class MultiDenseVectorDocValuesField extends AbstractScriptFieldFactory
+ implements
+ Field,
+ DocValuesScriptFieldFactory,
+ MultiDenseVectorScriptDocValues.MultiDenseVectorSupplier {
+ protected final String name;
+ protected final ElementType elementType;
+
+ public MultiDenseVectorDocValuesField(String name, ElementType elementType) {
+ this.name = name;
+ this.elementType = elementType;
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
+
+ public ElementType getElementType() {
+ return elementType;
+ }
+
+ /**
+ * Get the DenseVector for a document if one exists, DenseVector.EMPTY otherwise
+ */
+ public abstract MultiDenseVector get();
+
+ public abstract MultiDenseVector get(MultiDenseVector defaultValue);
+
+ public abstract MultiDenseVectorScriptDocValues toScriptDocValues();
+
+ // DenseVector fields are single valued, so Iterable does not make sense.
+ @Override
+ public Iterator iterator() {
+ throw new UnsupportedOperationException("Cannot iterate over single valued multi_dense_vector field, use get() instead");
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java
new file mode 100644
index 0000000000000..ef316c5addefa
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java
@@ -0,0 +1,374 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.mapper.vectors;
+
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
+import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField;
+import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField;
+import org.elasticsearch.script.field.vectors.MultiDenseVector;
+import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.index.IndexVersionUtils;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.Iterator;
+import java.util.List;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class MultiDenseVectorScriptDocValuesTests extends ESTestCase {
+
+ public void testFloatGetVectorValueAndGetMagnitude() throws IOException {
+ int dims = 3;
+ float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } };
+ float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } };
+
+ for (IndexVersion indexVersion : List.of(IndexVersionUtils.randomCompatibleVersion(random()), IndexVersion.current())) {
+ BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, indexVersion);
+ BinaryDocValues magnitudeValues = wrap(expectedMagnitudes);
+ MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.FLOAT,
+ dims
+ );
+ MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues();
+ for (int i = 0; i < vectors.length; i++) {
+ field.setNextDocId(i);
+ assertEquals(vectors[i].length, field.size());
+ assertEquals(dims, scriptDocValues.dims());
+ Iterator iterator = scriptDocValues.getVectorValues();
+ float[] magnitudes = scriptDocValues.getMagnitudes();
+ assertEquals(expectedMagnitudes[i].length, magnitudes.length);
+ for (int j = 0; j < vectors[i].length; j++) {
+ assertTrue(iterator.hasNext());
+ assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f);
+ assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f);
+ }
+ }
+ }
+ }
+
+ public void testByteGetVectorValueAndGetMagnitude() throws IOException {
+ int dims = 3;
+ float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } };
+ float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } };
+
+ BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, IndexVersion.current());
+ BinaryDocValues magnitudeValues = wrap(expectedMagnitudes);
+ MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.BYTE,
+ dims
+ );
+ MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues();
+ for (int i = 0; i < vectors.length; i++) {
+ field.setNextDocId(i);
+ assertEquals(vectors[i].length, field.size());
+ assertEquals(dims, scriptDocValues.dims());
+ Iterator iterator = scriptDocValues.getVectorValues();
+ float[] magnitudes = scriptDocValues.getMagnitudes();
+ assertEquals(expectedMagnitudes[i].length, magnitudes.length);
+ for (int j = 0; j < vectors[i].length; j++) {
+ assertTrue(iterator.hasNext());
+ assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f);
+ assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f);
+ }
+ }
+ }
+
+ public void testFloatMetadataAndIterator() throws IOException {
+ int dims = 3;
+ IndexVersion indexVersion = IndexVersion.current();
+ float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.FLOAT), fill(new float[2][dims], ElementType.FLOAT) };
+ float[][] magnitudes = new float[][] { new float[3], new float[2] };
+ BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, indexVersion);
+ BinaryDocValues magnitudeValues = wrap(magnitudes);
+
+ MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.FLOAT,
+ dims
+ );
+ for (int i = 0; i < vectors.length; i++) {
+ field.setNextDocId(i);
+ MultiDenseVector dv = field.get();
+ assertEquals(vectors[i].length, dv.size());
+ assertFalse(dv.isEmpty());
+ assertEquals(dims, dv.getDims());
+ UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator);
+ assertEquals("Cannot iterate over single valued multi_dense_vector field, use get() instead", e.getMessage());
+ }
+ field.setNextDocId(vectors.length);
+ MultiDenseVector dv = field.get();
+ assertEquals(dv, MultiDenseVector.EMPTY);
+ }
+
+ public void testByteMetadataAndIterator() throws IOException {
+ int dims = 3;
+ IndexVersion indexVersion = IndexVersion.current();
+ float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BYTE), fill(new float[2][dims], ElementType.BYTE) };
+ float[][] magnitudes = new float[][] { new float[3], new float[2] };
+ BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, indexVersion);
+ BinaryDocValues magnitudeValues = wrap(magnitudes);
+ MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.BYTE,
+ dims
+ );
+ for (int i = 0; i < vectors.length; i++) {
+ field.setNextDocId(i);
+ MultiDenseVector dv = field.get();
+ assertEquals(vectors[i].length, dv.size());
+ assertFalse(dv.isEmpty());
+ assertEquals(dims, dv.getDims());
+ UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator);
+ assertEquals("Cannot iterate over single valued multi_dense_vector field, use get() instead", e.getMessage());
+ }
+ field.setNextDocId(vectors.length);
+ MultiDenseVector dv = field.get();
+ assertEquals(dv, MultiDenseVector.EMPTY);
+ }
+
+ protected float[][] fill(float[][] vectors, ElementType elementType) {
+ for (float[] vector : vectors) {
+ for (int i = 0; i < vector.length; i++) {
+ vector[i] = elementType == ElementType.FLOAT ? randomFloat() : randomByte();
+ }
+ }
+ return vectors;
+ }
+
+ public void testFloatMissingValues() throws IOException {
+ int dims = 3;
+ float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } };
+ float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } };
+ BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, IndexVersion.current());
+ BinaryDocValues magnitudeValues = wrap(magnitudes);
+ MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.FLOAT,
+ dims
+ );
+ MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues();
+
+ field.setNextDocId(3);
+ assertEquals(0, field.size());
+ Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues);
+ assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
+
+ e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes);
+ assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
+ }
+
+ public void testByteMissingValues() throws IOException {
+ int dims = 3;
+ float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } };
+ float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } };
+ BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, IndexVersion.current());
+ BinaryDocValues magnitudeValues = wrap(magnitudes);
+ MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.BYTE,
+ dims
+ );
+ MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues();
+
+ field.setNextDocId(3);
+ assertEquals(0, field.size());
+ Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues);
+ assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
+
+ e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes);
+ assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
+ }
+
+ public void testFloatGetFunctionIsNotAccessible() throws IOException {
+ int dims = 3;
+ float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } };
+ float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } };
+ BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, IndexVersion.current());
+ BinaryDocValues magnitudeValues = wrap(magnitudes);
+ MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.FLOAT,
+ dims
+ );
+ MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues();
+
+ field.setNextDocId(0);
+ Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
+ assertThat(
+ e.getMessage(),
+ containsString(
+ "accessing a multi-vector field's value through 'get' or 'value' is not supported,"
+ + " use 'vectorValues' or 'magnitudes' instead."
+ )
+ );
+ }
+
+ public void testByteGetFunctionIsNotAccessible() throws IOException {
+ int dims = 3;
+ float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } };
+ float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } };
+ BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, IndexVersion.current());
+ BinaryDocValues magnitudeValues = wrap(magnitudes);
+ MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField(
+ docValues,
+ magnitudeValues,
+ "test",
+ ElementType.BYTE,
+ dims
+ );
+ MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues();
+
+ field.setNextDocId(0);
+ Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
+ assertThat(
+ e.getMessage(),
+ containsString(
+ "accessing a multi-vector field's value through 'get' or 'value' is not supported,"
+ + " use 'vectorValues' or 'magnitudes' instead."
+ )
+ );
+ }
+
+ public static BinaryDocValues wrap(float[][] magnitudes) {
+ return new BinaryDocValues() {
+ int idx = -1;
+ int maxIdx = magnitudes.length;
+
+ @Override
+ public BytesRef binaryValue() {
+ if (idx >= maxIdx) {
+ throw new IllegalStateException("max index exceeded");
+ }
+ ByteBuffer magnitudeBuffer = ByteBuffer.allocate(magnitudes[idx].length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
+ for (float magnitude : magnitudes[idx]) {
+ magnitudeBuffer.putFloat(magnitude);
+ }
+ return new BytesRef(magnitudeBuffer.array());
+ }
+
+ @Override
+ public boolean advanceExact(int target) {
+ idx = target;
+ if (target < maxIdx) {
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public int docID() {
+ return idx;
+ }
+
+ @Override
+ public int nextDoc() {
+ return idx++;
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new IllegalArgumentException("not defined!");
+ }
+
+ @Override
+ public long cost() {
+ throw new IllegalArgumentException("not defined!");
+ }
+ };
+ }
+
+ public static BinaryDocValues wrap(float[][][] vectors, ElementType elementType, IndexVersion indexVersion) {
+ return new BinaryDocValues() {
+ int idx = -1;
+ int maxIdx = vectors.length;
+
+ @Override
+ public BytesRef binaryValue() {
+ if (idx >= maxIdx) {
+ throw new IllegalStateException("max index exceeded");
+ }
+ return mockEncodeDenseVector(vectors[idx], elementType, indexVersion);
+ }
+
+ @Override
+ public boolean advanceExact(int target) {
+ idx = target;
+ if (target < maxIdx) {
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public int docID() {
+ return idx;
+ }
+
+ @Override
+ public int nextDoc() {
+ return idx++;
+ }
+
+ @Override
+ public int advance(int target) {
+ throw new IllegalArgumentException("not defined!");
+ }
+
+ @Override
+ public long cost() {
+ throw new IllegalArgumentException("not defined!");
+ }
+ };
+ }
+
+ public static BytesRef mockEncodeDenseVector(float[][] values, ElementType elementType, IndexVersion indexVersion) {
+ int dims = values[0].length;
+ if (elementType == ElementType.BIT) {
+ dims *= Byte.SIZE;
+ }
+ int numBytes = elementType.getNumBytes(dims);
+ ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes * values.length);
+ for (float[] vector : values) {
+ for (float value : vector) {
+ if (elementType == ElementType.FLOAT) {
+ byteBuffer.putFloat(value);
+ } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) {
+ byteBuffer.put((byte) value);
+ } else {
+ throw new IllegalStateException("unknown element_type [" + elementType + "]");
+ }
+ }
+ }
+ return new BytesRef(byteBuffer.array());
+ }
+
+}