Skip to content

Commit

Permalink
Problem setup
Browse files Browse the repository at this point in the history
  • Loading branch information
danielholmes committed Feb 17, 2018
1 parent ee25de3 commit 4b58738
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/NeuralNet/Cost.hs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module NeuralNet.Cost (cost) where
module NeuralNet.Cost (computeCost) where

import Data.Matrix
import NeuralNet.Matrix


cost :: Matrix Double -> Matrix Double -> Double
cost al y
computeCost :: Matrix Double -> Matrix Double -> Double
computeCost al y
| nrows al /= 1 = error "al should be a row vector"
| nrows y /= 1 = error "y should be a row vector"
| ncols y /= ncols al = error "y and al should be same length"
Expand Down
42 changes: 38 additions & 4 deletions src/NeuralNet/Problem.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
module NeuralNet.Problem (Problem (), createProblem, runProblem) where

import System.Random
import NeuralNet.Example
import NeuralNet.Net
import NeuralNet.Layer
import NeuralNet.Cost
import NeuralNet.Train


type NumIterations = Int
Expand All @@ -15,14 +19,44 @@ type IterationNum = Int

type Cost = Double

type Accuracy = Double
data RunStep = RunStep IterationNum Cost

data RunStep = RunStep IterationNum Cost Accuracy

problemNNDef :: Problem -> NeuralNetDefinition
problemNNDef (Problem d _ _ _) = d

problemExampleSet :: Problem -> ExampleSet
problemExampleSet (Problem _ e _ _) = e

problemLearningRate :: Problem -> LearningRate
problemLearningRate (Problem _ _ l _) = l

problemNumIterations :: Problem -> NumIterations
problemNumIterations (Problem _ _ _ i) = i

createProblem :: NeuralNetDefinition -> ExampleSet -> LearningRate -> NumIterations -> Problem
createProblem def examples learningRate numIterations
| not (isExampleSetCompatibleWithNNDef examples def) = error "Examples not compatible with nn"
| numIterations <= 0 = error "Must provide positive numIterations"
| learningRate <= 0 = error "Must provide positive learningRate"
| otherwise = Problem def examples learningRate numIterations

runProblem :: Problem -> (Double -> Double) -> (NeuralNet, [RunStep])
runProblem _ _ = (error "TODO", [RunStep 1 1 1])
runProblem :: StdGen -> Problem -> (NeuralNet, [RunStep])
runProblem g p = (resultNN, allSteps)
where
startNN = initNN g (problemNNDef p)
startStep = RunStep 0 1
iterations = [1..(problemNumIterations p)]
allNNAndSteps = reverse (foldl (\steps@((nn,_):_) i -> runProblemStep p i nn : steps) [(startNN, startStep)] iterations)
allSteps = map snd allNNAndSteps
resultNN = fst (last allNNAndSteps)

runProblemStep :: Problem -> IterationNum -> NeuralNet -> (NeuralNet, RunStep)
runProblemStep p i nn = (newNN, RunStep i cost)
where
exampleSet = problemExampleSet p
forwardSteps = nnForwardSet nn exampleSet
al = forwardPropA (last forwardSteps)
cost = computeCost al (exampleSetY exampleSet)
grads = nnBackward nn forwardSteps exampleSet
newNN = updateNNParams nn grads (problemLearningRate p)
4 changes: 2 additions & 2 deletions test/NeuralNet/CostSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import NeuralNet.Cost
costSpec :: SpecWith ()
costSpec =
describe "NeuralNet.Cost" $
describe "cost" $ do
describe "computeCost" $ do
it "calculates correctly for example" $
let
al = fromLists [[0.8, 0.9, 0.4]]
y = fromLists [[1, 1, 1]]
in cost al y `shouldBe` 0.41493159961539694
in computeCost al y `shouldBe` 0.41493159961539694

0 comments on commit 4b58738

Please sign in to comment.