diff --git a/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/AbstractEndpoint.java b/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/AbstractEndpoint.java index 8da23987b629668dc2cd1fa679775771c270bf7f..af0232f142d04469d981aedf100e7aa8b1e4ebde 100644 --- a/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/AbstractEndpoint.java +++ b/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/AbstractEndpoint.java @@ -1,21 +1,25 @@ package at.tuwien.endpoint; import at.tuwien.SortType; +import at.tuwien.api.database.query.ExecuteStatementDto; import at.tuwien.entities.database.Database; import at.tuwien.entities.database.table.Table; import at.tuwien.entities.identifier.Identifier; -import at.tuwien.exception.DatabaseNotFoundException; -import at.tuwien.exception.IdentifierNotFoundException; -import at.tuwien.exception.PaginationException; -import at.tuwien.exception.SortException; +import at.tuwien.exception.*; import at.tuwien.service.DatabaseService; import at.tuwien.service.IdentifierService; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.FileUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.core.Authentication; +import java.io.File; +import java.io.IOException; +import java.nio.charset.Charset; import java.security.Principal; import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static at.tuwien.entities.identifier.VisibilityType.EVERYONE; @@ -99,6 +103,25 @@ public abstract class AbstractEndpoint { } } + protected void validateForbiddenStatements(ExecuteStatementDto data) throws QueryMalformedException, + QueryStoreException { + final StringBuilder regex = new StringBuilder("["); + try { + FileUtils.readLines(new File("src/main/resources/forbidden.txt"), Charset.defaultCharset()) + .forEach(regex::append); + } catch (IOException e) { + log.error("Failed to load forbidden keywords list, reason {}", e.getMessage()); + throw new QueryStoreException("Failed to load forbidden keywords list", e); + } + final Pattern pattern = Pattern.compile(regex + "]"); + final Matcher matcher = pattern.matcher(data.getStatement()); + final boolean found = matcher.find(); + if (found) { + log.error("Query contains blacklisted character"); + throw new QueryMalformedException("Query contains blacklisted character"); + } + } + protected Boolean hasQueuePermission(Long containerId, Long databaseId, Long tableId, String permissionCode, Principal principal) { log.trace("validate queue permission, containerId={}, databaseId={}, tableId={}, permissionCode={}, principal={}", diff --git a/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/QueryEndpoint.java b/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/QueryEndpoint.java index 0699f5b758213fc5e8cc60c280eed63fe16bae81..3abdccb05c4079688a4baa84a60dd778f1156d0d 100644 --- a/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/QueryEndpoint.java +++ b/fda-query-service/rest-service/src/main/java/at/tuwien/endpoint/QueryEndpoint.java @@ -20,6 +20,8 @@ import org.springframework.web.bind.annotation.*; import javax.validation.Valid; import javax.validation.constraints.NotNull; import java.security.Principal; +import java.util.regex.Matcher; +import java.util.regex.Pattern; @Log4j2 @RestController @@ -63,6 +65,7 @@ public class QueryEndpoint extends AbstractEndpoint { log.error("Failed to execute query: is empty"); throw new QueryMalformedException("Failed to execute query"); } + validateForbiddenStatements(data); validateDataParams(page, size, sortDirection, sortColumn); /* execute */ final QueryResultDto result = queryService.execute(containerId, databaseId, data, QueryTypeDto.QUERY, diff --git a/fda-query-service/rest-service/src/main/resources/forbidden.txt b/fda-query-service/rest-service/src/main/resources/forbidden.txt new file mode 100644 index 0000000000000000000000000000000000000000..89bdcae71069ee0e80035dbc2a0143d570ab253d --- /dev/null +++ b/fda-query-service/rest-service/src/main/resources/forbidden.txt @@ -0,0 +1 @@ + \* \ No newline at end of file diff --git a/fda-query-service/rest-service/src/test/java/at/tuwien/BaseUnitTest.java b/fda-query-service/rest-service/src/test/java/at/tuwien/BaseUnitTest.java index c54769404bc35fe1bfa03c472de3a0cfc84dda0c..34a86ce7d1fae92db962bb8fb341fbad56d50afc 100644 --- a/fda-query-service/rest-service/src/test/java/at/tuwien/BaseUnitTest.java +++ b/fda-query-service/rest-service/src/test/java/at/tuwien/BaseUnitTest.java @@ -2,6 +2,8 @@ package at.tuwien; import at.tuwien.api.database.query.QueryBriefDto; import at.tuwien.api.database.query.QueryDto; +import at.tuwien.api.database.query.QueryResultDto; +import at.tuwien.api.user.UserDetailsDto; import at.tuwien.api.user.UserDto; import at.tuwien.entities.container.image.ContainerImageDate; import at.tuwien.entities.database.table.columns.concepts.Concept; @@ -16,11 +18,17 @@ import at.tuwien.entities.database.Database; import at.tuwien.entities.database.table.Table; import at.tuwien.entities.database.table.columns.TableColumn; import at.tuwien.entities.database.table.columns.TableColumnType; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.core.userdetails.UserDetails; import org.springframework.test.context.TestPropertySource; +import java.security.Principal; import java.time.Instant; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static java.time.temporal.ChronoUnit.*; @@ -30,27 +38,41 @@ public abstract class BaseUnitTest { public final static long USER_1_ID = 1; public final static String USER_1_USERNAME = "junit"; public final static String USER_1_EMAIL = "junit@example.com"; + public final static String USER_1_PASSWORD = "password"; + public final static Instant USER_1_CREATED = Instant.now().minus(1, HOURS); + public final static User USER_1 = User.builder() .id(USER_1_ID) .username(USER_1_USERNAME) .email(USER_1_EMAIL) .emailVerified(true) .themeDark(false) - .password("password") + .password(USER_1_PASSWORD) .roles(Collections.singletonList(RoleType.ROLE_RESEARCHER)) .created(USER_1_CREATED) .lastModified(USER_1_CREATED) .build(); + public final static UserDto USER_1_DTO = UserDto.builder() .id(USER_1_ID) .username(USER_1_USERNAME) .email(USER_1_EMAIL) .emailVerified(true) .themeDark(false) - .password("password") + .password(USER_1_PASSWORD) .build(); + public final static UserDetails USER_1_DETAILS = UserDetailsDto.builder() + .username(USER_1_USERNAME) + .email(USER_1_EMAIL) + .password(USER_1_PASSWORD) + .authorities(List.of(new SimpleGrantedAuthority("ROLE_RESEARCHER"))) + .build(); + + public final static Principal USER_1_PRINCIPAL = new UsernamePasswordAuthenticationToken(USER_1_DETAILS, + USER_1_PASSWORD, USER_1_DETAILS.getAuthorities()); + public final static String DATABASE_NET = "fda-userdb"; public final static String BROKER_IMAGE = "fda-broker-service:latest"; @@ -1890,4 +1912,24 @@ public abstract class BaseUnitTest { .exchange(DATABASE_3_EXCHANGE) .build(); + public final static Long QUERY_1_RESULT_ID = 1L; + public final static Long QUERY_1_RESULT_NUMBER = 2L; + public final static List<Map<String, Object>> QUERY_1_RESULT_RESULT = List.of( + new HashMap<>() {{ + put("location", "Albury"); + put("lat", -36.0653583); + put("lng", 146.9112214); + }}, new HashMap<>() {{ + put("location", "Sydney"); + put("lat", -33.847927); + put("lng", 150.6517942); + }}); + + public final static QueryResultDto QUERY_1_RESULT_DTO = QueryResultDto.builder() + .id(QUERY_1_RESULT_ID) + .resultNumber(QUERY_1_RESULT_NUMBER) + .result(QUERY_1_RESULT_RESULT) + .build(); + + } diff --git a/fda-query-service/rest-service/src/test/java/at/tuwien/endpoint/QueryEndpointUnitTest.java b/fda-query-service/rest-service/src/test/java/at/tuwien/endpoint/QueryEndpointUnitTest.java index 7d53a289675c58b64d13628c7b8aca173d1920ad..86d07a94fa9b329b844cb7f32906f9d12bb51443 100644 --- a/fda-query-service/rest-service/src/test/java/at/tuwien/endpoint/QueryEndpointUnitTest.java +++ b/fda-query-service/rest-service/src/test/java/at/tuwien/endpoint/QueryEndpointUnitTest.java @@ -1,15 +1,34 @@ package at.tuwien.endpoint; import at.tuwien.BaseUnitTest; +import at.tuwien.SortType; +import at.tuwien.api.database.query.ExecuteStatementDto; +import at.tuwien.api.database.query.QueryResultDto; +import at.tuwien.api.database.query.QueryTypeDto; import at.tuwien.config.ReadyConfig; +import at.tuwien.exception.*; import at.tuwien.listener.impl.RabbitMqListenerImpl; +import at.tuwien.repository.jpa.ContainerRepository; +import at.tuwien.repository.jpa.DatabaseRepository; +import at.tuwien.repository.jpa.ImageRepository; +import at.tuwien.service.QueryService; +import at.tuwien.service.StoreService; import com.rabbitmq.client.Channel; import lombok.extern.log4j.Log4j2; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; import org.springframework.test.context.junit.jupiter.SpringExtension; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.when; + @Log4j2 @SpringBootTest @ExtendWith(SpringExtension.class) @@ -24,4 +43,75 @@ public class QueryEndpointUnitTest extends BaseUnitTest { @MockBean private RabbitMqListenerImpl rabbitMqListener; + @MockBean + private ImageRepository imageRepository; + + @MockBean + private ContainerRepository containerRepository; + + @MockBean + private DatabaseRepository databaseRepository; + + @MockBean + private QueryService queryService; + + @MockBean + private StoreService storeService; + + @Autowired + private QueryEndpoint queryEndpoint; + + @Test + public void execute_forbiddenKeyword_fails() throws UserNotFoundException, QueryStoreException, + TableMalformedException, DatabaseConnectionException, QueryMalformedException, ColumnParseException, + DatabaseNotFoundException, ImageNotSupportedException, ContainerNotFoundException { + final ExecuteStatementDto request = ExecuteStatementDto.builder() + .statement("SELECT w.* FROM `weather_aus` w") + .build(); + final Long page = 0L; + final Long size = 2L; + final SortType sortDirection = SortType.ASC; + final String sortColumn = "location"; + + /* mock */ + when(databaseRepository.findByContainerIdAndDatabaseId(CONTAINER_1_ID, DATABASE_1_ID)) + .thenReturn(Optional.of(DATABASE_1)); + when(queryService.execute(CONTAINER_1_ID, DATABASE_1_ID, request, QueryTypeDto.QUERY, + USER_1_PRINCIPAL, page, size, sortDirection, sortColumn)) + .thenReturn(QUERY_1_RESULT_DTO); + + /* test */ + assertThrows(QueryMalformedException.class, () -> { + queryEndpoint.execute(CONTAINER_1_ID, DATABASE_1_ID, request, page, size, USER_1_PRINCIPAL, sortDirection, + sortColumn); + }); + } + + @Test + public void execute_forbiddenKeyword2_fails() throws UserNotFoundException, QueryStoreException, + TableMalformedException, DatabaseConnectionException, QueryMalformedException, ColumnParseException, + DatabaseNotFoundException, ImageNotSupportedException, ContainerNotFoundException { + final ExecuteStatementDto request = ExecuteStatementDto.builder() + .statement("SELECT * FROM `weather_aus` w") + .build(); + final Long page = 0L; + final Long size = 2L; + final SortType sortDirection = SortType.ASC; + final String sortColumn = "location"; + + /* mock */ + when(databaseRepository.findByContainerIdAndDatabaseId(CONTAINER_1_ID, DATABASE_1_ID)) + .thenReturn(Optional.of(DATABASE_1)); + when(queryService.execute(CONTAINER_1_ID, DATABASE_1_ID, request, QueryTypeDto.QUERY, + USER_1_PRINCIPAL, page, size, sortDirection, sortColumn)) + .thenReturn(QUERY_1_RESULT_DTO); + + /* test */ + assertThrows(QueryMalformedException.class, () -> { + queryEndpoint.execute(CONTAINER_1_ID, DATABASE_1_ID, request, page, size, USER_1_PRINCIPAL, sortDirection, + sortColumn); + }); + } + + }