-
Notifications
You must be signed in to change notification settings - Fork 25k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add multi_dense_vector value access to scripts (#116610)
This adds value access to multi_dense_vector values in scripts. The users will get: - Count of vectors per field - Magnitudes of all the individual vectors - Access to each vector with an iterator I will happily take design critiques around how these are exposed in scripting. I initially though of just providing directly `float[][]` access, but this seems to have some unfavorable behavior around creating a TON of garbage. The reason is that each field could have a different number of vectors, so allocating a new collection of `float[dim]` for every field seemed rough. Generally, when scripting or using the vectors, an iterator should be enough and I have the iterator backed by a simple buffer to keep garbage down.
- Loading branch information
Showing
17 changed files
with
1,330 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
...mlRestTest/resources/rest-api-spec/test/painless/181_multi_dense_vector_dv_fields_api.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
setup: | ||
- requires: | ||
capabilities: | ||
- method: POST | ||
path: /_search | ||
capabilities: [ multi_dense_vector_script_access ] | ||
test_runner_features: capabilities | ||
reason: "Support for multi dense vector field script access capability required" | ||
- skip: | ||
features: headers | ||
|
||
- do: | ||
indices.create: | ||
index: test-index | ||
body: | ||
settings: | ||
number_of_shards: 1 | ||
mappings: | ||
properties: | ||
vector: | ||
type: multi_dense_vector | ||
dims: 5 | ||
byte_vector: | ||
type: multi_dense_vector | ||
dims: 5 | ||
element_type: byte | ||
bit_vector: | ||
type: multi_dense_vector | ||
dims: 40 | ||
element_type: bit | ||
- do: | ||
index: | ||
index: test-index | ||
id: "1" | ||
body: | ||
vector: [[230.0, 300.33, -34.8988, 15.555, -200.0], [-0.5, 100.0, -13, 14.8, -156.0]] | ||
byte_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]] | ||
bit_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]] | ||
|
||
- do: | ||
index: | ||
index: test-index | ||
id: "3" | ||
body: | ||
vector: [[0.5, 111.3, -13.0, 14.8, -156.0]] | ||
byte_vector: [[2, 18, -5, 0, -124]] | ||
bit_vector: [[2, 18, -5, 0, -124]] | ||
|
||
- do: | ||
indices.refresh: {} | ||
--- | ||
"Test vector magnitude equality": | ||
- skip: | ||
features: close_to | ||
|
||
- do: | ||
headers: | ||
Content-Type: application/json | ||
search: | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
script_score: | ||
query: {match_all: {} } | ||
script: | ||
source: "doc['vector'].magnitudes[0]" | ||
|
||
- match: {hits.total: 2} | ||
|
||
- match: {hits.hits.0._id: "1"} | ||
- close_to: {hits.hits.0._score: {value: 429.6021, error: 0.01}} | ||
|
||
- match: {hits.hits.1._id: "3"} | ||
- close_to: {hits.hits.1._score: {value: 192.6447, error: 0.01}} | ||
|
||
- do: | ||
headers: | ||
Content-Type: application/json | ||
search: | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
script_score: | ||
query: {match_all: {} } | ||
script: | ||
source: "doc['byte_vector'].magnitudes[0]" | ||
|
||
- match: {hits.total: 2} | ||
|
||
- match: {hits.hits.0._id: "3"} | ||
- close_to: {hits.hits.0._score: {value: 125.41531, error: 0.01}} | ||
|
||
- match: {hits.hits.1._id: "1"} | ||
- close_to: {hits.hits.1._score: {value: 19.07878, error: 0.01}} | ||
|
||
- do: | ||
headers: | ||
Content-Type: application/json | ||
search: | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
script_score: | ||
query: {match_all: {} } | ||
script: | ||
source: "doc['bit_vector'].magnitudes[0]" | ||
|
||
- match: {hits.total: 2} | ||
|
||
- match: {hits.hits.0._id: "1"} | ||
- close_to: {hits.hits.0._score: {value: 3.872983, error: 0.01}} | ||
|
||
- match: {hits.hits.1._id: "3"} | ||
- close_to: {hits.hits.1._score: {value: 3.464101, error: 0.01}} | ||
--- | ||
"Test vector value scoring": | ||
- skip: | ||
features: close_to | ||
|
||
- do: | ||
headers: | ||
Content-Type: application/json | ||
search: | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
script_score: | ||
query: {match_all: {} } | ||
script: | ||
source: "doc['vector'].vectorValues.next()[0];" | ||
|
||
- match: {hits.total: 2} | ||
|
||
- match: {hits.hits.0._id: "1"} | ||
- close_to: {hits.hits.0._score: {value: 230, error: 0.01}} | ||
|
||
- match: {hits.hits.1._id: "3"} | ||
- close_to: {hits.hits.1._score: {value: 0.5, error: 0.01}} | ||
|
||
- do: | ||
headers: | ||
Content-Type: application/json | ||
search: | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
script_score: | ||
query: {match_all: {} } | ||
script: | ||
source: "doc['byte_vector'].vectorValues.next()[0];" | ||
|
||
- match: {hits.total: 2} | ||
|
||
- match: {hits.hits.0._id: "1"} | ||
- close_to: {hits.hits.0._score: {value: 8, error: 0.01}} | ||
|
||
- match: {hits.hits.1._id: "3"} | ||
- close_to: {hits.hits.1._score: {value: 2, error: 0.01}} | ||
|
||
- do: | ||
headers: | ||
Content-Type: application/json | ||
search: | ||
rest_total_hits_as_int: true | ||
body: | ||
query: | ||
script_score: | ||
query: {match_all: {} } | ||
script: | ||
source: "doc['bit_vector'].vectorValues.next()[0];" | ||
|
||
- match: {hits.total: 2} | ||
|
||
- match: {hits.hits.0._id: "1"} | ||
- close_to: {hits.hits.0._score: {value: 8, error: 0.01}} | ||
|
||
- match: {hits.hits.1._id: "3"} | ||
- close_to: {hits.hits.1._score: {value: 2, error: 0.01}} |
81 changes: 81 additions & 0 deletions
81
...src/main/java/org/elasticsearch/index/mapper/vectors/MultiDenseVectorScriptDocValues.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the "Elastic License | ||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||
* Public License v 1"; you may not use this file except in compliance with, at | ||
* your election, the "Elastic License 2.0", the "GNU Affero General Public | ||
* License v3.0 only", or the "Server Side Public License, v 1". | ||
*/ | ||
|
||
package org.elasticsearch.index.mapper.vectors; | ||
|
||
import org.apache.lucene.util.BytesRef; | ||
import org.elasticsearch.index.fielddata.ScriptDocValues; | ||
import org.elasticsearch.script.field.vectors.MultiDenseVector; | ||
|
||
import java.util.Iterator; | ||
|
||
public class MultiDenseVectorScriptDocValues extends ScriptDocValues<BytesRef> { | ||
|
||
public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a multi-vector field!"; | ||
|
||
private final int dims; | ||
protected final MultiDenseVectorSupplier dvSupplier; | ||
|
||
public MultiDenseVectorScriptDocValues(MultiDenseVectorSupplier supplier, int dims) { | ||
super(supplier); | ||
this.dvSupplier = supplier; | ||
this.dims = dims; | ||
} | ||
|
||
public int dims() { | ||
return dims; | ||
} | ||
|
||
private MultiDenseVector getCheckedVector() { | ||
MultiDenseVector vector = dvSupplier.getInternal(); | ||
if (vector == null) { | ||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); | ||
} | ||
return vector; | ||
} | ||
|
||
/** | ||
* Get multi-dense vector's value as an array of floats | ||
*/ | ||
public Iterator<float[]> getVectorValues() { | ||
return getCheckedVector().getVectors(); | ||
} | ||
|
||
/** | ||
* Get dense vector's magnitude | ||
*/ | ||
public float[] getMagnitudes() { | ||
return getCheckedVector().getMagnitudes(); | ||
} | ||
|
||
@Override | ||
public BytesRef get(int index) { | ||
throw new UnsupportedOperationException( | ||
"accessing a multi-vector field's value through 'get' or 'value' is not supported, use 'vectorValues' or 'magnitudes' instead." | ||
); | ||
} | ||
|
||
@Override | ||
public int size() { | ||
MultiDenseVector mdv = dvSupplier.getInternal(); | ||
if (mdv != null) { | ||
return mdv.size(); | ||
} | ||
return 0; | ||
} | ||
|
||
public interface MultiDenseVectorSupplier extends Supplier<BytesRef> { | ||
@Override | ||
default BytesRef getInternal(int index) { | ||
throw new UnsupportedOperationException(); | ||
} | ||
|
||
MultiDenseVector getInternal(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.