Skip to content

Commit

Permalink
Add multi_dense_vector value access to scripts (#116610)
Browse files Browse the repository at this point in the history
This adds value access to multi_dense_vector values in scripts. The
users will get:

 - Count of vectors per field
 - Magnitudes of all the individual vectors
 - Access to each vector with an iterator

I will happily take design critiques around how these are exposed in
scripting.

I initially though of just providing directly `float[][]` access, but
this seems to have some unfavorable behavior around creating a TON of
garbage. The reason is that each field could have a different number of
vectors, so allocating a new collection of `float[dim]` for every field
seemed rough. 

Generally, when scripting or using the vectors, an iterator should be
enough and I have the iterator backed by a simple buffer to keep garbage
down.
  • Loading branch information
benwtrent authored Nov 14, 2024
1 parent 2d2ad00 commit 55be3ac
Show file tree
Hide file tree
Showing 17 changed files with 1,330 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ class org.elasticsearch.script.field.SeqNoDocValuesField @dynamic_type {
class org.elasticsearch.script.field.VersionDocValuesField @dynamic_type {
}

class org.elasticsearch.script.field.vectors.MultiDenseVector {
MultiDenseVector EMPTY
float[] getMagnitudes()

Iterator getVectors()
boolean isEmpty()
int getDims()
int size()
}

class org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField {
MultiDenseVector get()
MultiDenseVector get(MultiDenseVector)
}

class org.elasticsearch.script.field.vectors.DenseVector {
DenseVector EMPTY
float getMagnitude()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ class org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues {
float getMagnitude()
}

class org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValues {
Iterator getVectorValues()
float[] getMagnitudes()
}

class org.apache.lucene.util.BytesRef {
byte[] bytes
int offset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
setup:
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ multi_dense_vector_script_access ]
test_runner_features: capabilities
reason: "Support for multi dense vector field script access capability required"
- skip:
features: headers

- do:
indices.create:
index: test-index
body:
settings:
number_of_shards: 1
mappings:
properties:
vector:
type: multi_dense_vector
dims: 5
byte_vector:
type: multi_dense_vector
dims: 5
element_type: byte
bit_vector:
type: multi_dense_vector
dims: 40
element_type: bit
- do:
index:
index: test-index
id: "1"
body:
vector: [[230.0, 300.33, -34.8988, 15.555, -200.0], [-0.5, 100.0, -13, 14.8, -156.0]]
byte_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]
bit_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]

- do:
index:
index: test-index
id: "3"
body:
vector: [[0.5, 111.3, -13.0, 14.8, -156.0]]
byte_vector: [[2, 18, -5, 0, -124]]
bit_vector: [[2, 18, -5, 0, -124]]

- do:
indices.refresh: {}
---
"Test vector magnitude equality":
- skip:
features: close_to

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['vector'].magnitudes[0]"

- match: {hits.total: 2}

- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 429.6021, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 192.6447, error: 0.01}}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['byte_vector'].magnitudes[0]"

- match: {hits.total: 2}

- match: {hits.hits.0._id: "3"}
- close_to: {hits.hits.0._score: {value: 125.41531, error: 0.01}}

- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 19.07878, error: 0.01}}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['bit_vector'].magnitudes[0]"

- match: {hits.total: 2}

- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 3.872983, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 3.464101, error: 0.01}}
---
"Test vector value scoring":
- skip:
features: close_to

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['vector'].vectorValues.next()[0];"

- match: {hits.total: 2}

- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 230, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 0.5, error: 0.01}}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['byte_vector'].vectorValues.next()[0];"

- match: {hits.total: 2}

- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 8, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 2, error: 0.01}}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['bit_vector'].vectorValues.next()[0];"

- match: {hits.total: 2}

- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 8, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 2, error: 0.01}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.mapper.vectors;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.script.field.vectors.MultiDenseVector;

import java.util.Iterator;

