Browse Source

Make sure to save the tokens the Client might return when its session is restored (#3378)

* Use `ClientSessionDelegate` to ensure tokens are always updated.

Refreshed tokens on client restoration might not have been stored to disk if the token refresh happened before `RustMatrixClient` was built and the `ClientDelegate` was set in it.

Using `ClientSessionDelegate` should ensure the tokens refreshed callback is called at any point in time.

* Improve how assigning the Client works, fix docs

* Fix review comments
pull/3350/head
Jorge Martin Espinosa 2 weeks ago committed by GitHub
parent
commit
2c8b0d0b95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/DefaultLogoutUseCase.kt
  2. 2
      features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/LogoutPresenter.kt
  3. 2
      features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenter.kt
  4. 4
      features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/LogoutPresenterTest.kt
  5. 4
      features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenterTest.kt
  6. 3
      libraries/matrix/api/src/main/kotlin/io/element/android/libraries/matrix/api/MatrixClient.kt
  7. 133
      libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustClientSessionDelegate.kt
  8. 99
      libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClient.kt
  9. 4
      libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt
  10. 6
      libraries/matrix/test/src/main/kotlin/io/element/android/libraries/matrix/test/FakeMatrixClient.kt

2
features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/DefaultLogoutUseCase.kt

@ -33,7 +33,7 @@ class DefaultLogoutUseCase @Inject constructor( @@ -33,7 +33,7 @@ class DefaultLogoutUseCase @Inject constructor(
return if (currentSession != null) {
matrixClientProvider.getOrRestore(currentSession)
.getOrThrow()
.logout(ignoreSdkError = true)
.logout(userInitiated = true, ignoreSdkError = true)
} else {
error("No session to sign out")
}

2
features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/LogoutPresenter.kt

@ -104,7 +104,7 @@ class LogoutPresenter @Inject constructor( @@ -104,7 +104,7 @@ class LogoutPresenter @Inject constructor(
ignoreSdkError: Boolean,
) = launch {
suspend {
matrixClient.logout(ignoreSdkError)
matrixClient.logout(userInitiated = true, ignoreSdkError)
}.runCatchingUpdatingState(logoutAction)
}
}

2
features/logout/impl/src/main/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenter.kt

@ -86,7 +86,7 @@ class DefaultDirectLogoutPresenter @Inject constructor( @@ -86,7 +86,7 @@ class DefaultDirectLogoutPresenter @Inject constructor(
ignoreSdkError: Boolean,
) = launch {
suspend {
matrixClient.logout(ignoreSdkError)
matrixClient.logout(userInitiated = true, ignoreSdkError)
}.runCatchingUpdatingState(logoutAction)
}
}

4
features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/LogoutPresenterTest.kt

@ -144,7 +144,7 @@ class LogoutPresenterTest { @@ -144,7 +144,7 @@ class LogoutPresenterTest {
@Test
fun `present - logout with error then cancel`() = runTest {
val matrixClient = FakeMatrixClient().apply {
logoutLambda = { _ ->
logoutLambda = { _, _ ->
throw A_THROWABLE
}
}
@ -172,7 +172,7 @@ class LogoutPresenterTest { @@ -172,7 +172,7 @@ class LogoutPresenterTest {
@Test
fun `present - logout with error then force`() = runTest {
val matrixClient = FakeMatrixClient().apply {
logoutLambda = { ignoreSdkError ->
logoutLambda = { ignoreSdkError, _ ->
if (!ignoreSdkError) {
throw A_THROWABLE
} else {

4
features/logout/impl/src/test/kotlin/io/element/android/features/logout/impl/direct/DefaultDirectLogoutPresenterTest.kt

@ -125,7 +125,7 @@ class DefaultDirectLogoutPresenterTest { @@ -125,7 +125,7 @@ class DefaultDirectLogoutPresenterTest {
@Test
fun `present - logout with error then cancel`() = runTest {
val matrixClient = FakeMatrixClient().apply {
logoutLambda = { _ ->
logoutLambda = { _, _ ->
throw A_THROWABLE
}
}
@ -153,7 +153,7 @@ class DefaultDirectLogoutPresenterTest { @@ -153,7 +153,7 @@ class DefaultDirectLogoutPresenterTest {
@Test
fun `present - logout with error then force`() = runTest {
val matrixClient = FakeMatrixClient().apply {
logoutLambda = { ignoreSdkError ->
logoutLambda = { ignoreSdkError, _ ->
if (!ignoreSdkError) {
throw A_THROWABLE
} else {

3
libraries/matrix/api/src/main/kotlin/io/element/android/libraries/matrix/api/MatrixClient.kt

@ -88,9 +88,10 @@ interface MatrixClient : Closeable { @@ -88,9 +88,10 @@ interface MatrixClient : Closeable {
* Logout the user.
* Returns an optional URL. When the URL is there, it should be presented to the user after logout for
* Relying Party (RP) initiated logout on their account page.
* @param userInitiated if false, the logout came from the HS, no request will be made and the session entry will be kept in the store.
* @param ignoreSdkError if true, the SDK will ignore any error and delete the session data anyway.
*/
suspend fun logout(ignoreSdkError: Boolean): String?
suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String?
/**
* Retrieve the user profile, will also eventually emit a new value to [userProfile].

133
libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustClientSessionDelegate.kt

@ -0,0 +1,133 @@ @@ -0,0 +1,133 @@
/*
* Copyright (c) 2024 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.element.android.libraries.matrix.impl
import io.element.android.libraries.core.coroutine.CoroutineDispatchers
import io.element.android.libraries.matrix.impl.mapper.toSessionData
import io.element.android.libraries.matrix.impl.paths.getSessionPaths
import io.element.android.libraries.matrix.impl.util.anonymizedTokens
import io.element.android.libraries.sessionstorage.api.SessionStore
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import org.matrix.rustcomponents.sdk.ClientDelegate
import org.matrix.rustcomponents.sdk.ClientSessionDelegate
import org.matrix.rustcomponents.sdk.Session
import timber.log.Timber
import java.util.concurrent.atomic.AtomicBoolean
/**
* This class is responsible for handling the session data for the Rust SDK.
*
* It implements both [ClientSessionDelegate] and [ClientDelegate] to react to session data updates and auth errors.
*
* IMPORTANT: you must set the [client] property as soon as possible so [didReceiveAuthError] can work properly.
*/
@OptIn(ExperimentalCoroutinesApi::class)
class RustClientSessionDelegate(
private val sessionStore: SessionStore,
private val appCoroutineScope: CoroutineScope,
coroutineDispatchers: CoroutineDispatchers,
) : ClientSessionDelegate, ClientDelegate {
private val clientLog = Timber.tag("$this")
// Used to ensure several calls to `didReceiveAuthError` don't trigger multiple logouts
private val isLoggingOut = AtomicBoolean(false)
// To make sure only one coroutine affecting the token persistence can run at a time
private val updateTokensDispatcher = coroutineDispatchers.io.limitedParallelism(1)
// This Client needs to be set up as soon as possible so `didReceiveAuthError` can work properly.
private var client: RustMatrixClient? = null
/**
* Sets the [ClientDelegate] for the [RustMatrixClient], and keeps a reference to the client so it can be used later.
*/
fun bindClient(client: RustMatrixClient) {
this.client = client
client.setDelegate(this)
}
override fun saveSessionInKeychain(session: Session) {
appCoroutineScope.launch(updateTokensDispatcher) {
val existingData = sessionStore.getSession(session.userId) ?: return@launch
val (anonymizedAccessToken, anonymizedRefreshToken) = session.anonymizedTokens()
clientLog.d(
"Saving new session data with token: access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'. " +
"Was token valid: ${existingData.isTokenValid}"
)
val newData = session.toSessionData(
isTokenValid = true,
loginType = existingData.loginType,
passphrase = existingData.passphrase,
sessionPaths = existingData.getSessionPaths(),
)
sessionStore.updateData(newData)
clientLog.d("Saved new session data with access token: '$anonymizedAccessToken'.")
}.invokeOnCompletion {
if (it != null) {
clientLog.e(it, "Failed to save new session data.")
}
}
}
override fun didReceiveAuthError(isSoftLogout: Boolean) {
clientLog.w("didReceiveAuthError(isSoftLogout=$isSoftLogout)")
if (isLoggingOut.getAndSet(true).not()) {
clientLog.v("didReceiveAuthError -> do the cleanup")
// TODO handle isSoftLogout parameter.
appCoroutineScope.launch(updateTokensDispatcher) {
val currentClient = client
if (currentClient == null) {
clientLog.w("didReceiveAuthError -> no client, exiting")
isLoggingOut.set(false)
return@launch
}
val existingData = sessionStore.getSession(currentClient.sessionId.value)
val (anonymizedAccessToken, anonymizedRefreshToken) = existingData.anonymizedTokens()
clientLog.d(
"Removing session data with access token '$anonymizedAccessToken' " +
"and refresh token '$anonymizedRefreshToken'."
)
if (existingData != null) {
// Set isTokenValid to false
val newData = existingData.copy(isTokenValid = false)
sessionStore.updateData(newData)
clientLog.d("Invalidated session data with access token: '$anonymizedAccessToken'.")
} else {
clientLog.d("No session data found.")
}
client?.logout(userInitiated = false, ignoreSdkError = true)
}.invokeOnCompletion {
if (it != null) {
clientLog.e(it, "Failed to remove session data.")
}
}
} else {
clientLog.v("didReceiveAuthError -> already cleaning up")
}
}
override fun didRefreshTokens() {
// This is done in `saveSessionInKeychain(Session)` instead.
}
override fun retrieveSessionFromKeychain(userId: String): Session {
// This should never be called, as it's only used for multi-process setups
error("retrieveSessionFromKeychain should never be called for Android")
}
}

99
libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClient.kt

@ -51,12 +51,10 @@ import io.element.android.libraries.matrix.api.user.MatrixUser @@ -51,12 +51,10 @@ import io.element.android.libraries.matrix.api.user.MatrixUser
import io.element.android.libraries.matrix.api.verification.SessionVerificationService
import io.element.android.libraries.matrix.impl.core.toProgressWatcher
import io.element.android.libraries.matrix.impl.encryption.RustEncryptionService
import io.element.android.libraries.matrix.impl.mapper.toSessionData
import io.element.android.libraries.matrix.impl.media.RustMediaLoader
import io.element.android.libraries.matrix.impl.notification.RustNotificationService
import io.element.android.libraries.matrix.impl.notificationsettings.RustNotificationSettingsService
import io.element.android.libraries.matrix.impl.oidc.toRustAction
import io.element.android.libraries.matrix.impl.paths.getSessionPaths
import io.element.android.libraries.matrix.impl.pushers.RustPushersService
import io.element.android.libraries.matrix.impl.room.RoomContentForwarder
import io.element.android.libraries.matrix.impl.room.RoomSyncSubscriber
@ -69,7 +67,6 @@ import io.element.android.libraries.matrix.impl.sync.RustSyncService @@ -69,7 +67,6 @@ import io.element.android.libraries.matrix.impl.sync.RustSyncService
import io.element.android.libraries.matrix.impl.usersearch.UserProfileMapper
import io.element.android.libraries.matrix.impl.usersearch.UserSearchResultMapper
import io.element.android.libraries.matrix.impl.util.SessionPathsProvider
import io.element.android.libraries.matrix.impl.util.anonymizedTokens
import io.element.android.libraries.matrix.impl.util.cancelAndDestroy
import io.element.android.libraries.matrix.impl.util.mxCallbackFlow
import io.element.android.libraries.matrix.impl.verification.RustSessionVerificationService
@ -100,7 +97,6 @@ import kotlinx.coroutines.withContext @@ -100,7 +97,6 @@ import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import org.matrix.rustcomponents.sdk.BackupState
import org.matrix.rustcomponents.sdk.Client
import org.matrix.rustcomponents.sdk.ClientDelegate
import org.matrix.rustcomponents.sdk.ClientException
import org.matrix.rustcomponents.sdk.IgnoredUsersListener
import org.matrix.rustcomponents.sdk.NotificationProcessSetup
@ -111,7 +107,6 @@ import org.matrix.rustcomponents.sdk.use @@ -111,7 +107,6 @@ import org.matrix.rustcomponents.sdk.use
import timber.log.Timber
import java.io.File
import java.util.Optional
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.time.Duration
import kotlin.time.Duration.Companion.INFINITE
import kotlin.time.Duration.Companion.seconds
@ -130,6 +125,7 @@ class RustMatrixClient( @@ -130,6 +125,7 @@ class RustMatrixClient(
private val baseDirectory: File,
baseCacheDirectory: File,
private val clock: SystemClock,
sessionDelegate: RustClientSessionDelegate,
) : MatrixClient {
override val sessionId: UserId = UserId(client.userId())
override val deviceId: String = client.deviceId()
@ -138,8 +134,6 @@ class RustMatrixClient( @@ -138,8 +134,6 @@ class RustMatrixClient(
private val innerRoomListService = syncService.roomListService()
private val sessionDispatcher = dispatchers.io.limitedParallelism(64)
// To make sure only one coroutine affecting the token persistence can run at a time
private val tokenRefreshDispatcher = sessionDispatcher.limitedParallelism(1)
private val rustSyncService = RustSyncService(syncService, sessionCoroutineScope)
private val pushersService = RustPushersService(
client = client,
@ -164,72 +158,6 @@ class RustMatrixClient( @@ -164,72 +158,6 @@ class RustMatrixClient(
private val sessionPathsProvider = SessionPathsProvider(sessionStore)
private val isLoggingOut = AtomicBoolean(false)
private val clientDelegate = object : ClientDelegate {
private val clientLog get() = Timber.tag(this@RustMatrixClient.toString())
override fun didReceiveAuthError(isSoftLogout: Boolean) {
clientLog.w("didReceiveAuthError(isSoftLogout=$isSoftLogout)")
if (isLoggingOut.getAndSet(true).not()) {
clientLog.v("didReceiveAuthError -> do the cleanup")
// TODO handle isSoftLogout parameter.
appCoroutineScope.launch(tokenRefreshDispatcher) {
val existingData = sessionStore.getSession(client.userId())
val (anonymizedAccessToken, anonymizedRefreshToken) = existingData.anonymizedTokens()
clientLog.d(
"Removing session data with access token '$anonymizedAccessToken' " +
"and refresh token '$anonymizedRefreshToken'."
)
if (existingData != null) {
// Set isTokenValid to false
val newData = client.session().toSessionData(
isTokenValid = false,
loginType = existingData.loginType,
passphrase = existingData.passphrase,
sessionPaths = existingData.getSessionPaths(),
)
sessionStore.updateData(newData)
clientLog.d("Removed session data with access token: '$anonymizedAccessToken'.")
} else {
clientLog.d("No session data found.")
}
doLogout(doRequest = false, removeSession = false, ignoreSdkError = false)
}.invokeOnCompletion {
if (it != null) {
clientLog.e(it, "Failed to remove session data.")
}
}
} else {
clientLog.v("didReceiveAuthError -> already cleaning up")
}
}
override fun didRefreshTokens() {
clientLog.w("didRefreshTokens()")
appCoroutineScope.launch(tokenRefreshDispatcher) {
val existingData = sessionStore.getSession(client.userId()) ?: return@launch
val (anonymizedAccessToken, anonymizedRefreshToken) = client.session().anonymizedTokens()
clientLog.d(
"Saving new session data with token: access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'. " +
"Was token valid: ${existingData.isTokenValid}"
)
val newData = client.session().toSessionData(
isTokenValid = true,
loginType = existingData.loginType,
passphrase = existingData.passphrase,
sessionPaths = existingData.getSessionPaths(),
)
sessionStore.updateData(newData)
clientLog.d("Saved new session data with access token: '$anonymizedAccessToken'.")
}.invokeOnCompletion {
if (it != null) {
clientLog.e(it, "Failed to save new session data.")
}
}
}
}
private val roomSyncSubscriber: RoomSyncSubscriber = RoomSyncSubscriber(innerRoomListService, dispatchers)
override val roomListService: RoomListService = RustRoomListService(
@ -271,7 +199,7 @@ class RustMatrixClient( @@ -271,7 +199,7 @@ class RustMatrixClient(
private val roomMembershipObserver = RoomMembershipObserver()
private val clientDelegateTaskHandle: TaskHandle? = client.setDelegate(clientDelegate)
private val clientDelegateTaskHandle: TaskHandle? = client.setDelegate(sessionDelegate)
private val _userProfile: MutableStateFlow<MatrixUser> = MutableStateFlow(
MatrixUser(
@ -295,6 +223,9 @@ class RustMatrixClient( @@ -295,6 +223,9 @@ class RustMatrixClient(
.stateIn(sessionCoroutineScope, started = SharingStarted.Eagerly, initialValue = persistentListOf())
init {
// Make sure the session delegate has a reference to the client to be able to logout on auth error
sessionDelegate.bindClient(this)
sessionCoroutineScope.launch {
// Force a refresh of the profile
getUserProfile()
@ -536,21 +467,11 @@ class RustMatrixClient( @@ -536,21 +467,11 @@ class RustMatrixClient(
deleteSessionDirectory(deleteCryptoDb = false)
}
override suspend fun logout(ignoreSdkError: Boolean): String? = doLogout(
doRequest = true,
removeSession = true,
ignoreSdkError = ignoreSdkError,
)
private suspend fun doLogout(
doRequest: Boolean,
removeSession: Boolean,
ignoreSdkError: Boolean,
): String? {
override suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String? {
var result: String? = null
syncService.stop()
withContext(sessionDispatcher) {
if (doRequest) {
if (userInitiated) {
try {
result = client.logout()
} catch (failure: Throwable) {
@ -564,7 +485,7 @@ class RustMatrixClient( @@ -564,7 +485,7 @@ class RustMatrixClient(
}
close()
deleteSessionDirectory(deleteCryptoDb = true)
if (removeSession) {
if (userInitiated) {
sessionStore.removeSession(sessionId.value)
}
}
@ -615,6 +536,10 @@ class RustMatrixClient( @@ -615,6 +536,10 @@ class RustMatrixClient(
})
}.buffer(Channel.UNLIMITED)
internal fun setDelegate(delegate: RustClientSessionDelegate) {
client.setDelegate(delegate)
}
private suspend fun File.getCacheSize(
includeCryptoDb: Boolean = false,
): Long = withContext(sessionDispatcher) {

4
libraries/matrix/impl/src/main/kotlin/io/element/android/libraries/matrix/impl/RustMatrixClientFactory.kt

@ -56,6 +56,8 @@ class RustMatrixClientFactory @Inject constructor( @@ -56,6 +56,8 @@ class RustMatrixClientFactory @Inject constructor(
private val appPreferencesStore: AppPreferencesStore,
) {
suspend fun create(sessionData: SessionData): RustMatrixClient = withContext(coroutineDispatchers.io) {
val sessionDelegate = RustClientSessionDelegate(sessionStore, appCoroutineScope, coroutineDispatchers)
val client = getBaseClientBuilder(
sessionPaths = sessionData.getSessionPaths(),
passphrase = sessionData.passphrase,
@ -67,6 +69,7 @@ class RustMatrixClientFactory @Inject constructor( @@ -67,6 +69,7 @@ class RustMatrixClientFactory @Inject constructor(
)
.homeserverUrl(sessionData.homeserverUrl)
.username(sessionData.userId)
.setSessionDelegate(sessionDelegate)
.use { it.build() }
client.restoreSession(sessionData.toSession())
@ -86,6 +89,7 @@ class RustMatrixClientFactory @Inject constructor( @@ -86,6 +89,7 @@ class RustMatrixClientFactory @Inject constructor(
baseDirectory = baseDirectory,
baseCacheDirectory = cacheDirectory,
clock = clock,
sessionDelegate = sessionDelegate,
).also {
Timber.tag(it.toString()).d("Creating Client with access token '$anonymizedAccessToken' and refresh token '$anonymizedRefreshToken'")
}

6
libraries/matrix/test/src/main/kotlin/io/element/android/libraries/matrix/test/FakeMatrixClient.kt

@ -122,7 +122,7 @@ class FakeMatrixClient( @@ -122,7 +122,7 @@ class FakeMatrixClient(
var getRoomInfoFlowLambda = { _: RoomId ->
flowOf<Optional<MatrixRoomInfo>>(Optional.empty())
}
var logoutLambda: (Boolean) -> String? = {
var logoutLambda: (Boolean, Boolean) -> String? = { _, _ ->
null
}
@ -170,8 +170,8 @@ class FakeMatrixClient( @@ -170,8 +170,8 @@ class FakeMatrixClient(
clearCacheLambda()
}
override suspend fun logout(ignoreSdkError: Boolean): String? = simulateLongTask {
return logoutLambda(ignoreSdkError)
override suspend fun logout(userInitiated: Boolean, ignoreSdkError: Boolean): String? = simulateLongTask {
return logoutLambda(ignoreSdkError, userInitiated)
}
override fun close() = Unit

Loading…
Cancel
Save