Skip to content

Commit

Permalink
fixed eval func
Browse files Browse the repository at this point in the history
  • Loading branch information
cove9988 committed Jan 12, 2025
1 parent 0c3d2ce commit 2709024
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 50 deletions.
73 changes: 42 additions & 31 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,50 @@
import pandas as pd
import time
import datetime
from stable_baselines3 import PPO
from src.ppo_model import ForexTradingEnv, load_data
import logging
logging.basicConfig(level=logging.CRITICAL, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
from stable_baselines3 import PPO
from src.ppo_model import ForexTradingEnv
from src.util.read_config import EnvConfig
from src.util.logger_config import setup_logging

features = ['open', 'high', 'low', 'close', 'minute', 'hour', 'day', 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30', 'close_30_sma', 'close_60_sma']
model_file = '/home/paulg/github/tradesformer/data/model/AUDUSD/weekly/AUDUSD_2024_120.zip'
# csv_file = "/home/paulg/github/tradesformer/data/split/EURUSD/weekly/EURUSD_2024_103.csv"
data_directory = "/home/paulg/github/tradesformer/data/split/AUDUSD/weekly"
csv_files = glob.glob(os.path.join(data_directory, "*.csv"))
run_time = 10
_run = 1
for file in csv_files :
if _run > run_time: break
# Read the CSV file
env = ForexTradingEnv(file,features)
model = PPO.load(model_file, env=env)
# %%
observation, info = env.reset()
done = False
total_buy = 0
total_sell = 0
totoal_rewards = 0
mode = 'graph' #'graph', 'human'
while not done:
action, _states = model.predict(observation)
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
env.render(mode = mode)
print(f'------rewards:{totoal_rewards}-----buy:{total_buy}--sell:{total_sell}------')
_run += 1
logger = logging.getLogger(__name__)

def eval(data_directory, env_config_file, model_file, asset, run_time = 10, mode = 'graph', sequence_length=24):
csv_files = glob.glob(os.path.join(data_directory, "*.csv"))
cf = EnvConfig(env_config_file)
features = cf.env_parameters("observation_list")
print(features)

_run = 1
for file in csv_files :
if _run > run_time: break
# Read the CSV file
env = ForexTradingEnv(file,cf,asset,features,sequence_length)
model = PPO.load(model_file, env=env)
# %%
observation, info = env.reset()
done = False
total_buy = 0
total_sell = 0
total_rewards = 0
step = 0
while not done:
action, _states = model.predict(observation)
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
print(f'step:{step} rwd:{reward} action:{action} ')
step += 1
total_rewards += reward
if action == 1: total_buy += 1
if action == 2: total_buy += 1
env.render(mode = mode)
print(f'------rewards:{total_rewards}-----buy:{total_buy}--sell:{total_sell}------')
_run += 1

# %%
if __name__ == "__main__":
asset = "AUDUSD"
env_config_file = '/home/paulg/github/tradesformer/src/configure.json'
model_file = '/home/paulg/github/tradesformer/data/model/AUDUSD/weekly/AUDUSD_2024_120.zip'
data_directory = "/home/paulg/github/tradesformer/data/split/AUDUSD/weekly"
setup_logging(asset=asset, console_level=logging.ERROR, file_level=logging.INFO)
eval(data_directory, env_config_file, model_file, asset, run_time= 5, mode = 'graph', sequence_length=24)
24 changes: 6 additions & 18 deletions src/ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(self, file, cf, asset, features, sequence_length = 24, logger_show
self.balance = self.balance_initial
self.positions = []
self.ticket_id = 0
self.total_rewards = 0
self.action_aggregator =ActionAggregator()
# self.reward_calculator = RewardCalculator(
# self.data, cf, self.shaping_reward, self.stop_loss, self.profit_taken, self.backward_window
Expand All @@ -182,6 +183,7 @@ def reset(self, seed=None, options=None):
self.balance = self.balance_initial
self.positions = []
self.ticket_id = 0
self.total_rewards = 0
self.action_aggregator =ActionAggregator()

# self.current_step = np.random.randint(self.sequence_length, self.max_steps)
Expand Down Expand Up @@ -282,10 +284,11 @@ def step(self, action):
position_reward, closed = self._calculate_reward(position)
if not closed: open_positon += 1
reward += position_reward

self.total_rewards += position_reward
# Execute action
_action, stability_reward = self.action_aggregator.add_action(action)
reward += stability_reward
self.total_rewards += stability_reward
# logger.info(f"Step:{self.current_step}: action: {action}, real: {_action} stability reward:{stability_reward} ")
if _action in (1, 2) and open_positon < self.max_current_holding :
self.ticket_id += 1
Expand Down Expand Up @@ -319,7 +322,7 @@ def step(self, action):
logger.info(f"Step:{self.current_step} Tkt:{position['Ticket']} {position['Type']} Rwd:{position['pips']} SL:{position['SL']} PT:{position['PT']}")
else:
reward -= 1 # no open any position, encourage open position

self.total_rewards -= 1
# Move to the next time step
self.current_step += 1

Expand All @@ -337,7 +340,7 @@ def step(self, action):
if position["Type"] == "Buy":
buy +=1

logger.info(f'--- Position:{len(self.positions)}/Buy:{buy} Balance: {self.balance} step {self.current_step }')
logger.info(f'--- Position:{len(self.positions)}/Buy:{buy} TtlRwds: {self.total_rewards} Balance: {self.balance} step {self.current_step }')
# Additional info
info = {}
truncated = False
Expand Down Expand Up @@ -367,18 +370,3 @@ def render(self, mode='human', title=None, **kwargs):
p.plot()


def eval(model_file,env):
# Evaluate the agent
model = PPO.load(model_file, env=env)
observation, info = env.reset()
done = False

while not done:
action, _states = model.predict(observation)
observation, reward, terminated, truncated, info = env.step(action)
done = terminated
env.render()

# Save the model
logger.info("Model saved to 'ppo_forex_transformer'")

4 changes: 3 additions & 1 deletion src/util/logger_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ def setup_logging(asset, console_level=logging.INFO, file_level=logging.DEBUG):
logger.addHandler(file_handler)

# Optional: Avoid duplicate logs in libraries by disabling propagation
logging.getLogger("matplotlib").setLevel(logging.WARNING) # Example: Suppress matplotlib logs
logging.getLogger("matplotlib").setLevel(logging.ERROR) # Suppress matplotlib logs
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("mplfinance").setLevel(logging.ERROR)

0 comments on commit 2709024

Please sign in to comment.