diff --git a/osu.Game.Tests/NonVisual/TaskChainTest.cs b/osu.Game.Tests/NonVisual/TaskChainTest.cs new file mode 100644 index 0000000000..d83eaafe20 --- /dev/null +++ b/osu.Game.Tests/NonVisual/TaskChainTest.cs @@ -0,0 +1,111 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using osu.Game.Utils; + +namespace osu.Game.Tests.NonVisual +{ + [TestFixture] + public class TaskChainTest + { + private TaskChain taskChain; + private int currentTask; + private CancellationTokenSource globalCancellationToken; + + [SetUp] + public void Setup() + { + globalCancellationToken = new CancellationTokenSource(); + taskChain = new TaskChain(); + currentTask = 0; + } + + [TearDown] + public void TearDown() + { + globalCancellationToken?.Cancel(); + } + + [Test] + public async Task TestChainedTasksRunSequentially() + { + var task1 = addTask(); + var task2 = addTask(); + var task3 = addTask(); + + task3.mutex.Set(); + task2.mutex.Set(); + task1.mutex.Set(); + + await Task.WhenAll(task1.task, task2.task, task3.task); + + Assert.That(task1.task.Result, Is.EqualTo(1)); + Assert.That(task2.task.Result, Is.EqualTo(2)); + Assert.That(task3.task.Result, Is.EqualTo(3)); + } + + [Test] + public async Task TestChainedTaskWithIntermediateCancelRunsInSequence() + { + var task1 = addTask(); + var task2 = addTask(); + var task3 = addTask(); + + // Cancel task2, allow task3 to complete. + task2.cancellation.Cancel(); + task2.mutex.Set(); + task3.mutex.Set(); + + // Allow task3 to potentially complete. + Thread.Sleep(1000); + + // Allow task1 to complete. + task1.mutex.Set(); + + // Wait on both tasks. + await Task.WhenAll(task1.task, task3.task); + + Assert.That(task1.task.Result, Is.EqualTo(1)); + Assert.That(task2.task.IsCompleted, Is.False); + Assert.That(task3.task.Result, Is.EqualTo(2)); + } + + [Test] + public async Task TestChainedTaskDoesNotCompleteBeforeChildTasks() + { + var mutex = new ManualResetEventSlim(false); + + var task = taskChain.Add(async () => await Task.Run(() => mutex.Wait(globalCancellationToken.Token))); + + // Allow task to potentially complete + Thread.Sleep(1000); + + Assert.That(task.IsCompleted, Is.False); + + // Allow the task to complete. + mutex.Set(); + + await task; + } + + private (Task task, ManualResetEventSlim mutex, CancellationTokenSource cancellation) addTask() + { + var mutex = new ManualResetEventSlim(false); + var completionSource = new TaskCompletionSource(); + + var cancellationSource = new CancellationTokenSource(); + var token = CancellationTokenSource.CreateLinkedTokenSource(cancellationSource.Token, globalCancellationToken.Token); + + taskChain.Add(() => + { + mutex.Wait(globalCancellationToken.Token); + completionSource.SetResult(Interlocked.Increment(ref currentTask)); + }, token.Token); + + return (completionSource.Task, mutex, cancellationSource); + } + } +} diff --git a/osu.Game.Tests/Visual/Multiplayer/TestSceneMultiplayerRoomManager.cs b/osu.Game.Tests/Visual/Multiplayer/TestSceneMultiplayerRoomManager.cs index 7a3845cbf3..6de5704410 100644 --- a/osu.Game.Tests/Visual/Multiplayer/TestSceneMultiplayerRoomManager.cs +++ b/osu.Game.Tests/Visual/Multiplayer/TestSceneMultiplayerRoomManager.cs @@ -101,7 +101,7 @@ namespace osu.Game.Tests.Visual.Multiplayer }); }); - AddAssert("multiplayer room joined", () => roomContainer.Client.Room != null); + AddUntilStep("multiplayer room joined", () => roomContainer.Client.Room != null); } [Test] @@ -133,7 +133,7 @@ namespace osu.Game.Tests.Visual.Multiplayer }); }); - AddAssert("multiplayer room joined", () => roomContainer.Client.Room != null); + AddUntilStep("multiplayer room joined", () => roomContainer.Client.Room != null); } private TestMultiplayerRoomManager createRoomManager() diff --git a/osu.Game/Extensions/TaskExtensions.cs b/osu.Game/Extensions/TaskExtensions.cs new file mode 100644 index 0000000000..76a76c0c52 --- /dev/null +++ b/osu.Game/Extensions/TaskExtensions.cs @@ -0,0 +1,68 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +#nullable enable + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace osu.Game.Extensions +{ + public static class TaskExtensions + { + /// + /// Add a continuation to be performed only after the attached task has completed. + /// + /// The previous task to be awaited on. + /// The action to run. + /// An optional cancellation token. Will only cancel the provided action, not the sequence. + /// A task representing the provided action. + public static Task ContinueWithSequential(this Task task, Action action, CancellationToken cancellationToken = default) => + task.ContinueWithSequential(() => Task.Run(action, cancellationToken), cancellationToken); + + /// + /// Add a continuation to be performed only after the attached task has completed. + /// + /// The previous task to be awaited on. + /// The continuation to run. Generally should be an async function. + /// An optional cancellation token. Will only cancel the provided action, not the sequence. + /// A task representing the provided action. + public static Task ContinueWithSequential(this Task task, Func continuationFunction, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(); + + task.ContinueWith(t => + { + // the previous task has finished execution or been cancelled, so we can run the provided continuation. + + if (cancellationToken.IsCancellationRequested) + { + tcs.SetCanceled(); + } + else + { + continuationFunction().ContinueWith(continuationTask => + { + if (cancellationToken.IsCancellationRequested || continuationTask.IsCanceled) + { + tcs.TrySetCanceled(); + } + else if (continuationTask.IsFaulted) + { + tcs.TrySetException(continuationTask.Exception); + } + else + { + tcs.TrySetResult(true); + } + }, cancellationToken: default); + } + }, cancellationToken: default); + + // importantly, we are not returning the continuation itself but rather a task which represents its status in sequential execution order. + // this will not be cancelled or completed until the previous task has also. + return tcs.Task; + } + } +} diff --git a/osu.Game/Online/Multiplayer/MultiplayerClient.cs b/osu.Game/Online/Multiplayer/MultiplayerClient.cs index 6908795510..493518ac80 100644 --- a/osu.Game/Online/Multiplayer/MultiplayerClient.cs +++ b/osu.Game/Online/Multiplayer/MultiplayerClient.cs @@ -126,20 +126,12 @@ namespace osu.Game.Online.Multiplayer return connection.InvokeAsync(nameof(IMultiplayerServer.JoinRoom), roomId); } - public override async Task LeaveRoom() + protected override Task LeaveRoomInternal() { if (!isConnected.Value) - { - // even if not connected, make sure the local room state can be cleaned up. - await base.LeaveRoom(); - return; - } + return Task.FromCanceled(new CancellationToken(true)); - if (Room == null) - return; - - await base.LeaveRoom(); - await connection.InvokeAsync(nameof(IMultiplayerServer.LeaveRoom)); + return connection.InvokeAsync(nameof(IMultiplayerServer.LeaveRoom)); } public override Task TransferHost(int userId) diff --git a/osu.Game/Online/Multiplayer/StatefulMultiplayerClient.cs b/osu.Game/Online/Multiplayer/StatefulMultiplayerClient.cs index f7c9193dfe..032493a5c6 100644 --- a/osu.Game/Online/Multiplayer/StatefulMultiplayerClient.cs +++ b/osu.Game/Online/Multiplayer/StatefulMultiplayerClient.cs @@ -4,9 +4,9 @@ #nullable enable using System; -using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Threading; using System.Threading.Tasks; using osu.Framework.Allocation; using osu.Framework.Bindables; @@ -109,30 +109,43 @@ namespace osu.Game.Online.Multiplayer }); } + private readonly TaskChain joinOrLeaveTaskChain = new TaskChain(); + private CancellationTokenSource? joinCancellationSource; + /// /// Joins the for a given API . /// /// The API . public async Task JoinRoom(Room room) { - if (Room != null) - throw new InvalidOperationException("Cannot join a multiplayer room while already in one."); + var cancellationSource = joinCancellationSource = new CancellationTokenSource(); - Debug.Assert(room.RoomID.Value != null); + await joinOrLeaveTaskChain.Add(async () => + { + if (Room != null) + throw new InvalidOperationException("Cannot join a multiplayer room while already in one."); - apiRoom = room; - playlistItemId = room.Playlist.SingleOrDefault()?.ID ?? 0; + Debug.Assert(room.RoomID.Value != null); - Room = await JoinRoom(room.RoomID.Value.Value); + // Join the server-side room. + var joinedRoom = await JoinRoom(room.RoomID.Value.Value); + Debug.Assert(joinedRoom != null); - Debug.Assert(Room != null); + // Populate users. + Debug.Assert(joinedRoom.Users != null); + await Task.WhenAll(joinedRoom.Users.Select(PopulateUser)); - var users = await getRoomUsers(); - Debug.Assert(users != null); + // Update the stored room (must be done on update thread for thread-safety). + await scheduleAsync(() => + { + Room = joinedRoom; + apiRoom = room; + playlistItemId = room.Playlist.SingleOrDefault()?.ID ?? 0; + }, cancellationSource.Token); - await Task.WhenAll(users.Select(PopulateUser)); - - updateLocalRoomSettings(Room.Settings); + // Update room settings. + await updateLocalRoomSettings(joinedRoom.Settings, cancellationSource.Token); + }, cancellationSource.Token); } /// @@ -142,23 +155,33 @@ namespace osu.Game.Online.Multiplayer /// The joined . protected abstract Task JoinRoom(long roomId); - public virtual Task LeaveRoom() + public Task LeaveRoom() { - Scheduler.Add(() => - { - if (Room == null) - return; + // The join may have not completed yet, so certain tasks that either update the room or reference the room should be cancelled. + // This includes the setting of Room itself along with the initial update of the room settings on join. + joinCancellationSource?.Cancel(); + // Leaving rooms is expected to occur instantaneously whilst the operation is finalised in the background. + // However a few members need to be reset immediately to prevent other components from entering invalid states whilst the operation hasn't yet completed. + // For example, if a room was left and the user immediately pressed the "create room" button, then the user could be taken into the lobby if the value of Room is not reset in time. + var scheduledReset = scheduleAsync(() => + { apiRoom = null; Room = null; CurrentMatchPlayingUserIds.Clear(); RoomUpdated?.Invoke(); - }, false); + }); - return Task.CompletedTask; + return joinOrLeaveTaskChain.Add(async () => + { + await scheduledReset; + await LeaveRoomInternal(); + }); } + protected abstract Task LeaveRoomInternal(); + /// /// Change the current settings. /// @@ -462,27 +485,6 @@ namespace osu.Game.Online.Multiplayer /// The to populate. protected async Task PopulateUser(MultiplayerRoomUser multiplayerUser) => multiplayerUser.User ??= await userLookupCache.GetUserAsync(multiplayerUser.UserID); - /// - /// Retrieve a copy of users currently in the joined in a thread-safe manner. - /// This should be used whenever accessing users from outside of an Update thread context (ie. when not calling ). - /// - /// A copy of users in the current room, or null if unavailable. - private Task?> getRoomUsers() - { - var tcs = new TaskCompletionSource?>(); - - // at some point we probably want to replace all these schedule calls with Room.LockForUpdate. - // for now, as this would require quite some consideration due to the number of accesses to the room instance, - // let's just add a manual schedule for the non-scheduled usages instead. - Scheduler.Add(() => - { - var users = Room?.Users.ToList(); - tcs.SetResult(users); - }, false); - - return tcs.Task; - } - /// /// Updates the local room settings with the given . /// @@ -490,34 +492,36 @@ namespace osu.Game.Online.Multiplayer /// This updates both the joined and the respective API . /// /// The new to update from. - private void updateLocalRoomSettings(MultiplayerRoomSettings settings) + /// The to cancel the update. + private Task updateLocalRoomSettings(MultiplayerRoomSettings settings, CancellationToken cancellationToken = default) => scheduleAsync(() => { if (Room == null) return; - Scheduler.Add(() => + Debug.Assert(apiRoom != null); + + // Update a few properties of the room instantaneously. + Room.Settings = settings; + apiRoom.Name.Value = Room.Settings.Name; + + // The playlist update is delayed until an online beatmap lookup (below) succeeds. + // In-order for the client to not display an outdated beatmap, the playlist is forcefully cleared here. + apiRoom.Playlist.Clear(); + + RoomUpdated?.Invoke(); + + var req = new GetBeatmapSetRequest(settings.BeatmapID, BeatmapSetLookupType.BeatmapId); + + req.Success += res => { - if (Room == null) + if (cancellationToken.IsCancellationRequested) return; - Debug.Assert(apiRoom != null); + updatePlaylist(settings, res); + }; - // Update a few properties of the room instantaneously. - Room.Settings = settings; - apiRoom.Name.Value = Room.Settings.Name; - - // The playlist update is delayed until an online beatmap lookup (below) succeeds. - // In-order for the client to not display an outdated beatmap, the playlist is forcefully cleared here. - apiRoom.Playlist.Clear(); - - RoomUpdated?.Invoke(); - - var req = new GetBeatmapSetRequest(settings.BeatmapID, BeatmapSetLookupType.BeatmapId); - req.Success += res => updatePlaylist(settings, res); - - api.Queue(req); - }, false); - } + api.Queue(req); + }, cancellationToken); private void updatePlaylist(MultiplayerRoomSettings settings, APIBeatmapSet onlineSet) { @@ -566,5 +570,31 @@ namespace osu.Game.Online.Multiplayer else CurrentMatchPlayingUserIds.Remove(userId); } + + private Task scheduleAsync(Action action, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(); + + Scheduler.Add(() => + { + if (cancellationToken.IsCancellationRequested) + { + tcs.SetCanceled(); + return; + } + + try + { + action(); + tcs.SetResult(true); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + }); + + return tcs.Task; + } } } diff --git a/osu.Game/Screens/OnlinePlay/Multiplayer/MultiplayerRoomManager.cs b/osu.Game/Screens/OnlinePlay/Multiplayer/MultiplayerRoomManager.cs index 65d112a032..1e57847f04 100644 --- a/osu.Game/Screens/OnlinePlay/Multiplayer/MultiplayerRoomManager.cs +++ b/osu.Game/Screens/OnlinePlay/Multiplayer/MultiplayerRoomManager.cs @@ -87,7 +87,7 @@ namespace osu.Game.Screens.OnlinePlay.Multiplayer { if (t.IsCompletedSuccessfully) Schedule(() => onSuccess?.Invoke(room)); - else + else if (t.IsFaulted) { const string message = "Failed to join multiplayer room."; diff --git a/osu.Game/Tests/Visual/Multiplayer/MultiplayerTestScene.cs b/osu.Game/Tests/Visual/Multiplayer/MultiplayerTestScene.cs index a87b22affe..2e8c834c65 100644 --- a/osu.Game/Tests/Visual/Multiplayer/MultiplayerTestScene.cs +++ b/osu.Game/Tests/Visual/Multiplayer/MultiplayerTestScene.cs @@ -50,5 +50,13 @@ namespace osu.Game.Tests.Visual.Multiplayer if (joinRoom) RoomManager.Schedule(() => RoomManager.CreateRoom(Room)); }); + + public override void SetUpSteps() + { + base.SetUpSteps(); + + if (joinRoom) + AddUntilStep("wait for room join", () => Client.Room != null); + } } } diff --git a/osu.Game/Tests/Visual/Multiplayer/TestMultiplayerClient.cs b/osu.Game/Tests/Visual/Multiplayer/TestMultiplayerClient.cs index d0d41e56c3..379bb758c5 100644 --- a/osu.Game/Tests/Visual/Multiplayer/TestMultiplayerClient.cs +++ b/osu.Game/Tests/Visual/Multiplayer/TestMultiplayerClient.cs @@ -100,6 +100,8 @@ namespace osu.Game.Tests.Visual.Multiplayer return Task.FromResult(room); } + protected override Task LeaveRoomInternal() => Task.CompletedTask; + public override Task TransferHost(int userId) => ((IMultiplayerClient)this).HostChanged(userId); public override async Task ChangeSettings(MultiplayerRoomSettings settings) diff --git a/osu.Game/Utils/TaskChain.cs b/osu.Game/Utils/TaskChain.cs new file mode 100644 index 0000000000..df28faf9fb --- /dev/null +++ b/osu.Game/Utils/TaskChain.cs @@ -0,0 +1,46 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +#nullable enable + +using System; +using System.Threading; +using System.Threading.Tasks; +using osu.Game.Extensions; + +namespace osu.Game.Utils +{ + /// + /// A chain of s that run sequentially. + /// + public class TaskChain + { + private readonly object taskLock = new object(); + + private Task lastTaskInChain = Task.CompletedTask; + + /// + /// Adds a new task to the end of this . + /// + /// The action to be executed. + /// The for this task. Does not affect further tasks in the chain. + /// The awaitable . + public Task Add(Action action, CancellationToken cancellationToken = default) + { + lock (taskLock) + return lastTaskInChain = lastTaskInChain.ContinueWithSequential(action, cancellationToken); + } + + /// + /// Adds a new task to the end of this . + /// + /// The task to be executed. + /// The for this task. Does not affect further tasks in the chain. + /// The awaitable . + public Task Add(Func task, CancellationToken cancellationToken = default) + { + lock (taskLock) + return lastTaskInChain = lastTaskInChain.ContinueWithSequential(task, cancellationToken); + } + } +}