Skip to content

Commit

Permalink
test(gateway): make sure we are using the latest block in the statefu…
Browse files Browse the repository at this point in the history
…l validator
  • Loading branch information
yair-starkware committed Jul 9, 2024
1 parent 72dc6dc commit 1886f59
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
2 changes: 1 addition & 1 deletion crates/gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ blockifier= { workspace = true , features = ["testing"] }
cairo-lang-starknet-classes.workspace = true
cairo-vm.workspace = true
hyper.workspace = true
mockall.workspace = true
num-traits.workspace = true
papyrus_config.workspace = true
papyrus_rpc.workspace = true
Expand All @@ -38,6 +37,7 @@ validator.workspace = true

[dev-dependencies]
assert_matches.workspace = true
mockall.workspace = true
num-bigint.workspace = true
pretty_assertions.workspace = true
rstest.workspace = true
Expand Down
3 changes: 3 additions & 0 deletions crates/gateway/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use blockifier::blockifier::block::BlockInfo;
use blockifier::execution::contract_class::ContractClass;
use blockifier::state::errors::StateError;
use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult};
#[cfg(test)]
use mockall::automock;
use starknet_api::block::BlockNumber;
use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::StorageKey;
Expand All @@ -11,6 +13,7 @@ pub trait MempoolStateReader: BlockifierStateReader + Send + Sync {
fn get_block_info(&self) -> Result<BlockInfo, StateError>;
}

#[cfg_attr(test, automock)]
pub trait StateReaderFactory: Send + Sync {
fn get_state_reader_from_latest_block(&self) -> Box<dyn MempoolStateReader>;
fn get_state_reader(&self, block_number: BlockNumber) -> Box<dyn MempoolStateReader>;
Expand Down
6 changes: 3 additions & 3 deletions crates/gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use blockifier::execution::contract_class::ClassInfo;
use blockifier::state::cached_state::CachedState;
use blockifier::transaction::account_transaction::AccountTransaction;
use blockifier::versioned_constants::VersionedConstants;
use mockall::predicate::*;
use mockall::*;
#[cfg(test)]
use mockall::automock;
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::transaction::TransactionHash;

Expand All @@ -26,7 +26,7 @@ pub struct StatefulTransactionValidator {

type BlockifierStatefulValidator = GenericBlockifierStatefulValidator<Box<dyn MempoolStateReader>>;

#[automock]
#[cfg_attr(test, automock)]
pub trait StatefulTransactionValidatorTrait {
fn perform_validations(
&mut self,
Expand Down
28 changes: 26 additions & 2 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::sync::Arc;

use blockifier::blockifier::stateful_validator::StatefulValidatorError;
use blockifier::context::BlockContext;
use blockifier::test_utils::CairoVersion;
use blockifier::transaction::errors::{TransactionFeeError, TransactionPreValidationError};
use mockall::predicate::eq;
use num_bigint::BigUint;
use rstest::rstest;
use starknet_api::felt;
Expand All @@ -14,6 +17,7 @@ use test_utils::starknet_api_test_utils::{
use crate::compilation::compile_contract_class;
use crate::config::StatefulTransactionValidatorConfig;
use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult};
use crate::state_reader::{MockStateReaderFactory, StateReaderFactory};
use crate::state_reader_test_utils::local_test_state_reader_factory;
use crate::stateful_transaction_validator::{
MockStatefulTransactionValidatorTrait, StatefulTransactionValidator,
Expand Down Expand Up @@ -73,7 +77,27 @@ fn test_stateful_tx_validator(

#[test]
fn test_instantiate_validator() {
let state_reader_factory = local_test_state_reader_factory(CairoVersion::Cairo1, false);
// Using Arc and cloning because mock requires moving the state_reader_factory, but we need it
// twice.
let state_reader_factory =
Arc::new(local_test_state_reader_factory(CairoVersion::Cairo1, false));
let latest_block = state_reader_factory.state_reader.block_info.block_number;

let mut mock_state_reader_factory = MockStateReaderFactory::new();

// Make sure stateful_validator uses the latest block in the initiall call.
let state_reader_factory_clone = state_reader_factory.clone();
mock_state_reader_factory
.expect_get_state_reader_from_latest_block()
.returning(move || state_reader_factory_clone.get_state_reader_from_latest_block());

// Make sure stateful_validator uses the latest block in the following calls to the
// state_reader.
mock_state_reader_factory
.expect_get_state_reader()
.with(eq(latest_block))
.returning(move |bn| state_reader_factory.get_state_reader(bn));

let block_context = &BlockContext::create_for_testing();
let stateful_validator = StatefulTransactionValidator {
config: StatefulTransactionValidatorConfig {
Expand All @@ -83,6 +107,6 @@ fn test_instantiate_validator() {
chain_info: block_context.chain_info().clone().into(),
},
};
let blockifier_validator = stateful_validator.instantiate_validator(&state_reader_factory);
let blockifier_validator = stateful_validator.instantiate_validator(&mock_state_reader_factory);
assert!(blockifier_validator.is_ok());
}

0 comments on commit 1886f59

Please sign in to comment.