Skip to content

Commit

Permalink
Fix race condition in SnowflakeDatabaseMetaData.java show command (#1212
Browse files Browse the repository at this point in the history
)

* fix show command race condition

* add test for race condition

* add test for race condition

* remove showCommand from code
  • Loading branch information
sfc-gh-ext-simba-lb authored Dec 16, 2022
1 parent 729a2ec commit fcdae49
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,6 @@ public class SnowflakeDatabaseMetaData implements DatabaseMetaData {

private boolean stringsQuoted = false;

private String showCommand;

// Package-private function for displaying show command (for testing only)
String getShowCommand() {
return showCommand;
}

SnowflakeDatabaseMetaData(Connection connection) throws SQLException {
logger.debug("public SnowflakeDatabaseMetaData(SnowflakeConnection connection)", false);

Expand Down Expand Up @@ -1418,29 +1411,29 @@ public ResultSet getTables(
final Pattern compiledSchemaPattern = Wildcard.toRegexPattern(schemaPattern, true);
final Pattern compiledTablePattern = Wildcard.toRegexPattern(tableNamePattern, true);

showCommand = null;
String showTablesCommand = null;
final boolean viewOnly =
inputValidTableTypes.size() == 1 && "VIEW".equalsIgnoreCase(inputValidTableTypes.get(0));
final boolean tableOnly =
inputValidTableTypes.size() == 1 && "TABLE".equalsIgnoreCase(inputValidTableTypes.get(0));
if (viewOnly) {
showCommand = "show /* JDBC:DatabaseMetaData.getTables() */ views";
showTablesCommand = "show /* JDBC:DatabaseMetaData.getTables() */ views";
} else if (tableOnly) {
showCommand = "show /* JDBC:DatabaseMetaData.getTables() */ tables";
showTablesCommand = "show /* JDBC:DatabaseMetaData.getTables() */ tables";
} else {
showCommand = "show /* JDBC:DatabaseMetaData.getTables() */ objects";
showTablesCommand = "show /* JDBC:DatabaseMetaData.getTables() */ objects";
}

// only add pattern if it is not empty and not matching all character.
if (tableNamePattern != null
&& !tableNamePattern.isEmpty()
&& !tableNamePattern.trim().equals("%")
&& !tableNamePattern.trim().equals(".*")) {
showCommand += " like '" + tableNamePattern + "'";
showTablesCommand += " like '" + tableNamePattern + "'";
}

if (catalog == null) {
showCommand += " in account";
showTablesCommand += " in account";
} else if (catalog.isEmpty()) {
return SnowflakeDatabaseMetaDataResultSet.getEmptyResultSet(GET_TABLES, statement);
} else {
Expand All @@ -1450,18 +1443,18 @@ public ResultSet getTables(
// a schema if the current schema a user is connected to is different
// given that we don't support show tables without a known schema.
if (schemaPattern == null || isSchemaNameWildcardPattern(schemaPattern)) {
showCommand += " in database \"" + catalogEscaped + "\"";
showTablesCommand += " in database \"" + catalogEscaped + "\"";
} else if (schemaPattern.isEmpty()) {
return SnowflakeDatabaseMetaDataResultSet.getEmptyResultSet(GET_TABLES, statement);
} else {
String schemaUnescaped = unescapeChars(schemaPattern);
showCommand += " in schema \"" + catalogEscaped + "\".\"" + schemaUnescaped + "\"";
showTablesCommand += " in schema \"" + catalogEscaped + "\".\"" + schemaUnescaped + "\"";
}
}

logger.debug("sql command to get table metadata: {}", showCommand);
logger.debug("sql command to get table metadata: {}", showTablesCommand);

resultSet = executeAndReturnEmptyResultIfNotFound(statement, showCommand, GET_TABLES);
resultSet = executeAndReturnEmptyResultIfNotFound(statement, showTablesCommand, GET_TABLES);
sendInBandTelemetryMetadataMetrics(
resultSet,
"getTables",
Expand Down Expand Up @@ -1608,37 +1601,37 @@ public ResultSet getColumns(
final Pattern compiledTablePattern = Wildcard.toRegexPattern(tableNamePattern, true);
final Pattern compiledColumnPattern = Wildcard.toRegexPattern(columnNamePattern, true);

showCommand = "show /* JDBC:DatabaseMetaData.getColumns() */ columns";
String showColumnsCommand = "show /* JDBC:DatabaseMetaData.getColumns() */ columns";

if (columnNamePattern != null
&& !columnNamePattern.isEmpty()
&& !columnNamePattern.trim().equals("%")
&& !columnNamePattern.trim().equals(".*")) {
showCommand += " like '" + columnNamePattern + "'";
showColumnsCommand += " like '" + columnNamePattern + "'";
}

if (catalog == null) {
showCommand += " in account";
showColumnsCommand += " in account";
} else if (catalog.isEmpty()) {
return SnowflakeDatabaseMetaDataResultSet.getEmptyResultSet(
extendedSet ? GET_COLUMNS_EXTENDED_SET : GET_COLUMNS, statement);
} else {
String catalogEscaped = escapeSqlQuotes(catalog);
if (schemaPattern == null || isSchemaNameWildcardPattern(schemaPattern)) {
showCommand += " in database \"" + catalogEscaped + "\"";
showColumnsCommand += " in database \"" + catalogEscaped + "\"";
} else if (schemaPattern.isEmpty()) {
return SnowflakeDatabaseMetaDataResultSet.getEmptyResultSet(
extendedSet ? GET_COLUMNS_EXTENDED_SET : GET_COLUMNS, statement);
} else {
String schemaUnescaped = unescapeChars(schemaPattern);
if (tableNamePattern == null || Wildcard.isWildcardPatternStr(tableNamePattern)) {
showCommand += " in schema \"" + catalogEscaped + "\".\"" + schemaUnescaped + "\"";
showColumnsCommand += " in schema \"" + catalogEscaped + "\".\"" + schemaUnescaped + "\"";
} else if (tableNamePattern.isEmpty()) {
return SnowflakeDatabaseMetaDataResultSet.getEmptyResultSet(
extendedSet ? GET_COLUMNS_EXTENDED_SET : GET_COLUMNS, statement);
} else {
String tableNameUnescaped = unescapeChars(tableNamePattern);
showCommand +=
showColumnsCommand +=
" in table \""
+ catalogEscaped
+ "\".\""
Expand All @@ -1650,11 +1643,11 @@ public ResultSet getColumns(
}
}

logger.debug("sql command to get column metadata: {}", showCommand);
logger.debug("sql command to get column metadata: {}", showColumnsCommand);

ResultSet resultSet =
executeAndReturnEmptyResultIfNotFound(
statement, showCommand, extendedSet ? GET_COLUMNS_EXTENDED_SET : GET_COLUMNS);
statement, showColumnsCommand, extendedSet ? GET_COLUMNS_EXTENDED_SET : GET_COLUMNS);
sendInBandTelemetryMetadataMetrics(
resultSet,
"getColumns",
Expand Down Expand Up @@ -2617,37 +2610,37 @@ public ResultSet getStreams(
final Pattern compiledSchemaPattern = Wildcard.toRegexPattern(schemaPattern, true);
final Pattern compiledStreamNamePattern = Wildcard.toRegexPattern(streamName, true);

String showCommand = "show streams";
String showStreamsCommand = "show streams";

if (streamName != null
&& !streamName.isEmpty()
&& !streamName.trim().equals("%")
&& !streamName.trim().equals(".*")) {
showCommand += " like '" + streamName + "'";
showStreamsCommand += " like '" + streamName + "'";
}

if (catalog == null) {
showCommand += " in account";
showStreamsCommand += " in account";
} else if (catalog.isEmpty()) {
return SnowflakeDatabaseMetaDataResultSet.getEmptyResultSet(GET_STREAMS, statement);
} else {
String catalogEscaped = escapeSqlQuotes(catalog);
if (schemaPattern == null || isSchemaNameWildcardPattern(schemaPattern)) {
showCommand += " in database \"" + catalogEscaped + "\"";
showStreamsCommand += " in database \"" + catalogEscaped + "\"";
} else if (schemaPattern.isEmpty()) {
return SnowflakeDatabaseMetaDataResultSet.getEmptyResultSet(GET_STREAMS, statement);
} else {
String schemaUnescaped = unescapeChars(schemaPattern);
if (streamName == null || Wildcard.isWildcardPatternStr(streamName)) {
showCommand += " in schema \"" + catalogEscaped + "\".\"" + schemaUnescaped + "\"";
showStreamsCommand += " in schema \"" + catalogEscaped + "\".\"" + schemaUnescaped + "\"";
}
}
}

logger.debug("sql command to get stream metadata: {}", showCommand);
logger.debug("sql command to get stream metadata: {}", showStreamsCommand);

ResultSet resultSet =
executeAndReturnEmptyResultIfNotFound(statement, showCommand, GET_STREAMS);
executeAndReturnEmptyResultIfNotFound(statement, showStreamsCommand, GET_STREAMS);
sendInBandTelemetryMetadataMetrics(
resultSet, "getStreams", originalCatalog, originalSchemaPattern, streamName, "none");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import static org.junit.Assert.assertFalse;

import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import net.snowflake.client.ConditionalIgnoreRule;
import net.snowflake.client.RunningOnGithubAction;
import net.snowflake.client.category.TestCategoryOthers;
Expand Down Expand Up @@ -52,19 +58,16 @@ public void testGetMetaDataUseConnectionCtx() throws SQLException {

// Searches for tables only in database JDBC_DB1 and schema JDBC_SCHEMA11
ResultSet resultSet = databaseMetaData.getTables(null, null, null, null);
// Assert the show command scopes to schema level
assertEquals(
"show /* JDBC:DatabaseMetaData.getTables() */ objects in schema \"JDBC_DB1\".\"JDBC_SCHEMA11\"",
databaseMetaData.unwrap(SnowflakeDatabaseMetaData.class).getShowCommand());
assertEquals(1, getSizeOfResultSet(resultSet));

// Assert the tables are retrieved at schema level
resultSet.next();
assertEquals("JDBC_DB1", resultSet.getString(1));
assertEquals("JDBC_SCHEMA11", resultSet.getString(2));
// Searches for tables only in database JDBC_DB1 and schema JDBC_SCHEMA11
resultSet = databaseMetaData.getColumns(null, null, null, null);
// Assert the show command scopes to schema level
assertEquals(
"show /* JDBC:DatabaseMetaData.getColumns() */ columns in schema \"JDBC_DB1\".\"JDBC_SCHEMA11\"",
databaseMetaData.unwrap(SnowflakeDatabaseMetaData.class).getShowCommand());
assertEquals(3, getSizeOfResultSet(resultSet));
// Assert the columns are retrieved at schema level
resultSet.next();
assertEquals("JDBC_DB1", resultSet.getString(1));
assertEquals("JDBC_SCHEMA11", resultSet.getString(2));
}

@Test
Expand Down Expand Up @@ -179,4 +182,33 @@ public void testGetFunctionColumns() throws SQLException {
resultSet = databaseMetaData.getFunctionColumns("%", "%", "%", "%");
assertEquals(0, getSizeOfResultSet(resultSet));
}

/** Tests that calling getTables() concurrently doesn't cause data race condition. */
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testGetTablesRaceCondition()
throws SQLException, ExecutionException, InterruptedException {
try (Connection connection = getConnection()) {
String database = connection.getCatalog();
String schema = connection.getSchema();
DatabaseMetaData databaseMetaData = connection.getMetaData();

// Create 10 threads, each calls getTables() concurrently
ExecutorService executorService = Executors.newFixedThreadPool(10);
List<Future<?>> futures = new ArrayList<>();
for (int i = 0; i < 10; i++) {
futures.add(
executorService.submit(
() -> {
try {
databaseMetaData.getTables(database, schema, null, null);
} catch (SQLException e) {
throw new RuntimeException(e);
}
}));
}
executorService.shutdown();
for (int i = 0; i < 10; i++) futures.get(i).get();
}
}
}

0 comments on commit fcdae49

Please sign in to comment.