Skip to content
This repository has been archived by the owner on Jun 29, 2019. It is now read-only.

Commit

Permalink
add at_q_b (tested)
Browse files Browse the repository at this point in the history
  • Loading branch information
lindahua committed Feb 7, 2013
1 parent 35a3ade commit a0e9ae0
Show file tree
Hide file tree
Showing 3 changed files with 504 additions and 17 deletions.
202 changes: 185 additions & 17 deletions src/Metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ export
# distance classes
Euclidean,
SqEuclidean,
Chebyshev,
Cityblock,
Chebyshev,
Minkowski,
Mahalanobis,

Hamming,
CosineDist,
Expand All @@ -22,11 +21,19 @@ export
KLDivergence,
JSDivergence,

WeightedEuclidean,
WeightedSqEuclidean,
WeightedCityblock,
WeightedMinkowski,
WeightedHamming,
SqMahalanobis,
Mahalanobis,

# convenient functions
euclidean,
sqeuclidean,
chebyshev,
cityblock,
chebyshev,
minkowski,
mahalanobis,

Expand All @@ -36,14 +43,27 @@ export
chisq_dist,
kl_divergence,
js_divergence,

weighted_euclidean,
weighted_sqeuclidean,
weighted_cityblock,
weighted_minkowski,
weighted_hamming,
sqmahalanobis,
mahalanobis,

# generic functions
result_type,
colwise,
pairwise,
evaluate
evaluate,

# other convenient functions
At_Q_B, At_Q_A, A_Q_Bt, A_Q_At


include("at_q_b.jl")

###########################################################
#
# Metric types
Expand Down Expand Up @@ -71,6 +91,7 @@ abstract SemiMetric <: PreMetric
#
abstract Metric <: SemiMetric


type Euclidean <: Metric end
type SqEuclidean <: SemiMetric end
type Chebyshev <: Metric end
Expand All @@ -88,6 +109,34 @@ type ChiSqDist <: SemiMetric end
type KLDivergence <: PreMetric end
type JSDivergence <: SemiMetric end

type WeightedEuclidean{T<:FloatingPoint} <: Metric
weights::Vector{T}
end

type WeightedSqEuclidean{T<:FloatingPoint} <: SemiMetric
weights::Vector{T}
end

type WeightedCityblock{T<:FloatingPoint} <: Metric
weights::Vector{T}
end

type WeightedMinkowski{T<:FloatingPoint} <: Metric
p::Real
weights::Vector{T}
end

type WeightedHamming{T<:FloatingPoint} <: Metric
weights::Vector{T}
end

type Mahalanobis{T} <: Metric
qmat::Matrix{T}
end

type SqMahalanobis{T} <: SemiMetric
qmat::Matrix{T}
end


###########################################################
Expand All @@ -99,6 +148,15 @@ type JSDivergence <: SemiMetric end
result_type(::PreMetric, T1::Type, T2::Type) = promote_type(T1, T2)
result_type(::Hamming, T1::Type, T2::Type) = Int

result_type{T}(::WeightedEuclidean{T}, T1::Type, T2::Type) = T
result_type{T}(::WeightedSqEuclidean{T}, T1::Type, T2::Type) = T
result_type{T}(::WeightedCityblock{T}, T1::Type, T2::Type) = T
result_type{T}(::WeightedMinkowski{T}, T1::Type, T2::Type) = T
result_type{T}(::WeightedHamming{T}, T1::Type, T2::Type) = T

result_type{T}(::Mahalanobis{T}, T1::Type, T2::Type) = T
result_type{T}(::SqMahalanobis{T}, T1::Type, T2::Type) = T


###########################################################
#
Expand All @@ -123,7 +181,7 @@ function get_colwise_dims(r::Array, a::Matrix, b::Matrix)
throw(ArgumentError("The sizes of a and b must match."))
end
if length(r) != size(a, 2)
throw(ArgumentError("Invalid size of r."))
throw(ArgumentError("Incorrect size of r."))
end
return size(a)
end
Expand All @@ -133,7 +191,7 @@ function get_colwise_dims(r::Array, a::Vector, b::Matrix)
throw(ArgumentError("The length of a must match the number of rows in b."))
end
if length(r) != size(b, 2)
throw(ArgumentError("Invalid size of r."))
throw(ArgumentError("Incorrect size of r."))
end
return size(b)
end
Expand All @@ -143,7 +201,7 @@ function get_colwise_dims(r::Array, a::Matrix, b::Vector)
throw(ArgumentError("The length of b must match the number of rows in a."))
end
if length(r) != size(a, 2)
throw(ArgumentError("Invalid size of r."))
throw(ArgumentError("Incorrect size of r."))
end
return size(a)
end
Expand All @@ -155,20 +213,76 @@ function get_pairwise_dims(r::Matrix, a::Matrix, b::Matrix)
throw(ArgumentError("The numbers of rows in a and b must match."))
end
if !(size(r) == (na, nb))
throw(ArgumentError("Invalid size of r."))
throw(ArgumentError("Incorrect size of r."))
end
return (ma, na, nb)
end

function get_pairwise_dims(r::Matrix, a::Matrix)
m, n = size(a)
if !(size(r) == (n, n))
throw(ArgumentError("Invalid size of r."))
throw(ArgumentError("Incorrect size of r."))
end
return (m, n)
end


# for weighted metrics

