Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Account for enum member removal and optimize #2082

Merged
merged 1 commit into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

package software.amazon.smithy.model.transform;

import static java.util.stream.Collectors.toList;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -27,21 +25,18 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.TreeSet;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.loader.TopologicalShapeSort;
import software.amazon.smithy.model.shapes.AbstractShapeBuilder;
import software.amazon.smithy.model.shapes.CollectionShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.ShapeVisitor;
import software.amazon.smithy.model.shapes.SimpleShape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.model.traits.MixinTrait;
import software.amazon.smithy.utils.Pair;
import software.amazon.smithy.utils.SetUtils;

/**
* Replaces shapes while ensuring that the transformed model is in a
Expand Down Expand Up @@ -84,7 +79,7 @@ final class ReplaceShapes {
}

Model transform(ModelTransformer transformer, Model model) {
List<Shape> shouldReplace = determineShapesToReplace(model);
Collection<Shape> shouldReplace = determineShapesToReplace(model);
if (shouldReplace.isEmpty()) {
return model;
}
Expand All @@ -99,47 +94,47 @@ Model transform(ModelTransformer transformer, Model model) {
// is also updated to reference the updated member. Note that the updated container
// shapes will be a modified version of shapes present in the shouldReplace Set
// over shapes in the provided model.
getUpdatedContainers(model, shouldReplace).forEach(builder::addShape);
builder.addShapes(getUpdatedContainers(model, shouldReplace));

// Builds the model, then returns a model that removes any shapes that
// need to be removed after mapping over the shapes.
return transformer.removeShapes(builder.build(), getShapesToRemove(model, shouldReplace));
return transformer.removeShapes(builder.build(), getRemovedMembers(model, shouldReplace));
}

private List<Shape> determineShapesToReplace(Model model) {
return replacements.stream()
// Only replace shapes if they don't exist in the model or if they are
// different than the current shape in the model.
//
// This prevents infinite recursion when this transformer and the
// RemoveShapes transformer recursively call each other. It also
// prevents unnecessary allocations.
.filter(shape -> !model.getShape(shape.getId())
.filter(original -> original.equals(shape))
.isPresent())
// Sort the replacements to ensure that members come after container shapes.
// This ensures that updates to members take precedence over updates to containers.
.sorted((a, b) -> {
if (a.isMemberShape() ^ b.isMemberShape()) {
return a.isMemberShape() ? 1 : -1;
} else {
return 0;
}
})
.collect(toList());
private Collection<Shape> determineShapesToReplace(Model model) {
// Sort the replacements to ensure that members come after container shapes.
// This ensures that updates to members take precedence over updates to containers.
Set<Shape> result = new TreeSet<>((a, b) -> {
if (a.isMemberShape() ^ b.isMemberShape()) {
return a.isMemberShape() ? 1 : -1;
} else {
return a.compareTo(b);
}
});

for (Shape shape : replacements) {
// Only replace shapes if they don't exist in the model or if they are different from the current shape.
// This check prevents infinite recursion when this transformer and the RemoveShapes transformer
// recursively call each other, and it prevents unnecessary work.
if (!Objects.equals(shape, model.getShape(shape.getId()).orElse(null))) {
result.add(shape);
}
}

return result;
}

private void assertShapeTypeChangesSound(Model model, List<Shape> shouldReplace) {
private void assertShapeTypeChangesSound(Model model, Collection<Shape> shouldReplace) {
// Throws if any mappings attempted to change a shape's type.
shouldReplace.stream()
.flatMap(previous -> Pair.flatMapStream(previous, p -> model.getShape(p.getId())))
.filter(pair -> pair.getLeft().getType() != pair.getRight().getType())
.filter(pair -> !isReplacementValid(pair.getLeft(), pair.getRight()))
.forEach(pair -> {
for (Shape shape : shouldReplace) {
model.getShape(shape.getId()).ifPresent(previousShape -> {
if (shape.getType() != previousShape.getType() && !isReplacementValid(shape, previousShape)) {
throw new ModelTransformException(String.format(
"Cannot change the type of %s from %s to %s",
pair.getLeft().getId(), pair.getRight().getType(), pair.getLeft().getType()));
});
previousShape.getId(), previousShape.getType(), shape.getType()));
}
});
}
}

private boolean isReplacementValid(Shape left, Shape right) {
Expand All @@ -154,19 +149,19 @@ private boolean isReplacementValid(Shape left, Shape right) {
}
}

private Model.Builder createReplacedModelBuilder(Model model, List<Shape> shouldReplace) {
private Model.Builder createReplacedModelBuilder(Model model, Collection<Shape> shouldReplace) {
// Add member shapes to the builder. This builder is mutated
// by the visitor, which will ensure that newly added members
// show up in the model.
Model.Builder builder = model.toBuilder();
shouldReplace.forEach(shape -> {
for (Shape shape : shouldReplace) {
builder.addShape(shape);
builder.addShapes(shape.members());
});
}
return builder;
}

private void updateMixins(Model model, Model.Builder builder, List<Shape> replacements) {
private void updateMixins(Model model, Model.Builder builder, Collection<Shape> replacements) {
// Create a map to function as a mutable kind of intermediate model index so that as
// shapes are updated and built, they're used as mixins in shapes that depend on it.
Map<ShapeId, Shape> updatedShapes = new HashMap<>();
Expand All @@ -178,10 +173,11 @@ private void updateMixins(Model model, Model.Builder builder, List<Shape> replac
TopologicalShapeSort sorter = new TopologicalShapeSort();

// Add shapes that are mixins or use mixins.
model.shapes()
.filter(shape -> !shape.isMemberShape())
.filter(shape -> shape.hasTrait(MixinTrait.class) || !shape.getMixins().isEmpty())
.forEach(sorter::enqueue);
for (Shape shape : model.toSet()) {
if (!shape.isMemberShape() && (shape.hasTrait(MixinTrait.class) || !shape.getMixins().isEmpty())) {
sorter.enqueue(shape);
}
}

// Add _all_ of the replacements in case mixins or the Mixin trait were removed from updated shapes.
for (Shape shape : replacements) {
Expand All @@ -196,8 +192,8 @@ private void updateMixins(Model model, Model.Builder builder, List<Shape> replac
List<ShapeId> sorted = sorter.dequeueSortedShapes();
for (ShapeId toRebuild : sorted) {
Shape shape = updatedShapes.containsKey(toRebuild)
? updatedShapes.get(toRebuild)
: model.expectShape(toRebuild);
? updatedShapes.get(toRebuild)
: model.expectShape(toRebuild);
if (!shape.getMixins().isEmpty()) {
// We don't clear mixins here because a shape might have an inherited
// mixin member that was updated with an applied trait. Clearing mixins
Expand All @@ -206,8 +202,8 @@ private void updateMixins(Model model, Model.Builder builder, List<Shape> replac
AbstractShapeBuilder<?, ?> shapeBuilder = Shape.shapeToBuilder(shape);
for (ShapeId mixin : shape.getMixins()) {
Shape mixinShape = updatedShapes.containsKey(mixin)
? updatedShapes.get(mixin)
: model.expectShape(mixin);
? updatedShapes.get(mixin)
: model.expectShape(mixin);
shapeBuilder.addMixin(mixinShape);
}
Shape rebuilt = shapeBuilder.build();
Expand All @@ -217,21 +213,26 @@ private void updateMixins(Model model, Model.Builder builder, List<Shape> replac
}
}

private Set<Shape> getShapesToRemove(Model model, List<Shape> shouldReplace) {
// Ensure that when members are removed from a container shape
// (e.g., a structure with fewer members), the removed members are
// removed from the model.
return shouldReplace.stream()
.flatMap(shape -> Pair.flatMapStream(shape, s -> model.getShape(s.getId())))
.flatMap(pair -> {
RemoveShapesVisitor removeShapesVisitor = new RemoveShapesVisitor(pair.getRight());
return pair.getLeft().accept(removeShapesVisitor).stream();
})
.collect(Collectors.toSet());
private Set<Shape> getRemovedMembers(Model model, Collection<Shape> beingReplaced) {
// Ensure that when members are removed from a container shape (e.g., a structure with fewer members),
// the removed members are removed from the model.
Set<Shape> removedMembers = new HashSet<>();
for (Shape currentShape : beingReplaced) {
// Find the previous shape by ID from the model, and if present, determine what members were removed.
model.getShape(currentShape.getId()).ifPresent(previousShape -> {
Map<String, MemberShape> currentMembers = currentShape.getAllMembers();
for (MemberShape previousMember : previousShape.members()) {
if (!currentMembers.containsKey(previousMember.getMemberName())) {
removedMembers.add(previousMember);
}
}
});
}
return removedMembers;
}

private Set<Shape> getUpdatedContainers(Model model, List<Shape> shouldReplace) {
// Determine which shapes being updated are members, and group them by their container shapes.
private Set<Shape> getUpdatedContainers(Model model, Collection<Shape> shouldReplace) {
// Determine what shapes being updated are members, and group them by their container shapes.
Map<Shape, List<MemberShape>> containerToMemberMapping = new HashMap<>();
for (Shape shape : shouldReplace) {
shape.asMemberShape().ifPresent(member -> {
Expand Down Expand Up @@ -266,7 +267,7 @@ private static boolean isMemberPresent(MemberShape member, Shape shape) {
return shape.getMember(member.getMemberName()).filter(m -> m == member).isPresent();
}

private Optional<Shape> findMostUpToDateShape(ShapeId shapeId, Model model, List<Shape> shouldReplace) {
private Optional<Shape> findMostUpToDateShape(ShapeId shapeId, Model model, Collection<Shape> shouldReplace) {
// Shapes in the replacement set take precedence over shapes in the previous model.
// This accounts for newly added shapes and not overwriting changes also made to the
// container shape.
Expand All @@ -277,48 +278,4 @@ private Optional<Shape> findMostUpToDateShape(ShapeId shapeId, Model model, List
}
return model.getShape(shapeId);
}

/**
* Gets the member shapes of structures and unions that were
* removed as a result of mapping. These removed members need to also be
* removed from the Model.
*/
private static final class RemoveShapesVisitor extends ShapeVisitor.Default<Collection<Shape>> {

private final Shape previous;

RemoveShapesVisitor(Shape previous) {
this.previous = previous;
}

@Override
public Collection<Shape> getDefault(Shape shape) {
return SetUtils.of();
}

@Override
public Collection<Shape> unionShape(UnionShape shape) {
// Use previous.members() in case of a type change to prevent needing to call asUnion.
return onNamedMemberContainer(shape.getAllMembers(), previous.members());
}

@Override
public Collection<Shape> structureShape(StructureShape shape) {
// Use previous.members() in case of a type change to prevent needing to call asStructure.
return onNamedMemberContainer(shape.getAllMembers(), previous.members());
}

private Collection<Shape> onNamedMemberContainer(
Map<String, MemberShape> members,
Collection<MemberShape> previous
) {
List<Shape> result = new ArrayList<>();
for (MemberShape previousMember : previous) {
if (!members.containsKey(previousMember.getMemberName())) {
result.add(previousMember);
}
}
return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.junit.jupiter.api.Test;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.SourceLocation;
import software.amazon.smithy.model.shapes.EnumShape;
import software.amazon.smithy.model.shapes.IntegerShape;
import software.amazon.smithy.model.shapes.ListShape;
import software.amazon.smithy.model.shapes.LongShape;
Expand Down Expand Up @@ -392,4 +393,25 @@ public void doesNotOverwriteOtherContainerUpdatesWhenAlsoUpdatingMembers() {
// Ensure that the list shape has the new member.
assertThat(result.getShape(containerId).get().asListShape().get().getMember(), Matchers.is(newMember));
}

@Test
public void removingEnumMemberRemovesMemberFromUpdatedModel() {
EnumShape shapeA = EnumShape.builder()
.id("example#Foo")
.addMember("a", "A")
.addMember("b", "B")
.build();
EnumShape shapeB = shapeA.toBuilder().removeMember("b").build();

Model modelA = Model.builder().addShape(shapeA).build();

ReplaceShapes replaceShapes = new ReplaceShapes(Collections.singleton(shapeB));
Model modelB = replaceShapes.transform(ModelTransformer.create(), modelA);

assertEquals(modelB.expectShape(shapeB.getId()), shapeB);

// This previously would have failed because ReplaceShapes only removed members when they were removed from
// structures or unions. We now handle member removal generically instead.
assertThat(modelB.getShape(shapeA.getAllMembers().get("b").getId()), Matchers.equalTo(Optional.empty()));
}
}
Loading