public class MultiDenseVectorScriptDocValues extends ScriptDocValues<BytesRef> {

public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a multi-vector field!";

private final int dims;
protected final MultiDenseVectorSupplier dvSupplier;

public MultiDenseVectorScriptDocValues(MultiDenseVectorSupplier supplier, int dims) {
super(supplier);
this.dvSupplier = supplier;
this.dims = dims;
}

public int dims() {
return dims;
}

private MultiDenseVector getCheckedVector() {
MultiDenseVector vector = dvSupplier.getInternal();
if (vector == null) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
return vector;
}

/**
* Get multi-dense vector's value as an array of floats
*/
public Iterator<float[]> getVectorValues() {
return getCheckedVector().getVectors();
}

/**
* Get dense vector's magnitude
*/
public float[] getMagnitudes() {
return getCheckedVector().getMagnitudes();
}

@Override
public BytesRef get(int index) {
throw new UnsupportedOperationException(
"accessing a multi-vector field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead."
);
}

@Override
public int size() {
MultiDenseVector mdv = dvSupplier.getInternal();
if (mdv != null) {
return mdv.size();
}
return 0;
}

public interface MultiDenseVectorSupplier extends Supplier<BytesRef> {
@Override
default BytesRef getInternal(int index) {
throw new UnsupportedOperationException();
}

MultiDenseVector getInternal();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,44 @@

package org.elasticsearch.index.mapper.vectors;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.fielddata.LeafFieldData;
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField;

import java.io.IOException;

final class MultiVectorDVLeafFieldData implements LeafFieldData {
private final LeafReader reader;
private final String field;
private final IndexVersion indexVersion;
private final DenseVectorFieldMapper.ElementType elementType;
private final int dims;

MultiVectorDVLeafFieldData(
LeafReader reader,
String field,
IndexVersion indexVersion,
DenseVectorFieldMapper.ElementType elementType,
int dims
) {
MultiVectorDVLeafFieldData(LeafReader reader, String field, DenseVectorFieldMapper.ElementType elementType, int dims) {
this.reader = reader;
this.field = field;
this.indexVersion = indexVersion;
this.elementType = elementType;
this.dims = dims;
}

@Override
public DocValuesScriptFieldFactory getScriptFieldFactory(String name) {
// TODO
return null;
try {
BinaryDocValues values = DocValues.getBinary(reader, field);
BinaryDocValues magnitudeValues = DocValues.getBinary(reader, field + MultiDenseVectorFieldMapper.VECTOR_MAGNITUDES_SUFFIX);
return switch (elementType) {
case BYTE -> new ByteMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
case FLOAT -> new FloatMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
case BIT -> new BitMultiDenseVectorDocValuesField(values, magnitudeValues, name, elementType, dims);
};
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for multi-vector field!", e);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public ValuesSourceType getValuesSourceType() {

@Override
public MultiVectorDVLeafFieldData load(LeafReaderContext context) {
return new MultiVectorDVLeafFieldData(context.reader(), fieldName, indexVersion, elementType, dims);
return new MultiVectorDVLeafFieldData(context.reader(), fieldName, elementType, dims);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,24 @@ public static void decodeDenseVector(IndexVersion indexVersion, BytesRef vectorB
}
}

public static float[] getMultiMagnitudes(BytesRef magnitudes) {
assert magnitudes.length % Float.BYTES == 0;
float[] multiMagnitudes = new float[magnitudes.length / Float.BYTES];
ByteBuffer byteBuffer = ByteBuffer.wrap(magnitudes.bytes, magnitudes.offset, magnitudes.length).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < magnitudes.length / Float.BYTES; i++) {
multiMagnitudes[i] = byteBuffer.getFloat();
}
return multiMagnitudes;
}

public static void decodeMultiDenseVector(BytesRef vectorBR, int numVectors, float[][] multiVectorValue) {
if (vectorBR == null) {
throw new IllegalArgumentException(MultiDenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
}
FloatBuffer fb = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
for (int i = 0; i < numVectors; i++) {
fb.get(multiVectorValue[i]);
}
}

}
Loading

0 comments on commit 55be3ac

Please sign in to comment.