Skip to content

Commit

Permalink
Add pretrained version of DenseRAE
Browse files Browse the repository at this point in the history
  • Loading branch information
ProfFan committed Nov 29, 2020
1 parent 5a0ee08 commit fbe71f0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
6 changes: 3 additions & 3 deletions Scripts/Fan10.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ struct Fan10: ParsableCommand {

func getTrainingDataEM(
from dataset: OISTBeeVideo,
numberForeground: Int = 3000,
numberBackground: Int = 3000
numberForeground: Int = 300,
numberBackground: Int = 300
) -> [LikelihoodModel.Datum] {
let bgBoxes = dataset.makeBackgroundBoundingBoxes(patchSize: (40, 70), batchSize: numberBackground).map {
(frame: $0.frame, type: LikelihoodModel.PatchType.bg, obb: $0.obb)
Expand Down Expand Up @@ -58,7 +58,7 @@ struct Fan10: ParsableCommand {
let generator = ARC4RandomNumberGenerator(seed: 42)
var em = MonteCarloEM<LikelihoodModel>(sourceOfEntropy: generator)

let trainingDataset = OISTBeeVideo(directory: dataDir, length: 100)!
let trainingDataset = OISTBeeVideo(directory: dataDir, length: 30)!

let trainingData = getTrainingDataEM(from: trainingDataset)

Expand Down
2 changes: 1 addition & 1 deletion Scripts/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import PenguinParallelWithFoundation

struct Scripts: ParsableCommand {
static var configuration = CommandConfiguration(
subcommands: [Fan01.self, Fan02.self, Fan03.self, Fan04.self, Fan05.self,
subcommands: [Fan01.self, Fan02.self, Fan03.self, Fan04.self, Fan05.self, Fan10.self,
Frank01.self, Frank02.self, Frank03.self, Frank04.self])
}

Expand Down
46 changes: 46 additions & 0 deletions Sources/BeeTracking/AppearanceRAE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import SwiftFusion
import TensorFlow
import PythonKit

// MARK: - The Regularized Autoencoder model.

Expand Down Expand Up @@ -300,3 +301,48 @@ public struct DenseRAELoss {
return totalLoss
}
}

/// Pretrained version of DenseRAE. Note this is created because Swift does not allow inheritance of structs.
public struct PretrainedDenseRAE: AppearanceModelEncoder {
public var inner: DenseRAE

/// The constructor that only does loading of the pretrained weights.
public init(from imageBatch: Tensor<Double>, given: HyperParameters?) {
let shape = imageBatch.shape
precondition(imageBatch.rank == 4, "Wrong image shape \(shape)")
let (_, H_, W_, C_) = (shape[0], shape[1], shape[2], shape[3])

if let params = given {
var encoder = DenseRAE(
imageHeight: H_, imageWidth: W_, imageChannels: C_,
hiddenDimension: params.hiddenDimension, latentDimension: params.latentDimension
)

let np = Python.import("numpy")

encoder.load(weights: np.load(params.weightFile, allow_pickle: true))
inner = encoder
} else {
inner = DenseRAE(
imageHeight: H_, imageWidth: W_, imageChannels: C_,
hiddenDimension: 1, latentDimension: 1
)
fatalError("Must provide hyperparameters to pretrained network")
}
}

/// Constructor that does training of the network
public init(trainFrom imageBatch: Tensor<Double>, given: HyperParameters?) {
inner = DenseRAE(
from: imageBatch, given: (given != nil) ? (hiddenDimension: given!.hiddenDimension, latentDimension: given!.latentDimension) : nil
)
}

@differentiable
public func encode(_ imageBatch: Tensor<Double>) -> Tensor<Double> {
inner.encode(imageBatch)
}

/// Initialize given an image batch
public typealias HyperParameters = (hiddenDimension: Int, latentDimension: Int, weightFile: String)
}

0 comments on commit fbe71f0

Please sign in to comment.