Skip to content

Commit

Permalink
Added different sorting rule for classification solver
Browse files Browse the repository at this point in the history
  • Loading branch information
dferens committed May 19, 2015
1 parent ca7e05c commit 42009b2
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/clj/asols/solver.clj
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,25 @@
[ret# elapsed#]))

(defprotocol SolverProtocol
(get-metrics [this net entries]))
(get-metrics [this net entries])
(sort-cases [this cases]))

(defrecord ClassificationSolver [train-opts mutation-opts]
SolverProtocol
(get-metrics [_ net entries]
(trainer/calc-ca net entries)))
(trainer/calc-ca net entries))
(sort-cases [_ cases]
(sort-by
(fn [case]
[(- 1.0 (:train-metrics case))
(:train-cost case)])
cases)))

(defrecord RegressionSolver [train-opts mutation-opts]
SolverProtocol
(get-metrics [_ _ _] nil))
(get-metrics [_ _ _] nil)
(sort-cases [_ cases]
(sort-by :train-cost cases)))

(defn- converged?
[_ solving]
Expand Down Expand Up @@ -100,12 +109,9 @@
(defn- make-combined-cases
[solver net cases]
(let [select-count 5
{base-cost :train-cost} (first (for [{m :mutation :as case} cases
:when (= ::m/identity (:operation m))]
case))
selected-cases (->> cases
(filter #(< (:train-cost %) base-cost))
(sort-by :train-cost)
(sort-cases solver)
(take-while #(not= ::m/identity (:operation (:mutation %))))
(take select-count))]
(for [select-count (range 2 (inc (count selected-cases)))
:let [merge-cases (take select-count selected-cases)
Expand Down Expand Up @@ -137,7 +143,7 @@
progress-chan (make-progress-chan out-chan (count mutations))
[cases ms-took] (solve-mutations solver net mutations tpool progress-chan)
all-cases (concat cases (make-combined-cases solver net cases))
[best-case & other-cases] (sort-by :train-cost all-cases)]
[best-case & other-cases] (sort-cases solver all-cases)]
(make-solving solver net best-case other-cases ms-took))
(catch InterruptedException _
(debug "Detected thread interrupt"))))
Expand Down

0 comments on commit 42009b2

Please sign in to comment.