Skip to content

Commit

Permalink
Rotate quick restore QR code and web socket.
Browse files Browse the repository at this point in the history
  • Loading branch information
cody-signal authored and greyson-signal committed Dec 12, 2024
1 parent 57502fb commit 2eabf03
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,12 @@ private fun RestoreViaQrScreen(
) {
AnimatedContent(
targetState = state.qrState,
contentKey = { it::class },
contentAlignment = Alignment.Center,
label = "qr-code-progress"
label = "qr-code-progress",
modifier = Modifier
.fillMaxWidth()
.fillMaxHeight()
) { qrState ->
when (qrState) {
is RestoreViaQrViewModel.QrState.Loaded -> {
Expand All @@ -184,7 +188,9 @@ private fun RestoreViaQrScreen(
}

RestoreViaQrViewModel.QrState.Loading -> {
CircularProgressIndicator(modifier = Modifier.size(48.dp))
Box(contentAlignment = Alignment.Center) {
CircularProgressIndicator(modifier = Modifier.size(48.dp))
}
}

is RestoreViaQrViewModel.QrState.Scanned,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
package org.thoughtcrime.securesms.registrationv3.ui.restore

import androidx.lifecycle.ViewModel
import kotlinx.coroutines.CoroutineExceptionHandler
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import org.signal.core.util.logging.Log
import org.signal.registration.proto.RegistrationProvisionMessage
import org.thoughtcrime.securesms.backup.v2.MessageBackupTier
Expand All @@ -31,15 +36,30 @@ class RestoreViaQrViewModel : ViewModel() {

val state: StateFlow<RestoreViaQrState> = store

private var socketHandle: Closeable
private var socketHandles: MutableList<Closeable> = mutableListOf()
private var startNewSocketJob: Job? = null

init {
socketHandle = start()
restart()
}

fun restart() {
socketHandle.close()
socketHandle = start()
SignalStore.registration.restoreMethodToken = null
shutdown()

startNewSocket()

startNewSocketJob = viewModelScope.launch(Dispatchers.IO) {
var count = 0
while (count < 5 && isActive) {
delay(ProvisioningSocket.LIFESPAN / 2)
if (isActive) {
startNewSocket()
count++
Log.d(TAG, "Started next websocket count: $count")
}
}
}
}

fun handleRegistrationFailure() {
Expand All @@ -61,20 +81,66 @@ class RestoreViaQrViewModel : ViewModel() {
}

override fun onCleared() {
socketHandle.close()
shutdown()
}

private fun startNewSocket() {
synchronized(socketHandles) {
socketHandles += start()

if (socketHandles.size > 2) {
socketHandles.removeAt(0).close()
}
}
}

private fun shutdown() {
startNewSocketJob?.cancel()
synchronized(socketHandles) {
socketHandles.forEach { it.close() }
socketHandles.clear()
}
}

private fun start(): Closeable {
SignalStore.registration.restoreMethodToken = null
store.update { it.copy(qrState = QrState.Loading) }
store.update {
if (it.qrState !is QrState.Loaded) {
it.copy(qrState = QrState.Loading)
} else {
it
}
}

return ProvisioningSocket.start(
identityKeyPair = IdentityKeyUtil.generateIdentityKeyPair(),
configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration(),
handler = CoroutineExceptionHandler { _, _ -> store.update { it.copy(qrState = QrState.Failed) } }
handler = { id, t ->
store.update {
if (it.currentSocketId == null || it.currentSocketId == id) {
Log.w(TAG, "Current socket [$id] has failed, stopping automatic connects", t)
shutdown()
it.copy(currentSocketId = null, qrState = QrState.Failed)
} else {
Log.i(TAG, "Old socket [$id] failed, ignoring")
it
}
}
}
) { socket ->
val url = socket.getProvisioningUrl()
store.update { it.copy(qrState = QrState.Loaded(qrData = QrCodeData.forData(data = url, supportIconOverlay = false))) }
store.update {
Log.d(TAG, "Updating QR code with data from [${socket.id}]")

it.copy(
currentSocketId = socket.id,
qrState = QrState.Loaded(
qrData = QrCodeData.forData(
data = url,
supportIconOverlay = false
)
)
)
}

val result = socket.getRegistrationProvisioningMessage()

Expand All @@ -94,8 +160,15 @@ class RestoreViaQrViewModel : ViewModel() {
SignalStore.backup.usedBackupMediaSpace = result.message.backupSizeBytes
}
store.update { it.copy(isRegistering = true, provisioningMessage = result.message, qrState = QrState.Scanned) }
shutdown()
} else {
store.update { it.copy(showProvisioningError = true, qrState = QrState.Scanned) }
store.update {
if (it.currentSocketId == socket.id) {
it.copy(showProvisioningError = true, qrState = QrState.Scanned)
} else {
it
}
}
}
}
}
Expand All @@ -105,7 +178,8 @@ class RestoreViaQrViewModel : ViewModel() {
val qrState: QrState = QrState.Loading,
val provisioningMessage: RegistrationProvisionMessage? = null,
val showProvisioningError: Boolean = false,
val showRegistrationError: Boolean = false
val showRegistrationError: Boolean = false,
val currentSocketId: Int? = null
)

sealed interface QrState {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,44 @@ import kotlin.time.Duration.Companion.seconds
* A provisional web socket for communicating with a primary device during registration.
*/
class ProvisioningSocket private constructor(
val id: Int,
identityKeyPair: IdentityKeyPair,
configuration: SignalServiceConfiguration,
private val scope: CoroutineScope
) {
companion object {
private val TAG = Log.tag(ProvisioningSocket::class)

@Volatile private var nextSocketId = 1000

val LIFESPAN = 90.seconds

fun start(
identityKeyPair: IdentityKeyPair,
configuration: SignalServiceConfiguration,
handler: CoroutineExceptionHandler,
handler: ProvisioningSocketExceptionHandler,
block: suspend CoroutineScope.(ProvisioningSocket) -> Unit
): Closeable {
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + handler
val socketId = nextSocketId++
val scope = CoroutineScope(Dispatchers.IO) + SupervisorJob() + CoroutineExceptionHandler { _, t -> handler.handleException(socketId, t) }

scope.launch {
var socket: ProvisioningSocket? = null
try {
socket = ProvisioningSocket(identityKeyPair, configuration, scope)
socket = ProvisioningSocket(socketId, identityKeyPair, configuration, scope)
socket.connect()
block(socket)
} catch (e: CancellationException) {
val rootCause = e.getRootCause()
if (rootCause == null) {
Log.i(TAG, "Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
Log.i(TAG, "[$socketId] Scope canceled expectedly, fail silently, ${e.toMinimalString()}")
throw e
} else {
Log.w(TAG, "Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
Log.w(TAG, "[$socketId] Unable to maintain web socket, ${rootCause.toMinimalString()}", rootCause)
throw rootCause
}
} finally {
Log.d(TAG, "Closing web socket")
Log.d(TAG, "[$socketId] Closing web socket")
socket?.close()
}
}
Expand Down Expand Up @@ -144,42 +150,50 @@ class ProvisioningSocket private constructor(
private var lastKeepAliveId: Long = 0

override fun onOpen(webSocket: WebSocket, response: Response) {
Log.d(TAG, "[onOpen]")
Log.d(TAG, "[$id] [onOpen]")
keepAliveJob = scope.launch { keepAlive(webSocket) }

val timeoutJob = scope.launch {
delay(10.seconds)
scope.cancel("Did not receive device id within 10 seconds", SocketTimeoutException("No device id received"))
}

val webSocketExpireJob = scope.launch {
delay(LIFESPAN)
scope.cancel("Did not complete a registration within ${LIFESPAN.inWholeSeconds} seconds", SocketTimeoutException("No provisioning message received"))
}

scope.launch {
provisioningUrlDeferral.await()
timeoutJob.cancel()

provisioningMessageDeferral.await()
webSocketExpireJob.cancel()
}
}

override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
val message: WebSocketMessage = WebSocketMessage.ADAPTER.decode(bytes)

if (message.response != null && message.response.id == lastKeepAliveId) {
Log.d(TAG, "[onMessage] Keep alive received")
Log.d(TAG, "[$id] [onMessage] Keep alive received")
return
}

if (message.request == null) {
Log.w(TAG, "[onMessage] Received null request")
Log.w(TAG, "[$id] [onMessage] Received null request")
return
}

val success = webSocket.send(message.request.toResponse().encode().toByteString())

if (!success) {
Log.w(TAG, "[onMessage] Failed to send response")
Log.w(TAG, "[$id] [onMessage] Failed to send response")
webSocket.close(1000, "OK")
return
}

Log.d(TAG, "[onMessage] Processing request")
Log.d(TAG, "[$id] [onMessage] Processing request")

if (message.request.verb == "PUT" && message.request.body != null) {
when (message.request.path) {
Expand All @@ -197,19 +211,19 @@ class ProvisioningSocket private constructor(
provisioningMessageDeferral.complete(result)
}

else -> Log.w(TAG, "Unknown path requested")
else -> Log.w(TAG, "[$id] Unknown path requested")
}
} else {
Log.w(TAG, "Invalid data")
Log.w(TAG, "[$id] Invalid data")
}
}

override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
scope.launch {
Log.i(TAG, "[onClosing] code: $code reason: $reason")
Log.i(TAG, "[$id] [onClosing] code: $code reason: $reason")

if (code != 1000) {
Log.w(TAG, "Remote side is closing with non-normal code $code")
Log.w(TAG, "[$id] Remote side is closing with non-normal code $code")
webSocket.close(1000, "Remote closed with code $code")
}

Expand All @@ -219,7 +233,7 @@ class ProvisioningSocket private constructor(

override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
scope.launch {
Log.w(TAG, "[onFailure] Failed", t)
Log.w(TAG, "[$id] [onFailure] Failed", t)
webSocket.close(1000, "Failed ${t.message}")

scope.cancel(CancellationException("WebSocket Failure", t))
Expand All @@ -233,10 +247,10 @@ class ProvisioningSocket private constructor(
}

private suspend fun keepAlive(webSocket: WebSocket) {
Log.i(TAG, "[keepAlive] Starting")
Log.i(TAG, "[$id] [keepAlive] Starting")
while (true) {
delay(30.seconds)
Log.i(TAG, "[keepAlive] Sending...")
Log.i(TAG, "[$id] [keepAlive] Sending...")

val id = System.currentTimeMillis()
val message = WebSocketMessage(
Expand All @@ -249,7 +263,7 @@ class ProvisioningSocket private constructor(
)

if (!webSocket.send(message.encodeByteString())) {
Log.w(TAG, "[keepAlive] Send failed")
Log.w(TAG, "[${this@ProvisioningSocket.id}] [keepAlive] Send failed")
} else {
lastKeepAliveId = id
}
Expand All @@ -267,4 +281,8 @@ class ProvisioningSocket private constructor(
)
}
}

fun interface ProvisioningSocketExceptionHandler {
fun handleException(id: Int, exception: Throwable)
}
}

0 comments on commit 2eabf03

Please sign in to comment.