function get_colwise_dims(d::Int, r::Array, a::Matrix, b::Matrix)
if !(size(a, 1) == size(b, 1) == d)
throw(ArgumentError("Incorrect vector dimensions."))
end
if length(r) != size(a, 2)
throw(ArgumentError("Incorrect size of r."))
end
return size(a)
end

function get_colwise_dims(d::Int, r::Array, a::Vector, b::Matrix)
if !(length(a) == size(b, 1) == d)
throw(ArgumentError("Incorrect vector dimensions."))
end
if length(r) != size(b, 2)
throw(ArgumentError("Incorrect size of r."))
end
return size(b)
end

function get_colwise_dims(d::Int, r::Array, a::Matrix, b::Vector)
if !(size(a, 1) == length(b) == d)
throw(ArgumentError("Incorrect vector dimensions."))
end
if length(r) != size(a, 2)
throw(ArgumentError("Incorrect size of r."))
end
return size(a)
end

function get_pairwise_dims(d::Int, r::Matrix, a::Matrix, b::Matrix)
na = size(a, 2)
nb = size(b, 2)
if !(size(a, 1) == size(b, 1) == d)
throw(ArgumentError("Incorrect vector dimensions."))
end
if !(size(r) == (na, nb))
throw(ArgumentError("Incorrect size of r."))
end
return (d, na, nb)
end

function get_pairwise_dims(d::Int, r::Matrix, a::Matrix)
n = size(a, 2)
if !(size(a, 1) == size(b, 1) == d)
throw(ArgumentError("Incorrect vector dimensions."))
end
if !(size(r) == (n, n))
throw(ArgumentError("Incorrect size of r."))
end
return (d, n)
end



###########################################################
#
Expand Down Expand Up @@ -308,34 +422,33 @@ end
sqeuclidean(a::Vector, b::Vector) = evaluate(SqEuclidean(), a, b)

function colwise!(r::Array, dist::SqEuclidean, a::Matrix, b::Matrix)
get_colwise_dims(r, a, b)
@devec r[:] = sum(sqr(a - b), 1)
end

function colwise!(r::Array, dist::SqEuclidean, a::Vector, b::Matrix)
for j = 1 : size(b, 2)
n = get_colwise_dims(r, a, b)
for j = 1 : n
@devec r[j] = sum(sqr(a - b[:,j]))
end
end

function pairwise!(r::Matrix, dist::SqEuclidean, a::Matrix, b::Matrix)
m, na, nb = get_pairwise_dims(r, a, b)
At_mul_B(r, a, b)
@devec sa2 = sum(sqr(a), 1)
@devec sb2 = sum(sqr(b), 1)

m = size(a, 2)
n = size(b, 2)
for j = 1 : n
for i = 1 : m
for j = 1 : nb
for i = 1 : na
r[i,j] = sa2[i] + sb2[j] - 2 * r[i,j]
end
end
end

function pairwise!(r::Matrix, dist::SqEuclidean, a::Matrix)
m, n = get_pairwise_dims(r, a)
At_mul_B(r, a, a)
@devec sa2 = sum(sqr(a), 1)

n = size(a, 2)
for j = 1 : n
for i = 1 : j-1
r[i,j] = r[j,i]
Expand Down Expand Up @@ -921,6 +1034,61 @@ function pairwise!(r::Matrix, dist::JSDivergence, a::Matrix)
end


# Weighted squared Euclidean

function evaluate{T<:FloatingPoint}(dist::WeightedSqEuclidean{T}, a::Vector, b::Vector)
w = dist.weights
@devec r = sum(sqr(a - b) .* w)
return r
end

weighted_sqeuclidean(a::Vector, b::Vector, w::Vector) = evaluate(WeightedSqEuclidean(w), a, b)

function colwise!{T<:FloatingPoint}(r::Array, dist::WeightedSqEuclidean{T}, a::Matrix, b::Matrix)
get_colwise_dims(r, a, b)
w = dist.weights
for j = 1 : n
@devec r[j] = sum(sqr(a[:,j] - b[:,j]) .* w)
end
end

function colwise!{T<:FloatingPoint}(r::Array, dist::WeightedSqEuclidean{T}, a::Vector, b::Matrix)
n = get_colwise_dims(r, a, b)
w = dist.weights
for j = 1 : n
@devec r[j] = sum(sqr(a - b[:,j]) .* w)
end
end

function pairwise!{T<:FloatingPoint}(r::Matrix, dist::WeightedSqEuclidean{T}, a::Matrix, b::Matrix)
m, na, nb = get_pairwise_dims(r, a, b)
At_mul_B(r, a, b)
@devec sa2 = sum(sqr(a), 1)
@devec sb2 = sum(sqr(b), 1)
for j = 1 : nb
for i = 1 : na
r[i,j] = sa2[i] + sb2[j] - 2 * r[i,j]
end
end
end

function pairwise!{T<:FloatingPoint}(r::Matrix, dist::WeightedSqEuclidean{T}, a::Matrix)
m, n = get_pairwise_dims(r, a)
At_mul_B(r, a, a)
@devec sa2 = sum(sqr(a), 1)
for j = 1 : n
for i = 1 : j-1
r[i,j] = r[j,i]
end
r[j,j] = 0
for i = j+1 : n
r[i,j] = sa2[i] + sa2[j] - 2 * r[i,j]
end
end
end



end # module end


Loading

0 comments on commit a0e9ae0

Please sign in to comment.