-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathManifold.swift
202 lines (184 loc) · 8.95 KB
/
Manifold.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import _Differentiation
/// A point on a differentiable manifold with a `retract` map centered around `self`.
///
/// This protocol helps you define manifolds with custom tangent vectors. Instructions:
/// 1. Define a type `C: ManifoldCoordinate`, without specifying a `TangentVector`. (Swift
/// generates a `TangentVector` automatically, usually not the `TangentVector` that you want for
/// your manifold).
/// 2. Define `C.LocalCoordinate` to be the `TangentVector` type that you want for your manifold,
/// and define `C.retract` and `C.localCoordinate` to be the retraction and inverse retration
/// for this `TangentVector`.
/// 3. Define a type `M: Manifold` that wraps `C`. The `Manifold` protocol automatically gives `M`
/// the desired `TangentVector`.
/// See "SwiftFusion/doc/DifferentiableManifoldRecipe.md" for more detailed instructions.
public protocol ManifoldCoordinate: Differentiable {
/// The local coordinate type of the manifold.
///
/// This is the `TangentVector` of the `Manifold` wrapper type.
///
/// Note that this is not the same type as `Self.TangentVector`.
associatedtype LocalCoordinate: Vector
/// Diffeomorphism between a neigborhood of `LocalCoordinate.zero` and `Self`.
///
/// Satisfies the following properties:
/// - `retract(LocalCoordinate.zero) == self`
/// - There exists an open set `B` around `LocalCoordinate.zero` such that
/// `localCoordinate(retract(b)) == b` for all `b \in B`.
@differentiable(wrt: local)
func retract(_ local: LocalCoordinate) -> Self
/// Inverse of `retract`.
///
/// Satisfies the following properties:
/// - `localCoordinate(self) == LocalCoordinate.zero`
/// - There exists an open set `B` around `self` such that `localCoordinate(retract(b)) == b` for all
/// `b \in B`.
@differentiable(wrt: global)
func localCoordinate(_ global: Self) -> LocalCoordinate
}
/// A point on a differentiable manifold.
public protocol Manifold: Differentiable {
/// The manifold's global coordinate system.
associatedtype Coordinate: ManifoldCoordinate
/// The coordinate of `self`.
///
/// Note: The distinction between `coordinateStorage` and `coordinate` is a workaround until we
/// can define default derivatives for protocol requirements (TF-982). Until then, implementers
/// of this protocol must define `coordinateStorage`, and clients of this protocol must access
/// coordinate`. This allows us to define default derivatives for `coordinate` that translate
/// between the `ManifoldCoordinate` tangent space and the `Manifold` tangent space.
var coordinateStorage: Coordinate { get set }
/// Creates a manifold point with coordinate `coordinateStorage`.
///
/// Note: The distinction between `init(coordinateStorage:)` and `init(coordinate:)` is a workaround until we
/// can define default derivatives for protocol requirements (TF-982). Until then, implementers
/// of this protocol must define `init(coordinateStorage:)`, and clients of this protocol must access
/// init(coordinate:)`. This allows us to define default derivatives for `init(coordinate:)` that translate
/// between the `ManifoldCoordinate` tangent space and the `Manifold` tangent space.
init(coordinateStorage: Coordinate)
}
/// Methods for converting between manifolds and their coordinates.
///
/// To enable these, you must explicitly write
// public typealias TangentVector = <local coordinate type>
/// in your manifold type.
extension Manifold where Self.TangentVector == Coordinate.LocalCoordinate {
/// The coordinate of `self`.
@differentiable
public var coordinate: Coordinate {
return coordinateStorage
}
/// A custom derivative of `coordinate` that converts from the global coordinate system's
/// tangent vector to the local coordinate system's tangent vector, so that all functions on this
/// manifold using `coordinate` have derivatives involving local coordinates.
@derivative(of: coordinate)
@usableFromInline
func vjpCoordinate()
-> (value: Coordinate, pullback: (Coordinate.TangentVector) -> TangentVector)
{
// Explanation of this pullback:
//
// Let `f: Manifold -> Coordinate` be `f(x) = x.coordinateStorage`.
//
// `differential(at: x, in: f)` is a linear approximation of how changes in tangent vectors
// around `x` lead to changes in global coordinates around `x.coordinateStorage`.
//
// `x.coordinateStorage.retract: TangentVector -> Coordinate` defines _exactly_ how local
// coordinates around zero map to global coordinates around `x`.
//
// Therefore, `differential(at: x, in: f) = differential(at: zero, in: x.coordinateStorage.retract)`.
//
// The pullback is the dual map of the differential, so taking duals of both sides gives:
// `pullback(at: x, in: f) = pullback(at: zero, in: x.coordinateStorage.retract)`.
return (
value: coordinateStorage,
pullback(at: zeroTangentVector) { self.coordinateStorage.retract($0) }
)
}
/// Creates a manifold point with coordinate `coordinate`.
@differentiable
public init(coordinate: Coordinate) { self.init(coordinateStorage: coordinate) }
/// A custom derivative of `init(coordinate:)` that converts from the local coordinate system's
/// tangent vector to the global coordinate system's tangent vector, so that all functions
/// producing instances of this manifold using `init(coordinates:)` have derivatives involving
/// local coordinates.
@derivative(of: init(coordinate:))
@usableFromInline
static func vjpInit(coordinate: Coordinate)
-> (value: Self, pullback: (TangentVector) -> Coordinate.TangentVector)
{
// Explanation of this pullback:
//
// Let `g: Coordinate -> Manifold` be `g(x) = Self(coordinateStorage: x)`.
//
// `D_x(g)` (the derivative of `g` at `x`) is a linear approximation of how changes in global
// coordinates around `x` lead to changes in tangent vectors around
// `Self(coordinateStorage: x)`.
//
// `x.coordinateStorage.localCoordinate: Coordinate -> TangentVector` defines _exactly_ how global
// coordinates around `x` map to tangent vectors.
//
// Therefore, `D_x(g)` is the derivative of `x.coordinateStorage.localCoordinate`.
// Explanation of this pullback:
//
// Let `g: Coordinate -> Manifold` be `g(x) = Self(coordinateStorage: x)`.
//
// `differential(at: x, in: g)` is a linear approximation of how changes in global
// coordinates around `x` lead to changes in local coordinates around `Self(coordinateStorage: x)`.
//
// `x.coordinateStorage.localCoordinate: Coordinate -> LocalCoordinate` defines _exactly_ how global
// coordinates around `x` map to local coordinates.
//
// Therefore, `differential(at: x, in: g) = differential(at: zero, in: x.coordinateStorage.localCoordinate)`.
//
// The pullback is the dual map of the differential, so taking duals of both sides gives:
// `pullback(at: x, in: g) = pullback(at: zero, in: x.coordinateStorage.localCoordinate)`.
return (
value: Self(coordinateStorage: coordinate),
pullback: pullback(at: coordinate) { coordinate.localCoordinate($0) }
)
}
}
/// Default implementations of manifold operations in terms of the corresponding
/// `ManifoldCoordinate` operations.
extension Manifold where Self.TangentVector == Coordinate.LocalCoordinate {
/// Diffeomorphism between a neigborhood of `TangentVector.zero` and `Self`.
///
/// Satisfies the following properties:
/// - `retract(TangentVector.zero) == self`
/// - There exists an open set `B` around `TangentVector.zero` such that
/// `localCoordinate(retract(b)) == b` for all `b \in B`.
@differentiable(wrt: local)
public func retract(_ local: TangentVector) -> Self {
return Self(coordinate: self.coordinate.retract(local))
}
/// Derivative of `retract`.
///
/// Swift AD can compute this, but we know mathematically that the derivative is the identity, so
/// we can provide an implementation that is more efficient.
@derivative(of: retract, wrt: local)
@usableFromInline
func vjpRetract(_ local: TangentVector) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
return (retract(local), { $0 })
}
/// Inverse of `retract`.
///
/// Satisfies the following properties:
/// - `localCoordinate(self) == TangentVector.zero`
/// - There exists an open set `B` around `self` such that `localCoordinate(retract(b)) == b` for all
/// `b \in B`.
@differentiable(wrt: global)
public func localCoordinate(_ global: Self) -> TangentVector {
return self.coordinate.localCoordinate(global.coordinate)
}
/// Derivative of `localCoordinate`.
///
/// Swift AD can compute this, but we know mathematically that the derivative is the identity, so
/// we can provide an implementation that is more efficient.
@derivative(of: localCoordinate, wrt: global)
@usableFromInline
func vjpLocalCoordinate(_ global: Self) ->
(value: TangentVector, pullback: (TangentVector) -> TangentVector)
{
return (localCoordinate(global), { $0 })
}
}