Skip to content

Commit

Permalink
Mapping Row=> Bean and deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
clun committed Dec 5, 2024
1 parent 7e1039b commit 38d60ba
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,30 @@
import com.datastax.astra.client.tables.definition.indexes.TableVectorIndexDefinition;
import com.datastax.astra.client.tables.definition.rows.Row;
import com.datastax.astra.client.tables.exceptions.TooManyRowsToCountException;
import com.datastax.astra.client.tables.mapping.EntityTable;
import com.datastax.astra.internal.api.DataAPIData;
import com.datastax.astra.internal.api.DataAPIResponse;
import com.datastax.astra.internal.api.DataAPIStatus;
import com.datastax.astra.internal.command.AbstractCommandRunner;
import com.datastax.astra.internal.command.CommandObserver;
import com.datastax.astra.internal.reflection.EntityBeanDefinition;
import com.datastax.astra.internal.reflection.EntityFieldDefinition;
import com.datastax.astra.internal.serdes.DataAPISerializer;
import com.datastax.astra.internal.serdes.tables.RowSerializer;
import com.datastax.astra.internal.utils.Assert;
import com.dtsx.astra.sdk.utils.Utils;
import com.fasterxml.jackson.databind.JavaType;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -568,17 +573,62 @@ public Optional<T> findOne(Filter filter, TableFindOneOptions findOneOptions) {
Command findOne = Command.create("findOne").withFilter(filter);
if (findOneOptions != null) {
findOne.withSort(findOneOptions.getSortArray())
.withProjection(findOneOptions.getProjectionArray())
.withOptions(new Document()
.appendIfNotNull(INPUT_INCLUDE_SIMILARITY, findOneOptions.includeSimilarity())
.appendIfNotNull(INPUT_INCLUDE_SORT_VECTOR, findOneOptions.includeSortVector())
.withProjection(findOneOptions.getProjectionArray())
.withOptions(new Document()
.appendIfNotNull(INPUT_INCLUDE_SIMILARITY, findOneOptions.includeSimilarity())
// not exposed in FindOne
//.appendIfNotNull(INPUT_INCLUDE_SORT_VECTOR, findOneOptions.includeSortVector())
);
}

DataAPIData data = runCommand(findOne, findOneOptions).getData();
if (data.getDocument() == null) {
return Optional.empty();
return Optional
.ofNullable(data.getDocument()
.map(Row.class))
.map(this::mapFromRow);
}

@SuppressWarnings("unchecked")
public T mapFromRow(Row row) {
try {
Class<T> rowClass = getRowClass();
if (rowClass == Row.class) {
return (T) row;
}
EntityBeanDefinition<T> beanDef = new EntityBeanDefinition<>(rowClass);
T input = rowClass.getDeclaredConstructor().newInstance();

for (EntityFieldDefinition fieldDef : beanDef.getFields().values()) {
String columnName = fieldDef.getColumnName() != null ?
fieldDef.getColumnName() :
fieldDef.getName();
Object columnValue = row.columnMap.get(columnName);
if (columnValue == null) {
continue; // Handle nulls as needed
}

// Use the JavaType directly
JavaType javaType = fieldDef.getJavaType();

// Convert the column value to the field's type
Object value = getSerializer()
.getMapper()
.convertValue(columnValue, javaType);

// Set the value to the bean
if (fieldDef.getSetter() != null) {
fieldDef.getSetter().invoke(input, value);
} else {
Field field = rowClass.getDeclaredField(fieldDef.getName());
field.setAccessible(true);
field.set(input, value);
}
}

return input;
} catch (Exception e) {
throw new RuntimeException("Failed to map row to bean", e);
}
return Optional.ofNullable(data.getDocument().map(getRowClass()));
}

public Optional<T> findOne(TableFindOneOptions findOneOptions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class TableFindOneOptions extends BaseOptions<TableFindOneOptions> {
/**
* Flag to include sortVector in the result when operating a semantic search.
*/
Boolean includeSortVector;
//Boolean includeSortVector;

/**
* Adding this on top of sort(Sort[] s) to allow for a more fluent API.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.datastax.astra.client.tables.definition.TableDuration;
import com.datastax.astra.internal.serdes.DataAPISerializer;
import com.datastax.astra.internal.serdes.tables.RowSerializer;
import com.datastax.astra.internal.utils.Assert;
import com.fasterxml.jackson.annotation.JsonAnyGetter;
import com.fasterxml.jackson.annotation.JsonAnySetter;
import lombok.NonNull;
Expand All @@ -46,6 +47,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;

Expand Down Expand Up @@ -267,6 +269,7 @@ public <K,V> Row addMap(String key, Map<K, V> myMap) {
* @throws ClassCastException if the value of the given key is not of type T
*/
public <T> T get(@NonNull final String key, @NonNull final Class<T> clazz) {
Assert.hasLength(key, "key");
return clazz.cast(SERIALIZER.convertValue(columnMap.get(key), clazz));
}

Expand Down Expand Up @@ -310,6 +313,10 @@ public Long getBigInt(final String key) {
return Long.parseLong(String.valueOf(get(key)));
}

public DataAPIVector getVector(final String key) {
return get(key, DataAPIVector.class);
}

/**
* Gets the value of the given key as a Boolean.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* #L%
*/

import com.datastax.astra.client.core.vector.DataAPIVector;
import com.datastax.astra.client.exceptions.DataAPIErrorDescriptor;
import com.datastax.astra.client.tables.definition.columns.ColumnDefinition;
import com.datastax.astra.internal.serdes.DataAPISerializer;
Expand All @@ -44,20 +45,34 @@ public class DataAPIStatus {
public transient Map<String, Object> payload = new HashMap<>();

/**
* PrimaryKey Schema returned
* Returned when insertMany with flag
*/
private LinkedHashMap<String, ColumnDefinition> primaryKeySchema;
private List<DataAPIDocumentResponse> documentResponses;

/**
* Returned when insertMany with flag
* Sort Vector returned if flag include sortVector is set to true
*/
private List<DataAPIDocumentResponse> documentResponses;
private DataAPIVector sortVector;

/**
* Warnings returned
*/
private List<DataAPIErrorDescriptor> warnings;

// ----------------------
// Tables Specifics
// ----------------------

/**
* PrimaryKey Schema returned
*/
private LinkedHashMap<String, ColumnDefinition> primaryKeySchema;

/**
* PrimaryKey Schema returned
*/
private LinkedHashMap<String, ColumnDefinition> projectionSchema;

/**
* Inserted ids.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ public EntityBeanDefinition(Class<T> clazz) {
EntityFieldDefinition field = new EntityFieldDefinition();
field.setName(property.getName());
field.setType(property.getPrimaryType().getRawClass());
field.setJavaType(property.getPrimaryType());

if (Map.class.isAssignableFrom(field.getType())) {
JavaType keyType = property.getPrimaryType().getBindings().getBoundType(0);
JavaType valueType = property.getPrimaryType().getBindings().getBoundType(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import com.datastax.astra.client.core.query.SortOrder;
import com.datastax.astra.client.core.vector.SimilarityMetric;
import com.datastax.astra.client.tables.definition.columns.ColumnTypes;
import com.fasterxml.jackson.databind.JavaType;
import lombok.Data;

import java.lang.reflect.Method;
import java.lang.reflect.Type;

@Data
public class EntityFieldDefinition {
Expand All @@ -37,6 +39,7 @@ public class EntityFieldDefinition {
private Method setter;
private Class<?> genericValueType;
private Class<?> genericKeyType;
private JavaType javaType;

// --- Table Hints --

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,22 @@ public DataAPIVectorDeserializer() {
@Override
public DataAPIVector deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
final JsonToken t = p.currentToken();
if (t == JsonToken.VALUE_EMBEDDED_OBJECT) {

// Expecting {"$binary":"PszMzb8ZmZo+TMzN"}
if (t == JsonToken.START_OBJECT) {
String fieldName = p.nextFieldName();
if ("$binary".equals(fieldName)) {
p.nextToken(); // Move to the value of $binary
byte[] base64Value = p.getBinaryValue();
p.nextToken(); // Move past the value
p.nextToken(); // Move past END_OBJECT
return new DataAPIVector(unpack(ctxt, base64Value));
}
// Understands [0.4, -0.6, 0.2]
} else if (t == JsonToken.START_ARRAY) {
float[] floats = ctxt.readValue(p, float[].class);
return new DataAPIVector(floats);
} else if (t == JsonToken.VALUE_EMBEDDED_OBJECT) {
Object emb = p.getEmbeddedObject();
if (emb instanceof byte[]) {
return new DataAPIVector(unpack(ctxt, (byte[]) emb));
Expand All @@ -47,6 +62,7 @@ public DataAPIVector deserialize(JsonParser p, DeserializationContext ctxt) thro
} else if (t == JsonToken.VALUE_STRING) {
return new DataAPIVector(unpack(ctxt, p.getBinaryValue()));
}

return new DataAPIVector((float[]) ctxt.handleUnexpectedToken(_valueClass, p));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

import java.text.SimpleDateFormat;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalTime;
import java.time.format.DateTimeFormatter;

Expand Down Expand Up @@ -99,7 +100,6 @@ public ObjectMapper getMapper() {
.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false)
.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS)
.setDateFormat(new SimpleDateFormat("yyyy-MM-dd"))
.registerModule(new JavaTimeModule())
.registerModule(new Jdk8Module())
.setSerializationInclusion(JsonInclude.Include.NON_NULL)
.setAnnotationIntrospector(new JacksonAnnotationIntrospector());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.datastax.astra.test.unit;

import com.datastax.astra.client.core.options.DataAPIClientOptions;
import com.datastax.astra.client.core.vector.DataAPIVector;
import com.datastax.astra.internal.serdes.tables.RowSerializer;
import org.junit.jupiter.api.Test;

import java.time.Instant;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

public class DataApiVectorSerializationTest {

@Test
public void binarySerialization() {
byte[] bytes = "hello".getBytes();
RowSerializer serializer = new RowSerializer();
String json = serializer.marshall(bytes);
// {"$binary":"aGVsbG8="}

System.out.println(json);
byte[] bytes2 = serializer.unMarshallBean(json, byte[].class);
assertThat(new String(bytes2)).isEqualTo("hello");
}

@Test
public void testVectorSerialization() {
DataAPIClientOptions.getSerdesOptions().encodeDataApiVectorsAsBase64(true);;
RowSerializer serializer = new RowSerializer();
DataAPIVector vector = new DataAPIVector(new float[]{0.4f, -0.6f, 0.2f});
System.out.println(serializer.marshall(vector));

String json1 = "{\"$binary\":\"PszMzb8ZmZo+TMzN\"}";
DataAPIVector vector2 = serializer.unMarshallBean(json1, DataAPIVector.class);
System.out.println(vector2.getEmbeddings());

String json2 = "[0.4, -0.6, 0.2]";
}

@Test
public void serializationInstant() {
String sample = "2024-12-04T15:04:07.203Z";
Instant.parse(sample);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.datastax.astra.client.tables;

import com.datastax.astra.client.DataAPIClients;
import com.datastax.astra.client.core.query.Filter;
import com.datastax.astra.client.core.query.Sort;
import com.datastax.astra.client.core.vector.DataAPIVector;
import com.datastax.astra.client.databases.Database;
import com.datastax.astra.client.tables.commands.options.TableFindOneOptions;
import com.datastax.astra.client.tables.definition.rows.Row;
import com.datastax.astra.internal.serdes.tables.RowSerializer;

import java.util.Optional;

import static com.datastax.astra.client.core.query.Filters.and;
import static com.datastax.astra.client.core.query.Filters.eq;
import static com.datastax.astra.client.core.query.Filters.gt;

public class FindOne {
public static void main(String[] args) {
Database db = DataAPIClients.localDbWithDefaultKeyspace();
// Database astraDb = new DataAPIClient(token).getDatabase(endpoint);

Table<Row> table = db.getTable("games");

Filter filter = and(
eq("match_id", "mtch_0"),
gt("round", 1),
eq("winner", "Victor"));

TableFindOneOptions options = new TableFindOneOptions()
//.projection(include("match_id", "winner", "field3"))
.sort(Sort.vector("m_vector", new DataAPIVector(new float[] {0.4f, -0.6f, 0.2f})))
.includeSimilarity(true);

Optional<Row> row = table.findOne(filter, options);
row.ifPresent(r -> {
System.out.println("Row: " + r);
DataAPIVector v = r.getVector("m_vector");
System.out.println(r.getInstant("when"));
});

Table<Game> tableGame = db.getTable("games", Game.class);
Optional<Game> row2 = tableGame.findOne(filter, options);
row2.ifPresent(game -> {
System.out.println("game: " + game.getVector().dimension());
System.out.println(game.getFighters());
System.out.println(game.getMatchId());
});
}
}
Loading

0 comments on commit 38d60ba

Please sign in to comment.