From 55be3ac273a32367800557138959b98c65be4360 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 14 Nov 2024 16:46:21 -0500 Subject: [PATCH] Add multi_dense_vector value access to scripts (#116610) This adds value access to multi_dense_vector values in scripts. The users will get: - Count of vectors per field - Magnitudes of all the individual vectors - Access to each vector with an iterator I will happily take design critiques around how these are exposed in scripting. I initially though of just providing directly `float[][]` access, but this seems to have some unfavorable behavior around creating a TON of garbage. The reason is that each field could have a different number of vectors, so allocating a new collection of `float[dim]` for every field seemed rough. Generally, when scripting or using the vectors, an iterator should be enough and I have the iterator backed by a simple buffer to keep garbage down. --- .../org.elasticsearch.script.fields.txt | 15 + .../painless/org.elasticsearch.txt | 5 + .../181_multi_dense_vector_dv_fields_api.yml | 178 +++++++++ .../MultiDenseVectorScriptDocValues.java | 81 ++++ .../vectors/MultiVectorDVLeafFieldData.java | 31 +- .../vectors/MultiVectorIndexFieldData.java | 2 +- .../mapper/vectors/VectorEncoderDecoder.java | 20 + .../action/search/SearchCapabilities.java | 3 + .../field/vectors/BitMultiDenseVector.java | 38 ++ .../BitMultiDenseVectorDocValuesField.java | 31 ++ .../field/vectors/ByteMultiDenseVector.java | 91 +++++ .../ByteMultiDenseVectorDocValuesField.java | 142 +++++++ .../field/vectors/FloatMultiDenseVector.java | 61 +++ .../FloatMultiDenseVectorDocValuesField.java | 143 +++++++ .../field/vectors/MultiDenseVector.java | 71 ++++ .../MultiDenseVectorDocValuesField.java | 57 +++ .../MultiDenseVectorScriptDocValuesTests.java | 374 ++++++++++++++++++ 17 files changed, 1330 insertions(+), 13 deletions(-) create mode 100644 modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java create mode 100644 server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt index a739635e85a9c..875b9a1dac3e8 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt @@ -132,6 +132,21 @@ class org.elasticsearch.script.field.SeqNoDocValuesField @dynamic_type { class org.elasticsearch.script.field.VersionDocValuesField @dynamic_type { } +class org.elasticsearch.script.field.vectors.MultiDenseVector { + MultiDenseVector EMPTY + float[] getMagnitudes() + + Iterator getVectors() + boolean isEmpty() + int getDims() + int size() +} + +class org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField { + MultiDenseVector get() + MultiDenseVector get(MultiDenseVector) +} + class org.elasticsearch.script.field.vectors.DenseVector { DenseVector EMPTY float getMagnitude() diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt index 7ab9eb32852b6..b2db0d1006d40 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.txt @@ -123,6 +123,11 @@ class org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues { float getMagnitude() } +class org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues { + Iterator getVectorValues() + float[] getMagnitudes() +} + class org.apache.lucene.util.BytesRef { byte[] bytes int offset diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml new file mode 100644 index 0000000000000..66cb3f3c46fcc --- /dev/null +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml @@ -0,0 +1,178 @@ +setup: + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ multi_dense_vector_script_access ] + test_runner_features: capabilities + reason: "Support for multi dense vector field script access capability required" + - skip: + features: headers + + - do: + indices.create: + index: test-index + body: + settings: + number_of_shards: 1 + mappings: + properties: + vector: + type: multi_dense_vector + dims: 5 + byte_vector: + type: multi_dense_vector + dims: 5 + element_type: byte + bit_vector: + type: multi_dense_vector + dims: 40 + element_type: bit + - do: + index: + index: test-index + id: "1" + body: + vector: [[230.0, 300.33, -34.8988, 15.555, -200.0], [-0.5, 100.0, -13, 14.8, -156.0]] + byte_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]] + bit_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]] + + - do: + index: + index: test-index + id: "3" + body: + vector: [[0.5, 111.3, -13.0, 14.8, -156.0]] + byte_vector: [[2, 18, -5, 0, -124]] + bit_vector: [[2, 18, -5, 0, -124]] + + - do: + indices.refresh: {} +--- +"Test vector magnitude equality": + - skip: + features: close_to + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['vector'].magnitudes[0]" + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 429.6021, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 192.6447, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['byte_vector'].magnitudes[0]" + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "3"} + - close_to: {hits.hits.0._score: {value: 125.41531, error: 0.01}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 19.07878, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['bit_vector'].magnitudes[0]" + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 3.872983, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 3.464101, error: 0.01}} +--- +"Test vector value scoring": + - skip: + features: close_to + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['vector'].vectorValues.next()[0];" + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 230, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 0.5, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['byte_vector'].vectorValues.next()[0];" + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 8, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 2, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['bit_vector'].vectorValues.next()[0];" + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 8, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 2, error: 0.01}} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java new file mode 100644 index 0000000000000..a91960832239f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.script.field.vectors.MultiDenseVector; + +import java.util.Iterator; + +public class MultiDenseVectorScriptDocValues extends ScriptDocValues { + + public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a multi-vector field!"; + + private final int dims; + protected final MultiDenseVectorSupplier dvSupplier; + + public MultiDenseVectorScriptDocValues(MultiDenseVectorSupplier supplier, int dims) { + super(supplier); + this.dvSupplier = supplier; + this.dims = dims; + } + + public int dims() { + return dims; + } + + private MultiDenseVector getCheckedVector() { + MultiDenseVector vector = dvSupplier.getInternal(); + if (vector == null) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + return vector; + } + + /** + * Get multi-dense vector's value as an array of floats + */ + public Iterator getVectorValues() { + return getCheckedVector().getVectors(); + } + + /** + * Get dense vector's magnitude + */ + public float[] getMagnitudes() { + return getCheckedVector().getMagnitudes(); + } + + @Override + public BytesRef get(int index) { + throw new UnsupportedOperationException( + "accessing a multi-vector field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead." + ); + } + + @Override + public int size() { + MultiDenseVector mdv = dvSupplier.getInternal(); + if (mdv != null) { + return mdv.size(); + } + return 0; + } + + public interface MultiDenseVectorSupplier extends Supplier { + @Override + default BytesRef getInternal(int index) { + throw new UnsupportedOperationException(); + } + + MultiDenseVector getInternal(); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java index cc6fb38274451..b9716d315f33a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorDVLeafFieldData.java @@ -9,37 +9,44 @@ package org.elasticsearch.index.mapper.vectors; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReader; -import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.LeafFieldData; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; +import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField; + +import java.io.IOException; final class MultiVectorDVLeafFieldData implements LeafFieldData { private final LeafReader reader; private final String field; - private final IndexVersion indexVersion; private final DenseVectorFieldMapper.ElementType elementType; private final int dims; - MultiVectorDVLeafFieldData( - LeafReader reader, - String field, - IndexVersion indexVersion, - DenseVectorFieldMapper.ElementType elementType, - int dims - ) { + MultiVectorDVLeafFieldData(LeafReader reader, String field, DenseVectorFieldMapper.ElementType elementType, int dims) { this.reader = reader; this.field = field; - this.indexVersion = indexVersion; this.elementType = elementType; this.dims = dims; } @Override public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { - // TODO - return null; + try { + BinaryDocValues values = DocValues.getBinary(reader, field); + BinaryDocValues magnitudeValues = DocValues.getBinary(reader, field + MultiDenseVectorFieldMapper.VECTOR_MAGNITUDES_SUFFIX); + return switch (elementType) { + case BYTE -> new ByteMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims); + case FLOAT -> new FloatMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims); + case BIT -> new BitMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims); + }; + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values for multi-vector field!", e); + } } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java index 65ef492ce052b..44a666e25a611 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/MultiVectorIndexFieldData.java @@ -55,7 +55,7 @@ public ValuesSourceType getValuesSourceType() { @Override public MultiVectorDVLeafFieldData load(LeafReaderContext context) { - return new MultiVectorDVLeafFieldData(context.reader(), fieldName, indexVersion, elementType, dims); + return new MultiVectorDVLeafFieldData(context.reader(), fieldName, elementType, dims); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java index 9d09a7493d605..3db2d164846bd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java @@ -84,4 +84,24 @@ public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorB } } + public static float[] getMultiMagnitudes(BytesRef magnitudes) { + assert magnitudes.length % Float.BYTES == 0; + float[] multiMagnitudes = new float[magnitudes.length / Float.BYTES]; + ByteBuffer byteBuffer = ByteBuffer.wrap(magnitudes.bytes, magnitudes.offset, magnitudes.length).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < magnitudes.length / Float.BYTES; i++) { + multiMagnitudes[i] = byteBuffer.getFloat(); + } + return multiMagnitudes; + } + + public static void decodeMultiDenseVector(BytesRef vectorBR, int numVectors, float[][] multiVectorValue) { + if (vectorBR == null) { + throw new IllegalArgumentException(MultiDenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE); + } + FloatBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + for (int i = 0; i < numVectors; i++) { + fb.get(multiVectorValue[i]); + } + } + } diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 3bc1c467323a3..7b57481ad5716 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -38,6 +38,8 @@ private SearchCapabilities() {} private static final String MULTI_DENSE_VECTOR_FIELD_MAPPER = "multi_dense_vector_field_mapper"; /** Support propagating nested retrievers' inner_hits to top-level compound retrievers . */ private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support"; + /** Support multi-dense-vector script field access. */ + private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access"; public static final Set CAPABILITIES; static { @@ -50,6 +52,7 @@ private SearchCapabilities() {} capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); + capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS); } if (Build.current().isSnapshot()) { capabilities.add(KQL_QUERY_SUPPORTED); diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java new file mode 100644 index 0000000000000..24e19a803ff38 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BytesRef; + +import java.util.Iterator; + +public class BitMultiDenseVector extends ByteMultiDenseVector { + public BitMultiDenseVector(Iterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { + super(vectorValues, magnitudesBytes, numVecs, dims); + } + + @Override + public void checkDimensions(int qvDims) { + if (qvDims != dims) { + throw new IllegalArgumentException( + "The query vector has a different number of dimensions [" + + qvDims * Byte.SIZE + + "] than the document vectors [" + + dims * Byte.SIZE + + "]." + ); + } + } + + @Override + public int getDims() { + return dims * Byte.SIZE; + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..35a43eabb8f0c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVectorDocValuesField.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; + +public class BitMultiDenseVectorDocValuesField extends ByteMultiDenseVectorDocValuesField { + + public BitMultiDenseVectorDocValuesField( + BinaryDocValues input, + BinaryDocValues magnitudes, + String name, + ElementType elementType, + int dims + ) { + super(input, magnitudes, name, elementType, dims / 8); + } + + @Override + protected MultiDenseVector getVector() { + return new BitMultiDenseVector(vectorValue, magnitudesValue, numVecs, dims); + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java new file mode 100644 index 0000000000000..e610d10146b2f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; + +import java.util.Iterator; + +public class ByteMultiDenseVector implements MultiDenseVector { + + protected final Iterator vectorValues; + protected final int numVecs; + protected final int dims; + + private Iterator floatDocVectors; + private float[] magnitudes; + private final BytesRef magnitudesBytes; + + public ByteMultiDenseVector(Iterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { + assert magnitudesBytes.length == numVecs * Float.BYTES; + this.vectorValues = vectorValues; + this.numVecs = numVecs; + this.dims = dims; + this.magnitudesBytes = magnitudesBytes; + } + + @Override + public Iterator getVectors() { + if (floatDocVectors == null) { + floatDocVectors = new ByteToFloatIteratorWrapper(vectorValues, dims); + } + return floatDocVectors; + } + + @Override + public float[] getMagnitudes() { + if (magnitudes == null) { + magnitudes = VectorEncoderDecoder.getMultiMagnitudes(magnitudesBytes); + } + return magnitudes; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public int getDims() { + return dims; + } + + @Override + public int size() { + return numVecs; + } + + static class ByteToFloatIteratorWrapper implements Iterator { + private final Iterator byteIterator; + private final float[] buffer; + private final int dims; + + ByteToFloatIteratorWrapper(Iterator byteIterator, int dims) { + this.byteIterator = byteIterator; + this.buffer = new float[dims]; + this.dims = dims; + } + + @Override + public boolean hasNext() { + return byteIterator.hasNext(); + } + + @Override + public float[] next() { + byte[] next = byteIterator.next(); + for (int i = 0; i < dims; i++) { + buffer[i] = next[i]; + } + return buffer; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..d1e062e0a3dee --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues; + +import java.io.IOException; +import java.util.Iterator; + +public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValuesField { + + protected final BinaryDocValues input; + private final BinaryDocValues magnitudes; + protected final int dims; + protected int numVecs; + protected Iterator vectorValue; + protected boolean decoded; + protected BytesRef value; + protected BytesRef magnitudesValue; + private byte[] buffer; + + public ByteMultiDenseVectorDocValuesField( + BinaryDocValues input, + BinaryDocValues magnitudes, + String name, + ElementType elementType, + int dims + ) { + super(name, elementType); + this.input = input; + this.dims = dims; + this.buffer = new byte[dims]; + this.magnitudes = magnitudes; + } + + @Override + public void setNextDocId(int docId) throws IOException { + decoded = false; + if (input.advanceExact(docId)) { + boolean magnitudesFound = magnitudes.advanceExact(docId); + assert magnitudesFound; + value = input.binaryValue(); + assert value.length % dims == 0; + numVecs = value.length / dims; + magnitudesValue = magnitudes.binaryValue(); + assert magnitudesValue.length == (numVecs * Float.BYTES); + } else { + value = null; + magnitudesValue = null; + vectorValue = null; + numVecs = 0; + } + } + + @Override + public MultiDenseVectorScriptDocValues toScriptDocValues() { + return new MultiDenseVectorScriptDocValues(this, dims); + } + + protected MultiDenseVector getVector() { + return new ByteMultiDenseVector(vectorValue, magnitudesValue, numVecs, dims); + } + + @Override + public MultiDenseVector get() { + if (isEmpty()) { + return MultiDenseVector.EMPTY; + } + decodeVectorIfNecessary(); + return getVector(); + } + + @Override + public MultiDenseVector get(MultiDenseVector defaultValue) { + if (isEmpty()) { + return defaultValue; + } + decodeVectorIfNecessary(); + return getVector(); + } + + @Override + public MultiDenseVector getInternal() { + return get(null); + } + + private void decodeVectorIfNecessary() { + if (decoded == false && value != null) { + vectorValue = new ByteVectorIterator(value, buffer, numVecs); + decoded = true; + } + } + + @Override + public int size() { + return value == null ? 0 : value.length / dims; + } + + @Override + public boolean isEmpty() { + return value == null; + } + + static class ByteVectorIterator implements Iterator { + private final byte[] buffer; + private final BytesRef vectorValues; + private final int size; + private int idx = 0; + + ByteVectorIterator(BytesRef vectorValues, byte[] buffer, int size) { + assert vectorValues.length == (buffer.length * size); + this.vectorValues = vectorValues; + this.size = size; + this.buffer = buffer; + } + + @Override + public boolean hasNext() { + return idx < size; + } + + @Override + public byte[] next() { + if (hasNext() == false) { + throw new IllegalArgumentException("No more elements in the iterator"); + } + System.arraycopy(vectorValues.bytes, vectorValues.offset + idx * buffer.length, buffer, 0, buffer.length); + idx++; + return buffer; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java new file mode 100644 index 0000000000000..9ffe8b3b970c4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BytesRef; + +import java.util.Iterator; + +import static org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder.getMultiMagnitudes; + +public class FloatMultiDenseVector implements MultiDenseVector { + + private final BytesRef magnitudes; + private float[] magnitudesArray = null; + private final int dims; + private final int numVectors; + private final Iterator decodedDocVector; + + public FloatMultiDenseVector(Iterator decodedDocVector, BytesRef magnitudes, int numVectors, int dims) { + assert magnitudes.length == numVectors * Float.BYTES; + this.decodedDocVector = decodedDocVector; + this.magnitudes = magnitudes; + this.numVectors = numVectors; + this.dims = dims; + } + + @Override + public Iterator getVectors() { + return decodedDocVector; + } + + @Override + public float[] getMagnitudes() { + if (magnitudesArray == null) { + magnitudesArray = getMultiMagnitudes(magnitudes); + } + return magnitudesArray; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public int getDims() { + return dims; + } + + @Override + public int size() { + return numVectors; + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..356db58d989c5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.Iterator; + +public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValuesField { + + private final BinaryDocValues input; + private final BinaryDocValues magnitudes; + private boolean decoded; + private final int dims; + private BytesRef value; + private BytesRef magnitudesValue; + private FloatVectorIterator vectorValues; + private int numVectors; + private float[] buffer; + + public FloatMultiDenseVectorDocValuesField( + BinaryDocValues input, + BinaryDocValues magnitudes, + String name, + ElementType elementType, + int dims + ) { + super(name, elementType); + this.input = input; + this.magnitudes = magnitudes; + this.dims = dims; + this.buffer = new float[dims]; + } + + @Override + public void setNextDocId(int docId) throws IOException { + decoded = false; + if (input.advanceExact(docId)) { + boolean magnitudesFound = magnitudes.advanceExact(docId); + assert magnitudesFound; + + value = input.binaryValue(); + assert value.length % (Float.BYTES * dims) == 0; + numVectors = value.length / (Float.BYTES * dims); + magnitudesValue = magnitudes.binaryValue(); + assert magnitudesValue.length == (Float.BYTES * numVectors); + } else { + value = null; + magnitudesValue = null; + numVectors = 0; + } + } + + @Override + public MultiDenseVectorScriptDocValues toScriptDocValues() { + return new MultiDenseVectorScriptDocValues(this, dims); + } + + @Override + public boolean isEmpty() { + return value == null; + } + + @Override + public MultiDenseVector get() { + if (isEmpty()) { + return MultiDenseVector.EMPTY; + } + decodeVectorIfNecessary(); + return new FloatMultiDenseVector(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public MultiDenseVector get(MultiDenseVector defaultValue) { + if (isEmpty()) { + return defaultValue; + } + decodeVectorIfNecessary(); + return new FloatMultiDenseVector(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public MultiDenseVector getInternal() { + return get(null); + } + + @Override + public int size() { + return value == null ? 0 : value.length / (Float.BYTES * dims); + } + + private void decodeVectorIfNecessary() { + if (decoded == false && value != null) { + vectorValues = new FloatVectorIterator(value, buffer, numVectors); + decoded = true; + } + } + + static class FloatVectorIterator implements Iterator { + private final float[] buffer; + private final FloatBuffer vectorValues; + private final int size; + private int idx = 0; + + FloatVectorIterator(BytesRef vectorValues, float[] buffer, int size) { + assert vectorValues.length == (buffer.length * Float.BYTES * size); + this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer(); + this.size = size; + this.buffer = buffer; + } + + @Override + public boolean hasNext() { + return idx < size; + } + + @Override + public float[] next() { + if (hasNext() == false) { + throw new IllegalArgumentException("No more elements in the iterator"); + } + vectorValues.get(buffer); + idx++; + return buffer; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java new file mode 100644 index 0000000000000..85c851dbe545c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import java.util.Iterator; + +public interface MultiDenseVector { + + default void checkDimensions(int qvDims) { + checkDimensions(getDims(), qvDims); + } + + Iterator getVectors(); + + float[] getMagnitudes(); + + boolean isEmpty(); + + int getDims(); + + int size(); + + static void checkDimensions(int dvDims, int qvDims) { + if (dvDims != qvDims) { + throw new IllegalArgumentException( + "The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]." + ); + } + } + + private static String badQueryVectorType(Object queryVector) { + return "Cannot use vector [" + queryVector + "] with class [" + queryVector.getClass().getName() + "] as query vector"; + } + + MultiDenseVector EMPTY = new MultiDenseVector() { + public static final String MISSING_VECTOR_FIELD_MESSAGE = "Multi Dense vector value missing for a field," + + " use isEmpty() to check for a missing vector value"; + + @Override + public Iterator getVectors() { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public float[] getMagnitudes() { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public int getDims() { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public int size() { + return 0; + } + }; +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..61ae4304683c8 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVectorDocValuesField.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues; +import org.elasticsearch.script.field.AbstractScriptFieldFactory; +import org.elasticsearch.script.field.DocValuesScriptFieldFactory; +import org.elasticsearch.script.field.Field; + +import java.util.Iterator; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; + +public abstract class MultiDenseVectorDocValuesField extends AbstractScriptFieldFactory + implements + Field, + DocValuesScriptFieldFactory, + MultiDenseVectorScriptDocValues.MultiDenseVectorSupplier { + protected final String name; + protected final ElementType elementType; + + public MultiDenseVectorDocValuesField(String name, ElementType elementType) { + this.name = name; + this.elementType = elementType; + } + + @Override + public String getName() { + return name; + } + + public ElementType getElementType() { + return elementType; + } + + /** + * Get the DenseVector for a document if one exists, DenseVector.EMPTY otherwise + */ + public abstract MultiDenseVector get(); + + public abstract MultiDenseVector get(MultiDenseVector defaultValue); + + public abstract MultiDenseVectorScriptDocValues toScriptDocValues(); + + // DenseVector fields are single valued, so Iterable does not make sense. + @Override + public Iterator iterator() { + throw new UnsupportedOperationException("Cannot iterate over single valued multi_dense_vector field, use get() instead"); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java new file mode 100644 index 0000000000000..ef316c5addefa --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValuesTests.java @@ -0,0 +1,374 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.MultiDenseVector; +import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.index.IndexVersionUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Iterator; +import java.util.List; + +import static org.hamcrest.Matchers.containsString; + +public class MultiDenseVectorScriptDocValuesTests extends ESTestCase { + + public void testFloatGetVectorValueAndGetMagnitude() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + + for (IndexVersion indexVersion : List.of(IndexVersionUtils.randomCompatibleVersion(random()), IndexVersion.current())) { + BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, indexVersion); + BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); + MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.FLOAT, + dims + ); + MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + assertEquals(vectors[i].length, field.size()); + assertEquals(dims, scriptDocValues.dims()); + Iterator iterator = scriptDocValues.getVectorValues(); + float[] magnitudes = scriptDocValues.getMagnitudes(); + assertEquals(expectedMagnitudes[i].length, magnitudes.length); + for (int j = 0; j < vectors[i].length; j++) { + assertTrue(iterator.hasNext()); + assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f); + assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f); + } + } + } + } + + public void testByteGetVectorValueAndGetMagnitude() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + + BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, IndexVersion.current()); + BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); + MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BYTE, + dims + ); + MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + assertEquals(vectors[i].length, field.size()); + assertEquals(dims, scriptDocValues.dims()); + Iterator iterator = scriptDocValues.getVectorValues(); + float[] magnitudes = scriptDocValues.getMagnitudes(); + assertEquals(expectedMagnitudes[i].length, magnitudes.length); + for (int j = 0; j < vectors[i].length; j++) { + assertTrue(iterator.hasNext()); + assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f); + assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f); + } + } + } + + public void testFloatMetadataAndIterator() throws IOException { + int dims = 3; + IndexVersion indexVersion = IndexVersion.current(); + float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.FLOAT), fill(new float[2][dims], ElementType.FLOAT) }; + float[][] magnitudes = new float[][] { new float[3], new float[2] }; + BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, indexVersion); + BinaryDocValues magnitudeValues = wrap(magnitudes); + + MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.FLOAT, + dims + ); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + MultiDenseVector dv = field.get(); + assertEquals(vectors[i].length, dv.size()); + assertFalse(dv.isEmpty()); + assertEquals(dims, dv.getDims()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); + assertEquals("Cannot iterate over single valued multi_dense_vector field, use get() instead", e.getMessage()); + } + field.setNextDocId(vectors.length); + MultiDenseVector dv = field.get(); + assertEquals(dv, MultiDenseVector.EMPTY); + } + + public void testByteMetadataAndIterator() throws IOException { + int dims = 3; + IndexVersion indexVersion = IndexVersion.current(); + float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BYTE), fill(new float[2][dims], ElementType.BYTE) }; + float[][] magnitudes = new float[][] { new float[3], new float[2] }; + BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, indexVersion); + BinaryDocValues magnitudeValues = wrap(magnitudes); + MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BYTE, + dims + ); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + MultiDenseVector dv = field.get(); + assertEquals(vectors[i].length, dv.size()); + assertFalse(dv.isEmpty()); + assertEquals(dims, dv.getDims()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); + assertEquals("Cannot iterate over single valued multi_dense_vector field, use get() instead", e.getMessage()); + } + field.setNextDocId(vectors.length); + MultiDenseVector dv = field.get(); + assertEquals(dv, MultiDenseVector.EMPTY); + } + + protected float[][] fill(float[][] vectors, ElementType elementType) { + for (float[] vector : vectors) { + for (int i = 0; i < vector.length; i++) { + vector[i] = elementType == ElementType.FLOAT ? randomFloat() : randomByte(); + } + } + return vectors; + } + + public void testFloatMissingValues() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, IndexVersion.current()); + BinaryDocValues magnitudeValues = wrap(magnitudes); + MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.FLOAT, + dims + ); + MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(3); + assertEquals(0, field.size()); + Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); + assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); + assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + } + + public void testByteMissingValues() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, IndexVersion.current()); + BinaryDocValues magnitudeValues = wrap(magnitudes); + MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BYTE, + dims + ); + MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(3); + assertEquals(0, field.size()); + Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); + assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); + assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + } + + public void testFloatGetFunctionIsNotAccessible() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.FLOAT, IndexVersion.current()); + BinaryDocValues magnitudeValues = wrap(magnitudes); + MultiDenseVectorDocValuesField field = new FloatMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.FLOAT, + dims + ); + MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(0); + Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + assertThat( + e.getMessage(), + containsString( + "accessing a multi-vector field's value through 'get' or 'value' is not supported," + + " use 'vectorValues' or 'magnitudes' instead." + ) + ); + } + + public void testByteGetFunctionIsNotAccessible() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BYTE, IndexVersion.current()); + BinaryDocValues magnitudeValues = wrap(magnitudes); + MultiDenseVectorDocValuesField field = new ByteMultiDenseVectorDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BYTE, + dims + ); + MultiDenseVectorScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(0); + Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + assertThat( + e.getMessage(), + containsString( + "accessing a multi-vector field's value through 'get' or 'value' is not supported," + + " use 'vectorValues' or 'magnitudes' instead." + ) + ); + } + + public static BinaryDocValues wrap(float[][] magnitudes) { + return new BinaryDocValues() { + int idx = -1; + int maxIdx = magnitudes.length; + + @Override + public BytesRef binaryValue() { + if (idx >= maxIdx) { + throw new IllegalStateException("max index exceeded"); + } + ByteBuffer magnitudeBuffer = ByteBuffer.allocate(magnitudes[idx].length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (float magnitude : magnitudes[idx]) { + magnitudeBuffer.putFloat(magnitude); + } + return new BytesRef(magnitudeBuffer.array()); + } + + @Override + public boolean advanceExact(int target) { + idx = target; + if (target < maxIdx) { + return true; + } + return false; + } + + @Override + public int docID() { + return idx; + } + + @Override + public int nextDoc() { + return idx++; + } + + @Override + public int advance(int target) { + throw new IllegalArgumentException("not defined!"); + } + + @Override + public long cost() { + throw new IllegalArgumentException("not defined!"); + } + }; + } + + public static BinaryDocValues wrap(float[][][] vectors, ElementType elementType, IndexVersion indexVersion) { + return new BinaryDocValues() { + int idx = -1; + int maxIdx = vectors.length; + + @Override + public BytesRef binaryValue() { + if (idx >= maxIdx) { + throw new IllegalStateException("max index exceeded"); + } + return mockEncodeDenseVector(vectors[idx], elementType, indexVersion); + } + + @Override + public boolean advanceExact(int target) { + idx = target; + if (target < maxIdx) { + return true; + } + return false; + } + + @Override + public int docID() { + return idx; + } + + @Override + public int nextDoc() { + return idx++; + } + + @Override + public int advance(int target) { + throw new IllegalArgumentException("not defined!"); + } + + @Override + public long cost() { + throw new IllegalArgumentException("not defined!"); + } + }; + } + + public static BytesRef mockEncodeDenseVector(float[][] values, ElementType elementType, IndexVersion indexVersion) { + int dims = values[0].length; + if (elementType == ElementType.BIT) { + dims *= Byte.SIZE; + } + int numBytes = elementType.getNumBytes(dims); + ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes * values.length); + for (float[] vector : values) { + for (float value : vector) { + if (elementType == ElementType.FLOAT) { + byteBuffer.putFloat(value); + } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) { + byteBuffer.put((byte) value); + } else { + throw new IllegalStateException("unknown element_type [" + elementType + "]"); + } + } + } + return new BytesRef(byteBuffer.array()); + } + +}