Skip to content

Commit

Permalink
Keras model evaluate.
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Nov 14, 2020
1 parent 9d08a31 commit d59db72
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
this.args = args;
_process_tensorlike();
num_samples = args.X.shape[0];
var batch_size = args.BatchSize;
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f)));
num_full_batches = num_samples / batch_size;
Expand Down
84 changes: 84 additions & 0 deletions src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
/// <summary>
/// Returns the loss value & metrics values for the model in test mode.
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="batch_size"></param>
/// <param name="verbose"></param>
/// <param name="steps"></param>
/// <param name="max_queue_size"></param>
/// <param name="workers"></param>
/// <param name="use_multiprocessing"></param>
/// <param name="return_dict"></param>
public void evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false,
bool return_dict = false)
{
data_handler = new DataHandler(new DataHandlerArgs
{
X = x,
Y = y,
BatchSize = batch_size,
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});

Console.WriteLine($"Testing...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = test_function(iterator);
}
Console.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}

IEnumerable<(string, Tensor)> test_function(OwnedIterator iterator)
{
var data = iterator.next();
var outputs = test_step(data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}

List<(string, Tensor)> test_step(Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, is_training: false);
var loss = compiled_loss.Call(y, y_pred);

compiled_metrics.update_state(y, y_pred);

return metrics.Select(x => (x.Name, x.result())).ToList();
}
}
}
1 change: 1 addition & 0 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public void fit(NDArray x, NDArray y,
stop_training = false;
_train_counter.assign(0);
bool first_step = true;
Console.WriteLine($"Training...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
Expand Down
13 changes: 13 additions & 0 deletions src/TensorFlowNET.Keras/Engine/Model.Save.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System.Collections.Generic;
using Tensorflow.Keras.Metrics;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public void save(string path)
{

}
}
}
2 changes: 2 additions & 0 deletions tensorflowlib/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ PM> Install-Package TensorFlow.NET
PM> Install-Package SciSharp.TensorFlow.Redist
```

Add `<RuntimeIdentifier>win-x64</RuntimeIdentifier>` to a `PropertyGroup` in your `.csproj` when targeting `.NET 472`.

### Run in Linux

Download Linux pre-built library and unzip `libtensorflow.so` and `libtensorflow_framework.so` into current running directory.
Expand Down

0 comments on commit d59db72

Please sign in to comment.