diff --git a/osu.Game/Database/BeatmapLookupCache.cs b/osu.Game/Database/BeatmapLookupCache.cs index c6f8244494..06edc3e2da 100644 --- a/osu.Game/Database/BeatmapLookupCache.cs +++ b/osu.Game/Database/BeatmapLookupCache.cs @@ -6,20 +6,13 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; -using osu.Framework.Allocation; -using osu.Game.Online.API; using osu.Game.Online.API.Requests; using osu.Game.Online.API.Requests.Responses; namespace osu.Game.Database { - // This class is based on `UserLookupCache` which is well tested. - // If modifications are to be made here, a base abstract implementation should likely be created and shared between the two. - public class BeatmapLookupCache : MemoryCachingComponent + public class BeatmapLookupCache : OnlineLookupCache { - [Resolved] - private IAPIProvider api { get; set; } - /// /// Perform an API lookup on the specified beatmap, populating a model. /// @@ -27,7 +20,7 @@ namespace osu.Game.Database /// An optional cancellation token. /// The populated beatmap, or null if the beatmap does not exist or the request could not be satisfied. [ItemCanBeNull] - public Task GetBeatmapAsync(int beatmapId, CancellationToken token = default) => GetAsync(beatmapId, token); + public Task GetBeatmapAsync(int beatmapId, CancellationToken token = default) => LookupAsync(beatmapId, token); /// /// Perform an API lookup on the specified beatmaps, populating a model. @@ -35,115 +28,10 @@ namespace osu.Game.Database /// The beatmaps to lookup. /// An optional cancellation token. /// The populated beatmaps. May include null results for failed retrievals. - public Task GetBeatmapsAsync(int[] beatmapIds, CancellationToken token = default) - { - var beatmapLookupTasks = new List>(); + public Task GetBeatmapsAsync(int[] beatmapIds, CancellationToken token = default) => LookupAsync(beatmapIds, token); - foreach (int u in beatmapIds) - { - beatmapLookupTasks.Add(GetBeatmapAsync(u, token).ContinueWith(task => - { - if (!task.IsCompletedSuccessfully) - return null; + protected override GetBeatmapsRequest CreateRequest(IEnumerable ids) => new GetBeatmapsRequest(ids.ToArray()); - return task.Result; - }, token)); - } - - return Task.WhenAll(beatmapLookupTasks); - } - - protected override async Task ComputeValueAsync(int lookup, CancellationToken token = default) - => await queryBeatmap(lookup).ConfigureAwait(false); - - private readonly Queue<(int id, TaskCompletionSource)> pendingBeatmapTasks = new Queue<(int, TaskCompletionSource)>(); - private Task pendingRequestTask; - private readonly object taskAssignmentLock = new object(); - - private Task queryBeatmap(int beatmapId) - { - lock (taskAssignmentLock) - { - var tcs = new TaskCompletionSource(); - - // Add to the queue. - pendingBeatmapTasks.Enqueue((beatmapId, tcs)); - - // Create a request task if there's not already one. - if (pendingRequestTask == null) - createNewTask(); - - return tcs.Task; - } - } - - private void performLookup() - { - // contains at most 50 unique beatmap IDs from beatmapTasks, which is used to perform the lookup. - var beatmapTasks = new Dictionary>>(); - - // Grab at most 50 unique beatmap IDs from the queue. - lock (taskAssignmentLock) - { - while (pendingBeatmapTasks.Count > 0 && beatmapTasks.Count < 50) - { - (int id, TaskCompletionSource task) next = pendingBeatmapTasks.Dequeue(); - - // Perform a secondary check for existence, in case the beatmap was queried in a previous batch. - if (CheckExists(next.id, out var existing)) - next.task.SetResult(existing); - else - { - if (beatmapTasks.TryGetValue(next.id, out var tasks)) - tasks.Add(next.task); - else - beatmapTasks[next.id] = new List> { next.task }; - } - } - } - - if (beatmapTasks.Count == 0) - return; - - // Query the beatmaps. - var request = new GetBeatmapsRequest(beatmapTasks.Keys.ToArray()); - - // rather than queueing, we maintain our own single-threaded request stream. - // todo: we probably want retry logic here. - api.Perform(request); - - // Create a new request task if there's still more beatmaps to query. - lock (taskAssignmentLock) - { - pendingRequestTask = null; - if (pendingBeatmapTasks.Count > 0) - createNewTask(); - } - - List foundBeatmaps = request.Response?.Beatmaps; - - if (foundBeatmaps != null) - { - foreach (var beatmap in foundBeatmaps) - { - if (beatmapTasks.TryGetValue(beatmap.OnlineID, out var tasks)) - { - foreach (var task in tasks) - task.SetResult(beatmap); - - beatmapTasks.Remove(beatmap.OnlineID); - } - } - } - - // if any tasks remain which were not satisfied, return null. - foreach (var tasks in beatmapTasks.Values) - { - foreach (var task in tasks) - task.SetResult(null); - } - } - - private void createNewTask() => pendingRequestTask = Task.Run(performLookup); + protected override IEnumerable RetrieveResults(GetBeatmapsRequest request) => request.Response?.Beatmaps; } } diff --git a/osu.Game/Database/OnlineLookupCache.cs b/osu.Game/Database/OnlineLookupCache.cs new file mode 100644 index 0000000000..8a50cc486f --- /dev/null +++ b/osu.Game/Database/OnlineLookupCache.cs @@ -0,0 +1,162 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using JetBrains.Annotations; +using osu.Framework.Allocation; +using osu.Game.Online.API; + +namespace osu.Game.Database +{ + public abstract class OnlineLookupCache : MemoryCachingComponent + where TLookup : IEquatable + where TValue : class, IHasOnlineID + where TRequest : APIRequest + { + [Resolved] + private IAPIProvider api { get; set; } + + /// + /// Creates an to retrieve the values for a given collection of s. + /// + /// The IDs to perform the lookup with. + protected abstract TRequest CreateRequest(IEnumerable ids); + + /// + /// Retrieves a list of s from a successful created by . + /// + [CanBeNull] + protected abstract IEnumerable RetrieveResults(TRequest request); + + /// + /// Perform a lookup using the specified , populating a . + /// + /// The ID to lookup. + /// An optional cancellation token. + /// The populated , or null if the value does not exist or the request could not be satisfied. + [ItemCanBeNull] + protected Task LookupAsync(TLookup id, CancellationToken token = default) => GetAsync(id, token); + + /// + /// Perform an API lookup on the specified , populating a . + /// + /// The IDs to lookup. + /// An optional cancellation token. + /// The populated values. May include null results for failed retrievals. + protected Task LookupAsync(TLookup[] ids, CancellationToken token = default) + { + var lookupTasks = new List>(); + + foreach (var id in ids) + { + lookupTasks.Add(LookupAsync(id, token).ContinueWith(task => + { + if (!task.IsCompletedSuccessfully) + return null; + + return task.Result; + }, token)); + } + + return Task.WhenAll(lookupTasks); + } + + // cannot be sealed due to test usages (see TestUserLookupCache). + protected override async Task ComputeValueAsync(TLookup lookup, CancellationToken token = default) + => await queryValue(lookup).ConfigureAwait(false); + + private readonly Queue<(TLookup id, TaskCompletionSource)> pendingTasks = new Queue<(TLookup, TaskCompletionSource)>(); + private Task pendingRequestTask; + private readonly object taskAssignmentLock = new object(); + + private Task queryValue(TLookup id) + { + lock (taskAssignmentLock) + { + var tcs = new TaskCompletionSource(); + + // Add to the queue. + pendingTasks.Enqueue((id, tcs)); + + // Create a request task if there's not already one. + if (pendingRequestTask == null) + createNewTask(); + + return tcs.Task; + } + } + + private void performLookup() + { + // contains at most 50 unique IDs from tasks, which is used to perform the lookup. + var nextTaskBatch = new Dictionary>>(); + + // Grab at most 50 unique IDs from the queue. + lock (taskAssignmentLock) + { + while (pendingTasks.Count > 0 && nextTaskBatch.Count < 50) + { + (TLookup id, TaskCompletionSource task) next = pendingTasks.Dequeue(); + + // Perform a secondary check for existence, in case the value was queried in a previous batch. + if (CheckExists(next.id, out var existing)) + next.task.SetResult(existing); + else + { + if (nextTaskBatch.TryGetValue(next.id, out var tasks)) + tasks.Add(next.task); + else + nextTaskBatch[next.id] = new List> { next.task }; + } + } + } + + if (nextTaskBatch.Count == 0) + return; + + // Query the values. + var request = CreateRequest(nextTaskBatch.Keys.ToArray()); + + // rather than queueing, we maintain our own single-threaded request stream. + // todo: we probably want retry logic here. + api.Perform(request); + + // Create a new request task if there's still more values to query. + lock (taskAssignmentLock) + { + pendingRequestTask = null; + if (pendingTasks.Count > 0) + createNewTask(); + } + + var foundValues = RetrieveResults(request); + + if (foundValues != null) + { + foreach (var value in foundValues) + { + if (nextTaskBatch.TryGetValue(value.OnlineID, out var tasks)) + { + foreach (var task in tasks) + task.SetResult(value); + + nextTaskBatch.Remove(value.OnlineID); + } + } + } + + // if any tasks remain which were not satisfied, return null. + foreach (var tasks in nextTaskBatch.Values) + { + foreach (var task in tasks) + task.SetResult(null); + } + } + + private void createNewTask() => pendingRequestTask = Task.Run(performLookup); + } +} diff --git a/osu.Game/Database/UserLookupCache.cs b/osu.Game/Database/UserLookupCache.cs index 26f4e9fb3b..5fdd80892d 100644 --- a/osu.Game/Database/UserLookupCache.cs +++ b/osu.Game/Database/UserLookupCache.cs @@ -6,18 +6,13 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; -using osu.Framework.Allocation; -using osu.Game.Online.API; using osu.Game.Online.API.Requests; using osu.Game.Online.API.Requests.Responses; namespace osu.Game.Database { - public class UserLookupCache : MemoryCachingComponent + public class UserLookupCache : OnlineLookupCache { - [Resolved] - private IAPIProvider api { get; set; } - /// /// Perform an API lookup on the specified user, populating a model. /// @@ -25,7 +20,7 @@ namespace osu.Game.Database /// An optional cancellation token. /// The populated user, or null if the user does not exist or the request could not be satisfied. [ItemCanBeNull] - public Task GetUserAsync(int userId, CancellationToken token = default) => GetAsync(userId, token); + public Task GetUserAsync(int userId, CancellationToken token = default) => LookupAsync(userId, token); /// /// Perform an API lookup on the specified users, populating a model. @@ -33,115 +28,10 @@ namespace osu.Game.Database /// The users to lookup. /// An optional cancellation token. /// The populated users. May include null results for failed retrievals. - public Task GetUsersAsync(int[] userIds, CancellationToken token = default) - { - var userLookupTasks = new List>(); + public Task GetUsersAsync(int[] userIds, CancellationToken token = default) => LookupAsync(userIds, token); - foreach (int u in userIds) - { - userLookupTasks.Add(GetUserAsync(u, token).ContinueWith(task => - { - if (!task.IsCompletedSuccessfully) - return null; + protected override GetUsersRequest CreateRequest(IEnumerable ids) => new GetUsersRequest(ids.ToArray()); - return task.Result; - }, token)); - } - - return Task.WhenAll(userLookupTasks); - } - - protected override async Task ComputeValueAsync(int lookup, CancellationToken token = default) - => await queryUser(lookup).ConfigureAwait(false); - - private readonly Queue<(int id, TaskCompletionSource)> pendingUserTasks = new Queue<(int, TaskCompletionSource)>(); - private Task pendingRequestTask; - private readonly object taskAssignmentLock = new object(); - - private Task queryUser(int userId) - { - lock (taskAssignmentLock) - { - var tcs = new TaskCompletionSource(); - - // Add to the queue. - pendingUserTasks.Enqueue((userId, tcs)); - - // Create a request task if there's not already one. - if (pendingRequestTask == null) - createNewTask(); - - return tcs.Task; - } - } - - private void performLookup() - { - // contains at most 50 unique user IDs from userTasks, which is used to perform the lookup. - var userTasks = new Dictionary>>(); - - // Grab at most 50 unique user IDs from the queue. - lock (taskAssignmentLock) - { - while (pendingUserTasks.Count > 0 && userTasks.Count < 50) - { - (int id, TaskCompletionSource task) next = pendingUserTasks.Dequeue(); - - // Perform a secondary check for existence, in case the user was queried in a previous batch. - if (CheckExists(next.id, out var existing)) - next.task.SetResult(existing); - else - { - if (userTasks.TryGetValue(next.id, out var tasks)) - tasks.Add(next.task); - else - userTasks[next.id] = new List> { next.task }; - } - } - } - - if (userTasks.Count == 0) - return; - - // Query the users. - var request = new GetUsersRequest(userTasks.Keys.ToArray()); - - // rather than queueing, we maintain our own single-threaded request stream. - // todo: we probably want retry logic here. - api.Perform(request); - - // Create a new request task if there's still more users to query. - lock (taskAssignmentLock) - { - pendingRequestTask = null; - if (pendingUserTasks.Count > 0) - createNewTask(); - } - - List foundUsers = request.Response?.Users; - - if (foundUsers != null) - { - foreach (var user in foundUsers) - { - if (userTasks.TryGetValue(user.Id, out var tasks)) - { - foreach (var task in tasks) - task.SetResult(user); - - userTasks.Remove(user.Id); - } - } - } - - // if any tasks remain which were not satisfied, return null. - foreach (var tasks in userTasks.Values) - { - foreach (var task in tasks) - task.SetResult(null); - } - } - - private void createNewTask() => pendingRequestTask = Task.Run(performLookup); + protected override IEnumerable RetrieveResults(GetUsersRequest request) => request.Response?.Users; } }