|
19 | 19 | import java.io.Serializable;
|
20 | 20 | import java.lang.annotation.Retention;
|
21 | 21 | import java.lang.annotation.RetentionPolicy;
|
| 22 | +import java.lang.reflect.Method; |
22 | 23 | import java.util.ArrayList;
|
23 | 24 | import java.util.Arrays;
|
| 25 | +import java.util.Iterator; |
24 | 26 | import java.util.List;
|
| 27 | +import java.util.Map; |
| 28 | +import java.util.concurrent.ConcurrentHashMap; |
25 | 29 | import java.util.function.Consumer;
|
26 | 30 | import java.util.function.Supplier;
|
27 | 31 |
|
|
31 | 35 | import org.junit.jupiter.api.extension.ExtendWith;
|
32 | 36 |
|
33 | 37 | import org.springframework.aop.Advisor;
|
| 38 | +import org.springframework.aop.Pointcut; |
34 | 39 | import org.springframework.aop.support.DefaultPointcutAdvisor;
|
35 | 40 | import org.springframework.aop.support.JdkRegexpMethodPointcut;
|
| 41 | +import org.springframework.aop.support.Pointcuts; |
| 42 | +import org.springframework.aop.support.StaticMethodMatcherPointcut; |
36 | 43 | import org.springframework.beans.factory.annotation.Autowired;
|
37 | 44 | import org.springframework.beans.factory.config.BeanDefinition;
|
38 | 45 | import org.springframework.context.annotation.AdviceMode;
|
|
60 | 67 | import org.springframework.security.authorization.AuthorizationManager;
|
61 | 68 | import org.springframework.security.authorization.method.AuthorizationInterceptorsOrder;
|
62 | 69 | import org.springframework.security.authorization.method.AuthorizationManagerBeforeMethodInterceptor;
|
| 70 | +import org.springframework.security.authorization.method.AuthorizationProxyMethodInterceptor; |
| 71 | +import org.springframework.security.authorization.method.AuthorizeReturnObject; |
63 | 72 | import org.springframework.security.authorization.method.MethodInvocationResult;
|
64 | 73 | import org.springframework.security.authorization.method.PrePostTemplateDefaults;
|
| 74 | +import org.springframework.security.config.Customizer; |
65 | 75 | import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
|
66 | 76 | import org.springframework.security.config.core.GrantedAuthorityDefaults;
|
67 | 77 | import org.springframework.security.config.test.SpringTestContext;
|
|
75 | 85 | import org.springframework.test.context.ContextConfiguration;
|
76 | 86 | import org.springframework.test.context.TestExecutionListeners;
|
77 | 87 | import org.springframework.test.context.junit.jupiter.SpringExtension;
|
| 88 | +import org.springframework.util.ClassUtils; |
78 | 89 | import org.springframework.web.context.ConfigurableWebApplicationContext;
|
79 | 90 | import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
|
80 | 91 |
|
81 | 92 | import static org.assertj.core.api.Assertions.assertThat;
|
82 | 93 | import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
| 94 | +import static org.assertj.core.api.Assertions.assertThatNoException; |
83 | 95 | import static org.mockito.ArgumentMatchers.any;
|
84 | 96 | import static org.mockito.Mockito.atLeastOnce;
|
85 | 97 | import static org.mockito.Mockito.mock;
|
@@ -662,6 +674,79 @@ public void methodWhenPostFilterMetaAnnotationThenFilters() {
|
662 | 674 | .containsExactly("dave");
|
663 | 675 | }
|
664 | 676 |
|
| 677 | + @Test |
| 678 | + @WithMockUser(authorities = "airplane:read") |
| 679 | + public void findByIdWhenAuthorizedResultThenAuthorizes() { |
| 680 | + this.spring.register(AuthorizeResultConfig.class).autowire(); |
| 681 | + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); |
| 682 | + Flight flight = flights.findById("1"); |
| 683 | + assertThatNoException().isThrownBy(flight::getAltitude); |
| 684 | + assertThatNoException().isThrownBy(flight::getSeats); |
| 685 | + } |
| 686 | + |
| 687 | + @Test |
| 688 | + @WithMockUser(authorities = "seating:read") |
| 689 | + public void findByIdWhenUnauthorizedResultThenDenies() { |
| 690 | + this.spring.register(AuthorizeResultConfig.class).autowire(); |
| 691 | + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); |
| 692 | + Flight flight = flights.findById("1"); |
| 693 | + assertThatNoException().isThrownBy(flight::getSeats); |
| 694 | + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude); |
| 695 | + } |
| 696 | + |
| 697 | + @Test |
| 698 | + @WithMockUser(authorities = "seating:read") |
| 699 | + public void findAllWhenUnauthorizedResultThenDenies() { |
| 700 | + this.spring.register(AuthorizeResultConfig.class).autowire(); |
| 701 | + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); |
| 702 | + flights.findAll().forEachRemaining((flight) -> { |
| 703 | + assertThatNoException().isThrownBy(flight::getSeats); |
| 704 | + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude); |
| 705 | + }); |
| 706 | + } |
| 707 | + |
| 708 | + @Test |
| 709 | + public void removeWhenAuthorizedResultThenRemoves() { |
| 710 | + this.spring.register(AuthorizeResultConfig.class).autowire(); |
| 711 | + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); |
| 712 | + flights.remove("1"); |
| 713 | + } |
| 714 | + |
| 715 | + @Test |
| 716 | + @WithMockUser(authorities = "airplane:read") |
| 717 | + public void findAllWhenPostFilterThenFilters() { |
| 718 | + this.spring.register(AuthorizeResultConfig.class).autowire(); |
| 719 | + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); |
| 720 | + flights.findAll() |
| 721 | + .forEachRemaining((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName) |
| 722 | + .doesNotContain("Kevin Mitnick")); |
| 723 | + } |
| 724 | + |
| 725 | + @Test |
| 726 | + @WithMockUser(authorities = "airplane:read") |
| 727 | + public void findAllWhenPreFilterThenFilters() { |
| 728 | + this.spring.register(AuthorizeResultConfig.class).autowire(); |
| 729 | + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); |
| 730 | + flights.findAll().forEachRemaining((flight) -> { |
| 731 | + flight.board(new ArrayList<>(List.of("John"))); |
| 732 | + assertThat(flight.getPassengers()).extracting(Passenger::getName).doesNotContain("John"); |
| 733 | + flight.board(new ArrayList<>(List.of("John Doe"))); |
| 734 | + assertThat(flight.getPassengers()).extracting(Passenger::getName).contains("John Doe"); |
| 735 | + }); |
| 736 | + } |
| 737 | + |
| 738 | + @Test |
| 739 | + @WithMockUser(authorities = "seating:read") |
| 740 | + public void findAllWhenNestedPreAuthorizeThenAuthorizes() { |
| 741 | + this.spring.register(AuthorizeResultConfig.class).autowire(); |
| 742 | + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); |
| 743 | + flights.findAll().forEachRemaining((flight) -> { |
| 744 | + List<Passenger> passengers = flight.getPassengers(); |
| 745 | + passengers.forEach((passenger) -> assertThatExceptionOfType(AccessDeniedException.class) |
| 746 | + .isThrownBy(passenger::getName)); |
| 747 | + }); |
| 748 | + } |
| 749 | + |
665 | 750 | private static Consumer<ConfigurableWebApplicationContext> disallowBeanOverriding() {
|
666 | 751 | return (context) -> ((AnnotationConfigWebApplicationContext) context).setAllowBeanDefinitionOverriding(false);
|
667 | 752 | }
|
@@ -1061,4 +1146,130 @@ List<String> resultsContainDave(List<String> list) {
|
1061 | 1146 |
|
1062 | 1147 | }
|
1063 | 1148 |
|
| 1149 | + @EnableMethodSecurity |
| 1150 | + @Configuration |
| 1151 | + static class AuthorizeResultConfig { |
| 1152 | + |
| 1153 | + @Bean |
| 1154 | + static Customizer<AuthorizationProxyMethodInterceptor> returnObject() { |
| 1155 | + return (interceptor) -> { |
| 1156 | + Pointcut pointcut = interceptor.getPointcut(); |
| 1157 | + interceptor.setPointcut(Pointcuts.intersection(new NotValueReturnTypePointcut(), pointcut)); |
| 1158 | + }; |
| 1159 | + } |
| 1160 | + |
| 1161 | + @Bean |
| 1162 | + FlightRepository flights() { |
| 1163 | + FlightRepository flights = new FlightRepository(); |
| 1164 | + Flight one = new Flight("1", 35000d, 35); |
| 1165 | + one.board(new ArrayList<>(List.of("Marie Curie", "Kevin Mitnick", "Ada Lovelace"))); |
| 1166 | + flights.save(one); |
| 1167 | + Flight two = new Flight("2", 32000d, 72); |
| 1168 | + two.board(new ArrayList<>(List.of("Albert Einstein"))); |
| 1169 | + flights.save(two); |
| 1170 | + return flights; |
| 1171 | + } |
| 1172 | + |
| 1173 | + @Bean |
| 1174 | + RoleHierarchy roleHierarchy() { |
| 1175 | + return RoleHierarchyImpl.withRolePrefix("").role("airplane:read").implies("seating:read").build(); |
| 1176 | + } |
| 1177 | + |
| 1178 | + private static class NotValueReturnTypePointcut extends StaticMethodMatcherPointcut { |
| 1179 | + |
| 1180 | + @Override |
| 1181 | + public boolean matches(Method method, Class<?> targetClass) { |
| 1182 | + return !ClassUtils.isSimpleValueType(method.getReturnType()); |
| 1183 | + } |
| 1184 | + |
| 1185 | + } |
| 1186 | + |
| 1187 | + } |
| 1188 | + |
| 1189 | + @AuthorizeReturnObject |
| 1190 | + static class FlightRepository { |
| 1191 | + |
| 1192 | + private final Map<String, Flight> flights = new ConcurrentHashMap<>(); |
| 1193 | + |
| 1194 | + Iterator<Flight> findAll() { |
| 1195 | + return this.flights.values().iterator(); |
| 1196 | + } |
| 1197 | + |
| 1198 | + Flight findById(String id) { |
| 1199 | + return this.flights.get(id); |
| 1200 | + } |
| 1201 | + |
| 1202 | + Flight save(Flight flight) { |
| 1203 | + this.flights.put(flight.getId(), flight); |
| 1204 | + return flight; |
| 1205 | + } |
| 1206 | + |
| 1207 | + void remove(String id) { |
| 1208 | + this.flights.remove(id); |
| 1209 | + } |
| 1210 | + |
| 1211 | + } |
| 1212 | + |
| 1213 | + @AuthorizeReturnObject |
| 1214 | + static class Flight { |
| 1215 | + |
| 1216 | + private final String id; |
| 1217 | + |
| 1218 | + private final Double altitude; |
| 1219 | + |
| 1220 | + private final Integer seats; |
| 1221 | + |
| 1222 | + private final List<Passenger> passengers = new ArrayList<>(); |
| 1223 | + |
| 1224 | + Flight(String id, Double altitude, Integer seats) { |
| 1225 | + this.id = id; |
| 1226 | + this.altitude = altitude; |
| 1227 | + this.seats = seats; |
| 1228 | + } |
| 1229 | + |
| 1230 | + String getId() { |
| 1231 | + return this.id; |
| 1232 | + } |
| 1233 | + |
| 1234 | + @PreAuthorize("hasAuthority('airplane:read')") |
| 1235 | + Double getAltitude() { |
| 1236 | + return this.altitude; |
| 1237 | + } |
| 1238 | + |
| 1239 | + @PreAuthorize("hasAuthority('seating:read')") |
| 1240 | + Integer getSeats() { |
| 1241 | + return this.seats; |
| 1242 | + } |
| 1243 | + |
| 1244 | + @PostAuthorize("hasAuthority('seating:read')") |
| 1245 | + @PostFilter("filterObject.name != 'Kevin Mitnick'") |
| 1246 | + List<Passenger> getPassengers() { |
| 1247 | + return this.passengers; |
| 1248 | + } |
| 1249 | + |
| 1250 | + @PreAuthorize("hasAuthority('seating:read')") |
| 1251 | + @PreFilter("filterObject.contains(' ')") |
| 1252 | + void board(List<String> passengers) { |
| 1253 | + for (String passenger : passengers) { |
| 1254 | + this.passengers.add(new Passenger(passenger)); |
| 1255 | + } |
| 1256 | + } |
| 1257 | + |
| 1258 | + } |
| 1259 | + |
| 1260 | + public static class Passenger { |
| 1261 | + |
| 1262 | + String name; |
| 1263 | + |
| 1264 | + public Passenger(String name) { |
| 1265 | + this.name = name; |
| 1266 | + } |
| 1267 | + |
| 1268 | + @PreAuthorize("hasAuthority('airplane:read')") |
| 1269 | + public String getName() { |
| 1270 | + return this.name; |
| 1271 | + } |
| 1272 | + |
| 1273 | + } |
| 1274 | + |
1064 | 1275 | }
|
0 commit comments