-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example and command line interface around it
- Loading branch information
1 parent
4b58738
commit 9ae9e78
Showing
10 changed files
with
194 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,58 @@ | ||
{-# LANGUAGE DeriveDataTypeable #-} | ||
{-# OPTIONS_GHC -fno-cse #-} | ||
module Main where | ||
|
||
import NeuralNet.Problem | ||
import NeuralNet.Net | ||
import NeuralNet.Activation | ||
import NeuralNet.Example | ||
import NeuralNet.External | ||
import System.IO () | ||
import System.Random | ||
import Data.List | ||
import System.Console.CmdArgs | ||
|
||
|
||
data RunOptions = RunOptions {trainPath :: FilePath | ||
,testPath :: FilePath | ||
,constInitWeights :: Bool | ||
,numIterations :: Int | ||
,learningRate :: Double} | ||
deriving (Eq, Show, Data, Typeable) | ||
|
||
optionsDef :: RunOptions | ||
optionsDef = RunOptions | ||
{trainPath = def &= typ "TRAINFILE" &= argPos 0 | ||
,testPath = def &= typ "TESTFILE" &= argPos 1 | ||
,constInitWeights = def &= name "c" &= name "const-init-weights" &= help "Use 0 for initial weights" | ||
,numIterations = 1000 &= name "i" &= name "num-iterations" &= help "Number of training iterations" | ||
,learningRate = 0.005 &= name "l" &= name "learning-rate" &= help "Learning rate (alpha)"} &= | ||
help "Simple Neural Net trainer" &= | ||
summary "NeuralNet v0.0.0, (C) Daniel Holmes" | ||
|
||
main :: IO () | ||
main = putStrLn "Hello" | ||
main = do | ||
options <- cmdArgs optionsDef | ||
problem <- loadProblem options | ||
gen <- createWeightInit options | ||
let (nn, steps) = runProblem gen problem (\yh y -> ((round yh) :: Int) == ((round y) :: Int)) | ||
putStrLn (intercalate "\n" (map formatStepLine steps)) | ||
print nn | ||
|
||
formatStepLine :: RunStep -> String | ||
formatStepLine s = show (runStepIteration s) ++ ") " ++ show (runStepCost s) ++ " " ++ show accuracy ++ "%" | ||
where accuracy = (round (100.0 * runStepAccuracy s)) :: Int | ||
|
||
createWeightInit :: RunOptions -> IO WeightInitialiser | ||
createWeightInit options = case constInitWeights options of | ||
True -> return (Const 0) | ||
False -> do | ||
stdGen <- getStdGen | ||
return (Random stdGen) | ||
|
||
loadProblem :: RunOptions -> IO Problem | ||
loadProblem (RunOptions {trainPath = train, testPath = test, numIterations = n, learningRate = a}) = do | ||
trainSet <- loadExampleSet train | ||
testSet <- loadExampleSet test | ||
let nnDef = createLogRegDefinition (exampleSetN testSet) Sigmoid | ||
return (createProblem nnDef trainSet testSet a n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
1,2,1 | ||
0,4,0 | ||
1,3,1 | ||
1,0,1 | ||
0,1,0 | ||
1,5,1 | ||
1,58,1 | ||
0,72,0 | ||
0,11,0 | ||
1,6,1 | ||
1,100,1 | ||
1,7000,1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
1,2,1 | ||
1,3,1 | ||
0,2,0 | ||
1,5,1 | ||
0,7,0 | ||
1,5,1 | ||
0,6,0 | ||
1,88,1 | ||
0,99,0 | ||
0,11,0 | ||
1,23,1 | ||
0,2,0 | ||
1,1,1 | ||
0,1,0 | ||
1,100,1 | ||
1,6,1 | ||
0,7,0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
module NeuralNet.External (loadExampleSet) where | ||
|
||
import NeuralNet.Example | ||
import Text.CSV | ||
import System.FilePath.Posix | ||
import Data.Char | ||
|
||
|
||
loadExampleSet :: FilePath -> IO ExampleSet | ||
loadExampleSet p = | ||
case map toLower (takeExtension p) of | ||
".csv" -> loadCsvExampleSet p | ||
_ -> error ("Don't know how to load example set from " ++ show p) | ||
|
||
loadCsvExampleSet :: FilePath -> IO ExampleSet | ||
loadCsvExampleSet p = do | ||
result <- parseCSVFromFile p | ||
case result of | ||
Right c -> return (createExampleSet (csvToExamples c)) | ||
_ -> error ("Error reading" ++ p) | ||
|
||
csvToExamples :: CSV -> [Example] | ||
csvToExamples records = map (\r -> (init r, last r)) doubleRecords | ||
where doubleRecords = map (map read) records |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,86 @@ | ||
module NeuralNet.Problem (Problem (), createProblem, runProblem) where | ||
module NeuralNet.Problem ( | ||
Problem (), | ||
RunStep (), | ||
createProblem, | ||
runProblem, | ||
runStepIteration, | ||
runStepAccuracy, | ||
runStepCost, | ||
problemTestSet | ||
) where | ||
|
||
import System.Random | ||
import NeuralNet.Example | ||
import NeuralNet.Net | ||
import NeuralNet.Layer | ||
import NeuralNet.Cost | ||
import NeuralNet.Train | ||
import Data.Matrix | ||
|
||
|
||
type NumIterations = Int | ||
|
||
type LearningRate = Double | ||
|
||
data Problem = Problem NeuralNetDefinition ExampleSet LearningRate NumIterations | ||
data Problem = Problem NeuralNetDefinition ExampleSet ExampleSet LearningRate NumIterations | ||
deriving (Show, Eq) | ||
|
||
|
||
type IterationNum = Int | ||
|
||
type Cost = Double | ||
type Accuracy = Double | ||
data RunStep = RunStep IterationNum Cost Accuracy | ||
deriving (Show, Eq) | ||
|
||
runStepIteration :: RunStep -> IterationNum | ||
runStepIteration (RunStep i _ _) = i | ||
|
||
data RunStep = RunStep IterationNum Cost | ||
runStepCost :: RunStep -> Cost | ||
runStepCost (RunStep _ c _) = c | ||
|
||
runStepAccuracy :: RunStep -> Accuracy | ||
runStepAccuracy (RunStep _ _ a) = a | ||
|
||
problemNNDef :: Problem -> NeuralNetDefinition | ||
problemNNDef (Problem d _ _ _) = d | ||
problemNNDef (Problem d _ _ _ _) = d | ||
|
||
problemTrainSet :: Problem -> ExampleSet | ||
problemTrainSet (Problem _ t _ _ _) = t | ||
|
||
problemExampleSet :: Problem -> ExampleSet | ||
problemExampleSet (Problem _ e _ _) = e | ||
problemTestSet :: Problem -> ExampleSet | ||
problemTestSet (Problem _ _ t _ _) = t | ||
|
||
problemLearningRate :: Problem -> LearningRate | ||
problemLearningRate (Problem _ _ l _) = l | ||
problemLearningRate (Problem _ _ _ l _) = l | ||
|
||
problemNumIterations :: Problem -> NumIterations | ||
problemNumIterations (Problem _ _ _ i) = i | ||
problemNumIterations (Problem _ _ _ _ i) = i | ||
|
||
createProblem :: NeuralNetDefinition -> ExampleSet -> LearningRate -> NumIterations -> Problem | ||
createProblem def examples learningRate numIterations | ||
| not (isExampleSetCompatibleWithNNDef examples def) = error "Examples not compatible with nn" | ||
createProblem :: NeuralNetDefinition -> ExampleSet -> ExampleSet -> LearningRate -> NumIterations -> Problem | ||
createProblem def trainSet testSet learningRate numIterations | ||
| not (isExampleSetCompatibleWithNNDef trainSet def) = error "trainSet not compatible with nn" | ||
| not (isExampleSetCompatibleWithNNDef testSet def) = error "testSet not compatible with nn" | ||
| numIterations <= 0 = error "Must provide positive numIterations" | ||
| learningRate <= 0 = error "Must provide positive learningRate" | ||
| otherwise = Problem def examples learningRate numIterations | ||
| otherwise = Problem def trainSet testSet learningRate numIterations | ||
|
||
runProblem :: StdGen -> Problem -> (NeuralNet, [RunStep]) | ||
runProblem g p = (resultNN, allSteps) | ||
runProblem :: WeightInitialiser -> Problem -> (Double -> Double -> Bool) -> (NeuralNet, [RunStep]) | ||
runProblem g p accuracyCheck = (resultNN, tail allSteps) | ||
where | ||
startNN = initNN g (problemNNDef p) | ||
startStep = RunStep 0 1 | ||
startStep = RunStep 0 1 0 | ||
iterations = [1..(problemNumIterations p)] | ||
allNNAndSteps = reverse (foldl (\steps@((nn,_):_) i -> runProblemStep p i nn : steps) [(startNN, startStep)] iterations) | ||
allNNAndSteps = reverse (foldl (\steps@((nn,_):_) i -> runProblemStep p i nn accuracyCheck : steps) [(startNN, startStep)] iterations) | ||
allSteps = map snd allNNAndSteps | ||
resultNN = fst (last allNNAndSteps) | ||
|
||
runProblemStep :: Problem -> IterationNum -> NeuralNet -> (NeuralNet, RunStep) | ||
runProblemStep p i nn = (newNN, RunStep i cost) | ||
runProblemStep :: Problem -> IterationNum -> NeuralNet -> (Double -> Double -> Bool) -> (NeuralNet, RunStep) | ||
runProblemStep p i nn accuracyCheck = (newNN, RunStep i cost accuracy) | ||
where | ||
exampleSet = problemExampleSet p | ||
exampleSet = problemTrainSet p | ||
forwardSteps = nnForwardSet nn exampleSet | ||
al = forwardPropA (last forwardSteps) | ||
cost = computeCost al (exampleSetY exampleSet) | ||
y = exampleSetY exampleSet | ||
cost = computeCost al y | ||
grads = nnBackward nn forwardSteps exampleSet | ||
newNN = updateNNParams nn grads (problemLearningRate p) | ||
accuracy = fromIntegral (length (filter (uncurry accuracyCheck) (zip (toList al) (toList y)))) / fromIntegral (ncols y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters