Skip to content

Commit

Permalink
PayloadMethodArgumentResolver supports Optional
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed Feb 13, 2023
1 parent 5b79a57 commit ccbb4bd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.messaging.handler.annotation.support;

import java.lang.annotation.Annotation;
import java.util.Optional;

import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
Expand All @@ -27,6 +28,7 @@
import org.springframework.messaging.converter.SmartMessageConverter;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
Expand Down Expand Up @@ -113,24 +115,30 @@ public Object resolveArgument(MethodParameter parameter, Message<?> message) thr
throw new IllegalStateException("@Payload SpEL expressions not supported by this resolver");
}

boolean isOptionalTargetClass = (parameter.getParameterType() == Optional.class);
Object payload = message.getPayload();
if (isEmptyPayload(payload)) {
if (ann == null || ann.required()) {
if ((ann == null || ann.required()) && !isOptionalTargetClass) {
String paramName = getParameterName(parameter);
BindingResult bindingResult = new BeanPropertyBindingResult(payload, paramName);
bindingResult.addError(new ObjectError(paramName, "Payload value must not be empty"));
throw new MethodArgumentNotValidException(message, parameter, bindingResult);
}
else {
return null;
return (isOptionalTargetClass ? Optional.empty() : null);
}
}

if (payload instanceof Optional<?> optional) {
payload = optional.get();
message = MessageBuilder.createMessage(payload, message.getHeaders());
}

Class<?> targetClass = resolveTargetClass(parameter, message);
Class<?> payloadClass = payload.getClass();
if (ClassUtils.isAssignable(targetClass, payloadClass)) {
validate(message, parameter, payload);
return payload;
return (isOptionalTargetClass ? Optional.of(payload) : payload);
}
else {
if (this.converter instanceof SmartMessageConverter smartConverter) {
Expand All @@ -144,7 +152,7 @@ public Object resolveArgument(MethodParameter parameter, Message<?> message) thr
payloadClass.getName() + "] to [" + targetClass.getName() + "] for " + message);
}
validate(message, parameter, payload);
return payload;
return (isOptionalTargetClass ? Optional.of(payload) : payload);
}
}

Expand All @@ -161,11 +169,14 @@ protected boolean isEmptyPayload(@Nullable Object payload) {
if (payload == null) {
return true;
}
else if (payload instanceof byte[]) {
return ((byte[]) payload).length == 0;
else if (payload instanceof byte[] bytes) {
return bytes.length == 0;
}
else if (payload instanceof String s) {
return !StringUtils.hasText(s);
}
else if (payload instanceof String) {
return !StringUtils.hasText((String) payload);
else if (payload instanceof Optional<?> optional) {
return optional.isEmpty();
}
else {
return false;
Expand All @@ -184,7 +195,7 @@ else if (payload instanceof String) {
* @since 5.2
*/
protected Class<?> resolveTargetClass(MethodParameter parameter, Message<?> message) {
return parameter.getParameterType();
return parameter.nestedIfOptional().getNestedParameterType();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.lang.annotation.Target;
import java.lang.reflect.Method;
import java.util.Locale;
import java.util.Optional;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -61,6 +62,8 @@ public class PayloadMethodArgumentResolverTests {

private MethodParameter paramWithSpelExpression;

private MethodParameter paramOptional;

private MethodParameter paramNotAnnotated;

private MethodParameter paramValidatedNotAnnotated;
Expand All @@ -74,16 +77,17 @@ public void setup() throws Exception {

Method payloadMethod = PayloadMethodArgumentResolverTests.class.getDeclaredMethod(
"handleMessage", String.class, String.class, Locale.class,
String.class, String.class, String.class, String.class);
String.class, Optional.class, String.class, String.class, String.class);

this.paramAnnotated = new SynthesizingMethodParameter(payloadMethod, 0);
this.paramAnnotatedNotRequired = new SynthesizingMethodParameter(payloadMethod, 1);
this.paramAnnotatedRequired = new SynthesizingMethodParameter(payloadMethod, 2);
this.paramWithSpelExpression = new SynthesizingMethodParameter(payloadMethod, 3);
this.paramValidated = new SynthesizingMethodParameter(payloadMethod, 4);
this.paramOptional = new SynthesizingMethodParameter(payloadMethod, 4);
this.paramValidated = new SynthesizingMethodParameter(payloadMethod, 5);
this.paramValidated.initParameterNameDiscovery(new DefaultParameterNameDiscoverer());
this.paramValidatedNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 5);
this.paramNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 6);
this.paramValidatedNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 6);
this.paramNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 7);
}

@Test
Expand Down Expand Up @@ -127,13 +131,33 @@ public void resolveNotRequired() throws Exception {
Message<?> emptyByteArrayMessage = MessageBuilder.withPayload(new byte[0]).build();
assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, emptyByteArrayMessage)).isNull();

Message<?> emptyStringMessage = MessageBuilder.withPayload("").build();
Message<?> emptyStringMessage = MessageBuilder.withPayload(" ").build();
assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, emptyStringMessage)).isNull();
assertThat(((Optional<?>) this.resolver.resolveArgument(this.paramOptional, emptyStringMessage)).isEmpty()).isTrue();

Message<?> emptyOptionalMessage = MessageBuilder.withPayload(Optional.empty()).build();
assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, emptyOptionalMessage)).isNull();

Message<?> notEmptyMessage = MessageBuilder.withPayload("ABC".getBytes()).build();
assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, notEmptyMessage)).isEqualTo("ABC");
}

@Test
public void resolveOptionalTarget() throws Exception {
Message<?> message = MessageBuilder.withPayload("ABC".getBytes()).build();
Object actual = this.resolver.resolveArgument(paramOptional, message);

assertThat(((Optional<?>) actual).get()).isEqualTo("ABC");
}

@Test
public void resolveOptionalSource() throws Exception {
Message<?> message = MessageBuilder.withPayload(Optional.of("ABC".getBytes())).build();
Object actual = this.resolver.resolveArgument(paramAnnotated, message);

assertThat(actual).isEqualTo("ABC");
}

@Test
public void resolveNonConvertibleParam() {
Message<?> notEmptyMessage = MessageBuilder.withPayload(123).build();
Expand Down Expand Up @@ -218,6 +242,7 @@ private void handleMessage(
@Payload(required=false) String paramNotRequired,
@Payload(required=true) Locale nonConvertibleRequiredParam,
@Payload("foo.bar") String paramWithSpelExpression,
@Payload Optional<String> optionalParam,
@MyValid @Payload String validParam,
@Validated String validParamNotAnnotated,
String paramNotAnnotated) {
Expand Down

0 comments on commit ccbb4bd

Please sign in to comment.