diff --git a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs index f020c2a805..5fc05a4b2f 100644 --- a/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs +++ b/osu.Game.Tests/Beatmaps/IO/ImportBeatmapTest.cs @@ -11,7 +11,9 @@ using NUnit.Framework; using osu.Framework.Platform; using osu.Game.IPC; using osu.Framework.Allocation; +using osu.Framework.Logging; using osu.Game.Beatmaps; +using osu.Game.IO; using osu.Game.Tests.Resources; using SharpCompress.Archives.Zip; @@ -21,14 +23,14 @@ namespace osu.Game.Tests.Beatmaps.IO public class ImportBeatmapTest { [Test] - public void TestImportWhenClosed() + public async Task TestImportWhenClosed() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportWhenClosed")) { try { - LoadOszIntoOsu(loadOsu(host)); + await LoadOszIntoOsu(loadOsu(host)); } finally { @@ -38,7 +40,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenDelete() + public async Task TestImportThenDelete() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenDelete")) @@ -47,7 +49,7 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); deleteBeatmapSet(imported, osu); } @@ -59,7 +61,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenImport() + public async Task TestImportThenImport() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenImport")) @@ -68,17 +70,15 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); - var importedSecondTime = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. Assert.IsTrue(imported.ID == importedSecondTime.ID); Assert.IsTrue(imported.Beatmaps.First().ID == importedSecondTime.Beatmaps.First().ID); - var manager = osu.Dependencies.Get(); - - Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); - Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); + checkBeatmapSetCount(osu, 1); + checkSingleReferencedFileCount(osu, 18); } finally { @@ -88,30 +88,41 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestRollbackOnFailure() + public async Task TestRollbackOnFailure() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestRollbackOnFailure")) { try { + int itemAddRemoveFireCount = 0; + int loggedExceptionCount = 0; + + Logger.NewEntry += l => + { + if (l.Target == LoggingTarget.Database && l.Exception != null) + Interlocked.Increment(ref loggedExceptionCount); + }; + var osu = loadOsu(host); var manager = osu.Dependencies.Get(); - int fireCount = 0; - // ReSharper disable once AccessToModifiedClosure - manager.ItemAdded += (_, __) => fireCount++; - manager.ItemRemoved += _ => fireCount++; + manager.ItemAdded += (_, __) => Interlocked.Increment(ref itemAddRemoveFireCount); + manager.ItemRemoved += _ => Interlocked.Increment(ref itemAddRemoveFireCount); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); - Assert.AreEqual(0, fireCount -= 1); + Assert.AreEqual(0, itemAddRemoveFireCount -= 1); imported.Hash += "-changed"; manager.Update(imported); - Assert.AreEqual(0, fireCount -= 2); + Assert.AreEqual(0, itemAddRemoveFireCount -= 2); + + checkBeatmapSetCount(osu, 1); + checkBeatmapCount(osu, 12); + checkSingleReferencedFileCount(osu, 18); var breakTemp = TestResources.GetTestBeatmapForImport(); @@ -127,19 +138,24 @@ namespace osu.Game.Tests.Beatmaps.IO zip.SaveTo(outStream, SharpCompress.Common.CompressionType.Deflate); } - Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); - Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); - Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count); - // this will trigger purging of the existing beatmap (online set id match) but should rollback due to broken osu. - manager.Import(breakTemp); + try + { + await manager.Import(breakTemp); + } + catch + { + } // no events should be fired in the case of a rollback. - Assert.AreEqual(0, fireCount); + Assert.AreEqual(0, itemAddRemoveFireCount); - Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); - Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); - Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count); + checkBeatmapSetCount(osu, 1); + checkBeatmapCount(osu, 12); + + checkSingleReferencedFileCount(osu, 18); + + Assert.AreEqual(1, loggedExceptionCount); } finally { @@ -149,7 +165,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenImportDifferentHash() + public async Task TestImportThenImportDifferentHash() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenImportDifferentHash")) @@ -159,19 +175,18 @@ namespace osu.Game.Tests.Beatmaps.IO var osu = loadOsu(host); var manager = osu.Dependencies.Get(); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); imported.Hash += "-changed"; manager.Update(imported); - var importedSecondTime = LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); Assert.IsTrue(imported.ID != importedSecondTime.ID); Assert.IsTrue(imported.Beatmaps.First().ID < importedSecondTime.Beatmaps.First().ID); // only one beatmap will exist as the online set ID matched, causing purging of the first import. - Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count); - Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); + checkBeatmapSetCount(osu, 1); } finally { @@ -181,7 +196,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportThenDeleteThenImport() + public async Task TestImportThenDeleteThenImport() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportThenDeleteThenImport")) @@ -190,11 +205,11 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); deleteBeatmapSet(imported, osu); - var importedSecondTime = LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. Assert.IsTrue(imported.ID == importedSecondTime.ID); @@ -209,7 +224,7 @@ namespace osu.Game.Tests.Beatmaps.IO [TestCase(true)] [TestCase(false)] - public void TestImportThenDeleteThenImportWithOnlineIDMismatch(bool set) + public async Task TestImportThenDeleteThenImportWithOnlineIDMismatch(bool set) { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost($"TestImportThenDeleteThenImport-{set}")) @@ -218,7 +233,7 @@ namespace osu.Game.Tests.Beatmaps.IO { var osu = loadOsu(host); - var imported = LoadOszIntoOsu(osu); + var imported = await LoadOszIntoOsu(osu); if (set) imported.OnlineBeatmapSetID = 1234; @@ -229,7 +244,7 @@ namespace osu.Game.Tests.Beatmaps.IO deleteBeatmapSet(imported, osu); - var importedSecondTime = LoadOszIntoOsu(osu); + var importedSecondTime = await LoadOszIntoOsu(osu); // check the newly "imported" beatmap has been reimported due to mismatch (even though hashes matched) Assert.IsTrue(imported.ID != importedSecondTime.ID); @@ -243,7 +258,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportWithDuplicateBeatmapIDs() + public async Task TestImportWithDuplicateBeatmapIDs() { //unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here. using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportWithDuplicateBeatmapID")) @@ -284,7 +299,7 @@ namespace osu.Game.Tests.Beatmaps.IO var manager = osu.Dependencies.Get(); - var imported = manager.Import(toImport); + var imported = await manager.Import(toImport); Assert.NotNull(imported); Assert.AreEqual(null, imported.Beatmaps[0].OnlineBeatmapID); @@ -330,7 +345,7 @@ namespace osu.Game.Tests.Beatmaps.IO } [Test] - public void TestImportWhenFileOpen() + public async Task TestImportWhenFileOpen() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportWhenFileOpen")) { @@ -339,7 +354,7 @@ namespace osu.Game.Tests.Beatmaps.IO var osu = loadOsu(host); var temp = TestResources.GetTestBeatmapForImport(); using (File.OpenRead(temp)) - osu.Dependencies.Get().Import(temp); + await osu.Dependencies.Get().Import(temp); ensureLoaded(osu); File.Delete(temp); Assert.IsFalse(File.Exists(temp), "We likely held a read lock on the file when we shouldn't"); @@ -351,13 +366,13 @@ namespace osu.Game.Tests.Beatmaps.IO } } - public static BeatmapSetInfo LoadOszIntoOsu(OsuGameBase osu, string path = null) + public static async Task LoadOszIntoOsu(OsuGameBase osu, string path = null) { var temp = path ?? TestResources.GetTestBeatmapForImport(); var manager = osu.Dependencies.Get(); - manager.Import(temp); + await manager.Import(temp); var imported = manager.GetAllUsableBeatmapSets(); @@ -373,11 +388,32 @@ namespace osu.Game.Tests.Beatmaps.IO var manager = osu.Dependencies.Get(); manager.Delete(imported); - Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 0); - Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count); + checkBeatmapSetCount(osu, 0); + checkBeatmapSetCount(osu, 1, true); + checkSingleReferencedFileCount(osu, 0); + Assert.IsTrue(manager.QueryBeatmapSets(_ => true).First().DeletePending); } + private void checkBeatmapSetCount(OsuGameBase osu, int expected, bool includeDeletePending = false) + { + var manager = osu.Dependencies.Get(); + + Assert.AreEqual(expected, includeDeletePending + ? manager.QueryBeatmapSets(_ => true).ToList().Count + : manager.GetAllUsableBeatmapSets().Count); + } + + private void checkBeatmapCount(OsuGameBase osu, int expected) + { + Assert.AreEqual(expected, osu.Dependencies.Get().QueryBeatmaps(_ => true).ToList().Count); + } + + private void checkSingleReferencedFileCount(OsuGameBase osu, int expected) + { + Assert.AreEqual(expected, osu.Dependencies.Get().QueryFiles(f => f.ReferenceCount == 1).Count()); + } + private OsuGameBase loadOsu(GameHost host) { var osu = new OsuGameBase(); diff --git a/osu.Game.Tests/Scores/IO/ImportScoreTest.cs b/osu.Game.Tests/Scores/IO/ImportScoreTest.cs index e39f18c3cd..4babb07213 100644 --- a/osu.Game.Tests/Scores/IO/ImportScoreTest.cs +++ b/osu.Game.Tests/Scores/IO/ImportScoreTest.cs @@ -23,13 +23,13 @@ namespace osu.Game.Tests.Scores.IO public class ImportScoreTest { [Test] - public void TestBasicImport() + public async Task TestBasicImport() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestBasicImport")) { try { - var osu = loadOsu(host); + var osu = await loadOsu(host); var toImport = new ScoreInfo { @@ -43,7 +43,7 @@ namespace osu.Game.Tests.Scores.IO OnlineScoreID = 12345, }; - var imported = loadIntoOsu(osu, toImport); + var imported = await loadIntoOsu(osu, toImport); Assert.AreEqual(toImport.Rank, imported.Rank); Assert.AreEqual(toImport.TotalScore, imported.TotalScore); @@ -62,20 +62,20 @@ namespace osu.Game.Tests.Scores.IO } [Test] - public void TestImportMods() + public async Task TestImportMods() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportMods")) { try { - var osu = loadOsu(host); + var osu = await loadOsu(host); var toImport = new ScoreInfo { Mods = new Mod[] { new OsuModHardRock(), new OsuModDoubleTime() }, }; - var imported = loadIntoOsu(osu, toImport); + var imported = await loadIntoOsu(osu, toImport); Assert.IsTrue(imported.Mods.Any(m => m is OsuModHardRock)); Assert.IsTrue(imported.Mods.Any(m => m is OsuModDoubleTime)); @@ -88,13 +88,13 @@ namespace osu.Game.Tests.Scores.IO } [Test] - public void TestImportStatistics() + public async Task TestImportStatistics() { using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestImportStatistics")) { try { - var osu = loadOsu(host); + var osu = await loadOsu(host); var toImport = new ScoreInfo { @@ -105,7 +105,7 @@ namespace osu.Game.Tests.Scores.IO } }; - var imported = loadIntoOsu(osu, toImport); + var imported = await loadIntoOsu(osu, toImport); Assert.AreEqual(toImport.Statistics[HitResult.Perfect], imported.Statistics[HitResult.Perfect]); Assert.AreEqual(toImport.Statistics[HitResult.Miss], imported.Statistics[HitResult.Miss]); @@ -117,7 +117,7 @@ namespace osu.Game.Tests.Scores.IO } } - private ScoreInfo loadIntoOsu(OsuGameBase osu, ScoreInfo score) + private async Task loadIntoOsu(OsuGameBase osu, ScoreInfo score) { var beatmapManager = osu.Dependencies.Get(); @@ -125,20 +125,24 @@ namespace osu.Game.Tests.Scores.IO score.Ruleset = new OsuRuleset().RulesetInfo; var scoreManager = osu.Dependencies.Get(); - scoreManager.Import(score); + await scoreManager.Import(score); return scoreManager.GetAllUsableScores().First(); } - private OsuGameBase loadOsu(GameHost host) + private async Task loadOsu(GameHost host) { var osu = new OsuGameBase(); + +#pragma warning disable 4014 Task.Run(() => host.Run(osu)); +#pragma warning restore 4014 + waitForOrAssert(() => osu.IsLoaded, @"osu! failed to start in a reasonable amount of time"); var beatmapFile = TestResources.GetTestBeatmapForImport(); var beatmapManager = osu.Dependencies.Get(); - beatmapManager.Import(beatmapFile); + await beatmapManager.Import(beatmapFile); return osu; } diff --git a/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs b/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs index 7104a420a3..8b941e4633 100644 --- a/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs +++ b/osu.Game.Tests/Visual/Background/TestSceneBackgroundScreenBeatmap.cs @@ -72,7 +72,7 @@ namespace osu.Game.Tests.Visual.Background Dependencies.Cache(manager = new BeatmapManager(LocalStorage, factory, rulesets, null, audio, host, Beatmap.Default)); Dependencies.Cache(new OsuConfigManager(LocalStorage)); - manager.Import(TestResources.GetTestBeatmapForImport()); + manager.Import(TestResources.GetTestBeatmapForImport()).Wait(); Beatmap.SetDefault(); } diff --git a/osu.Game.Tests/Visual/SongSelect/TestScenePlaySongSelect.cs b/osu.Game.Tests/Visual/SongSelect/TestScenePlaySongSelect.cs index 2664c7a42c..f5115c50a9 100644 --- a/osu.Game.Tests/Visual/SongSelect/TestScenePlaySongSelect.cs +++ b/osu.Game.Tests/Visual/SongSelect/TestScenePlaySongSelect.cs @@ -255,7 +255,7 @@ namespace osu.Game.Tests.Visual.SongSelect private void addRulesetImportStep(int id) => AddStep($"import test map for ruleset {id}", () => importForRuleset(id)); - private void importForRuleset(int id) => manager.Import(createTestBeatmapSet(getImportId(), rulesets.AvailableRulesets.Where(r => r.ID == id).ToArray())); + private void importForRuleset(int id) => manager.Import(createTestBeatmapSet(getImportId(), rulesets.AvailableRulesets.Where(r => r.ID == id).ToArray())).Wait(); private static int importId; private int getImportId() => ++importId; @@ -277,7 +277,7 @@ namespace osu.Game.Tests.Visual.SongSelect var usableRulesets = rulesets.AvailableRulesets.Where(r => r.ID != 2).ToArray(); for (int i = 0; i < 100; i += 10) - manager.Import(createTestBeatmapSet(i, usableRulesets)); + manager.Import(createTestBeatmapSet(i, usableRulesets)).Wait(); }); } diff --git a/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs b/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs index f59458ef8d..c361598354 100644 --- a/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs +++ b/osu.Game.Tests/Visual/UserInterface/TestSceneUpdateableBeatmapBackgroundSprite.cs @@ -32,7 +32,7 @@ namespace osu.Game.Tests.Visual.UserInterface this.api = api; this.rulesets = rulesets; - testBeatmap = ImportBeatmapTest.LoadOszIntoOsu(osu); + testBeatmap = ImportBeatmapTest.LoadOszIntoOsu(osu).Result; } [Test] diff --git a/osu.Game/Beatmaps/BeatmapInfo.cs b/osu.Game/Beatmaps/BeatmapInfo.cs index 52238c26fe..3c082bb71e 100644 --- a/osu.Game/Beatmaps/BeatmapInfo.cs +++ b/osu.Game/Beatmaps/BeatmapInfo.cs @@ -119,7 +119,7 @@ namespace osu.Game.Beatmaps /// public List Scores { get; set; } - public override string ToString() => $"{Metadata} [{Version}]"; + public override string ToString() => $"{Metadata} [{Version}]".Trim(); public bool Equals(BeatmapInfo other) { diff --git a/osu.Game/Beatmaps/BeatmapManager.cs b/osu.Game/Beatmaps/BeatmapManager.cs index b6fe7f88fa..d90657bff5 100644 --- a/osu.Game/Beatmaps/BeatmapManager.cs +++ b/osu.Game/Beatmaps/BeatmapManager.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Linq.Expressions; +using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; using osu.Framework.Audio; @@ -14,6 +15,7 @@ using osu.Framework.Extensions; using osu.Framework.Graphics.Textures; using osu.Framework.Logging; using osu.Framework.Platform; +using osu.Framework.Threading; using osu.Game.Beatmaps.Formats; using osu.Game.Database; using osu.Game.IO.Archives; @@ -72,6 +74,8 @@ namespace osu.Game.Beatmaps private readonly List currentDownloads = new List(); + private readonly BeatmapUpdateQueue updateQueue; + public BeatmapManager(Storage storage, IDatabaseContextFactory contextFactory, RulesetStore rulesets, IAPIProvider api, AudioManager audioManager, GameHost host = null, WorkingBeatmap defaultBeatmap = null) : base(storage, contextFactory, new BeatmapStore(contextFactory), host) @@ -86,9 +90,11 @@ namespace osu.Game.Beatmaps beatmaps = (BeatmapStore)ModelStore; beatmaps.BeatmapHidden += b => BeatmapHidden?.Invoke(b); beatmaps.BeatmapRestored += b => BeatmapRestored?.Invoke(b); + + updateQueue = new BeatmapUpdateQueue(api); } - protected override void Populate(BeatmapSetInfo beatmapSet, ArchiveReader archive) + protected override Task Populate(BeatmapSetInfo beatmapSet, ArchiveReader archive, CancellationToken cancellationToken = default) { if (archive != null) beatmapSet.Beatmaps = createBeatmapDifficulties(archive); @@ -104,8 +110,7 @@ namespace osu.Game.Beatmaps validateOnlineIds(beatmapSet); - foreach (BeatmapInfo b in beatmapSet.Beatmaps) - fetchAndPopulateOnlineValues(b); + return updateQueue.UpdateAsync(beatmapSet, cancellationToken); } protected override void PreImport(BeatmapSetInfo beatmapSet) @@ -122,7 +127,7 @@ namespace osu.Game.Beatmaps { Delete(existingOnlineId); beatmaps.PurgeDeletable(s => s.ID == existingOnlineId.ID); - Logger.Log($"Found existing beatmap set with same OnlineBeatmapSetID ({beatmapSet.OnlineBeatmapSetID}). It has been purged.", LoggingTarget.Database); + LogForModel(beatmapSet, $"Found existing beatmap set with same OnlineBeatmapSetID ({beatmapSet.OnlineBeatmapSetID}). It has been purged."); } } } @@ -181,10 +186,10 @@ namespace osu.Game.Beatmaps request.Success += filename => { - Task.Factory.StartNew(() => + Task.Factory.StartNew(async () => { // This gets scheduled back to the update thread, but we want the import to run in the background. - Import(downloadNotification, filename); + await Import(downloadNotification, filename); currentDownloads.Remove(request); }, TaskCreationOptions.LongRunning); }; @@ -381,47 +386,6 @@ namespace osu.Game.Beatmaps return beatmapInfos; } - /// - /// Query the API to populate missing values like OnlineBeatmapID / OnlineBeatmapSetID or (Rank-)Status. - /// - /// The beatmap to populate. - /// Whether to re-query if the provided beatmap already has populated values. - /// True if population was successful. - private bool fetchAndPopulateOnlineValues(BeatmapInfo beatmap, bool force = false) - { - if (api?.State != APIState.Online) - return false; - - if (!force && beatmap.OnlineBeatmapID != null && beatmap.BeatmapSet.OnlineBeatmapSetID != null - && beatmap.Status != BeatmapSetOnlineStatus.None && beatmap.BeatmapSet.Status != BeatmapSetOnlineStatus.None) - return true; - - Logger.Log("Attempting online lookup for the missing values...", LoggingTarget.Database); - - try - { - var req = new GetBeatmapRequest(beatmap); - - req.Perform(api); - - var res = req.Result; - - Logger.Log($"Successfully mapped to {res.OnlineBeatmapSetID} / {res.OnlineBeatmapID}.", LoggingTarget.Database); - - beatmap.Status = res.Status; - beatmap.BeatmapSet.Status = res.BeatmapSet.Status; - beatmap.BeatmapSet.OnlineBeatmapSetID = res.OnlineBeatmapSetID; - beatmap.OnlineBeatmapID = res.OnlineBeatmapID; - - return true; - } - catch (Exception e) - { - Logger.Log($"Failed ({e})", LoggingTarget.Database); - return false; - } - } - /// /// A dummy WorkingBeatmap for the purpose of retrieving a beatmap for star difficulty calculation. /// @@ -455,5 +419,55 @@ namespace osu.Game.Beatmaps public override bool IsImportant => false; } } + + private class BeatmapUpdateQueue + { + private readonly IAPIProvider api; + + private const int update_queue_request_concurrency = 4; + + private readonly ThreadedTaskScheduler updateScheduler = new ThreadedTaskScheduler(update_queue_request_concurrency, nameof(BeatmapUpdateQueue)); + + public BeatmapUpdateQueue(IAPIProvider api) + { + this.api = api; + } + + public Task UpdateAsync(BeatmapSetInfo beatmapSet, CancellationToken cancellationToken) + { + if (api?.State != APIState.Online) + return Task.CompletedTask; + + LogForModel(beatmapSet, "Performing online lookups..."); + return Task.WhenAll(beatmapSet.Beatmaps.Select(b => UpdateAsync(beatmapSet, b, cancellationToken)).ToArray()); + } + + // todo: expose this when we need to do individual difficulty lookups. + protected Task UpdateAsync(BeatmapSetInfo beatmapSet, BeatmapInfo beatmap, CancellationToken cancellationToken) + => Task.Factory.StartNew(() => update(beatmapSet, beatmap), cancellationToken, TaskCreationOptions.HideScheduler, updateScheduler); + + private void update(BeatmapSetInfo set, BeatmapInfo beatmap) + { + if (api?.State != APIState.Online) + return; + + var req = new GetBeatmapRequest(beatmap); + + req.Success += res => + { + LogForModel(set, $"Online retrieval mapped {beatmap} to {res.OnlineBeatmapSetID} / {res.OnlineBeatmapID}."); + + beatmap.Status = res.Status; + beatmap.BeatmapSet.Status = res.BeatmapSet.Status; + beatmap.BeatmapSet.OnlineBeatmapSetID = res.OnlineBeatmapSetID; + beatmap.OnlineBeatmapID = res.OnlineBeatmapID; + }; + + req.Failure += e => { LogForModel(set, $"Online retrieval failed for {beatmap}", e); }; + + // intentionally blocking to limit web request concurrency + req.Perform(api); + } + } } } diff --git a/osu.Game/Database/ArchiveModelManager.cs b/osu.Game/Database/ArchiveModelManager.cs index b5a9c70e47..20919f0899 100644 --- a/osu.Game/Database/ArchiveModelManager.cs +++ b/osu.Game/Database/ArchiveModelManager.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore; @@ -13,6 +14,7 @@ using osu.Framework.Extensions; using osu.Framework.IO.File; using osu.Framework.Logging; using osu.Framework.Platform; +using osu.Framework.Threading; using osu.Game.IO; using osu.Game.IO.Archives; using osu.Game.IPC; @@ -29,7 +31,7 @@ namespace osu.Game.Database /// /// The model type. /// The associated file join type. - public abstract class ArchiveModelManager : ICanAcceptFiles + public abstract class ArchiveModelManager : ArchiveModelManager, ICanAcceptFiles where TModel : class, IHasFiles, IHasPrimaryKey, ISoftDelete where TFileModel : INamedFileInfo, new() { @@ -130,56 +132,50 @@ namespace osu.Game.Database /// This will post notifications tracking progress. /// /// One or more archive locations on disk. - public void Import(params string[] paths) + public Task Import(params string[] paths) { var notification = new ProgressNotification { State = ProgressNotificationState.Active }; PostNotification?.Invoke(notification); - Import(notification, paths); + + return Import(notification, paths); } - protected void Import(ProgressNotification notification, params string[] paths) + protected async Task Import(ProgressNotification notification, params string[] paths) { notification.Progress = 0; notification.Text = "Import is initialising..."; - var term = $"{typeof(TModel).Name.Replace("Info", "").ToLower()}"; - - List imported = new List(); - int current = 0; - foreach (string path in paths) + var imported = new List(); + + await Task.WhenAll(paths.Select(async path => { - if (notification.State == ProgressNotificationState.Cancelled) - // user requested abort - return; + notification.CancellationToken.ThrowIfCancellationRequested(); try { - var text = "Importing "; + var model = await Import(path, notification.CancellationToken); - if (path.Length > 1) - text += $"{++current} of {paths.Length} {term}s.."; - else - text += $"{term}.."; + lock (imported) + { + imported.Add(model); + current++; - // only show the filename if it isn't a temporary one (as those look ugly). - if (!path.Contains(Path.GetTempPath())) - text += $"\n{Path.GetFileName(path)}"; - - notification.Text = text; - - imported.Add(Import(path)); - - notification.Progress = (float)current / paths.Length; + notification.Text = $"Imported {current} of {paths.Length} {humanisedModelName}s"; + notification.Progress = (float)current / paths.Length; + } + } + catch (TaskCanceledException) + { + throw; } catch (Exception e) { - e = e.InnerException ?? e; - Logger.Error(e, $@"Could not import ({Path.GetFileName(path)})"); + Logger.Error(e, $@"Could not import ({Path.GetFileName(path)})", LoggingTarget.Database); } - } + })); if (imported.Count == 0) { @@ -190,7 +186,7 @@ namespace osu.Game.Database { notification.CompletionText = imported.Count == 1 ? $"Imported {imported.First()}!" - : $"Imported {current} {term}s!"; + : $"Imported {current} {humanisedModelName}s!"; if (imported.Count > 0 && PresentImport != null) { @@ -210,12 +206,15 @@ namespace osu.Game.Database /// Import one from the filesystem and delete the file on success. /// /// The archive location on disk. + /// An optional cancellation token. /// The imported model, if successful. - public TModel Import(string path) + public async Task Import(string path, CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); + TModel import; using (ArchiveReader reader = getReaderFrom(path)) - import = Import(reader); + import = await Import(reader, cancellationToken); // We may or may not want to delete the file depending on where it is stored. // e.g. reconstructing/repairing database with items from default storage. @@ -228,7 +227,7 @@ namespace osu.Game.Database } catch (Exception e) { - Logger.Error(e, $@"Could not delete original file after import ({Path.GetFileName(path)})"); + LogForModel(import, $@"Could not delete original file after import ({Path.GetFileName(path)})", e); } return import; @@ -243,23 +242,32 @@ namespace osu.Game.Database /// Import an item from an . /// /// The archive to be imported. - public TModel Import(ArchiveReader archive) + /// An optional cancellation token. + public Task Import(ArchiveReader archive, CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); + + TModel model = null; + try { - var model = CreateModel(archive); + model = CreateModel(archive); if (model == null) return null; model.Hash = computeHash(archive); - - return Import(model, archive); + } + catch (TaskCanceledException) + { + throw; } catch (Exception e) { - Logger.Error(e, $"Model creation of {archive.Name} failed.", LoggingTarget.Database); + LogForModel(model, $"Model creation of {archive.Name} failed.", e); return null; } + + return Import(model, archive, cancellationToken); } /// @@ -269,6 +277,16 @@ namespace osu.Game.Database /// protected abstract string[] HashableFileTypes { get; } + protected static void LogForModel(TModel model, string message, Exception e = null) + { + string prefix = $"[{(model?.Hash ?? "?????").Substring(0, 5)}]"; + + if (e != null) + Logger.Error(e, $"{prefix} {message}", LoggingTarget.Database); + else + Logger.Log($"{prefix} {message}", LoggingTarget.Database); + } + /// /// Create a SHA-2 hash from the provided archive based on file content of all files matching . /// @@ -288,13 +306,30 @@ namespace osu.Game.Database /// /// The model to be imported. /// An optional archive to use for model population. - public TModel Import(TModel item, ArchiveReader archive = null) + /// An optional cancellation token. + public async Task Import(TModel item, ArchiveReader archive = null, CancellationToken cancellationToken = default) => await Task.Factory.StartNew(async () => { + cancellationToken.ThrowIfCancellationRequested(); + delayEvents(); + void rollback() + { + if (!Delete(item)) + { + // We may have not yet added the model to the underlying table, but should still clean up files. + LogForModel(item, "Dereferencing files for incomplete import."); + Files.Dereference(item.Files.Select(f => f.FileInfo).ToArray()); + } + } + try { - Logger.Log($"Importing {item}...", LoggingTarget.Database); + LogForModel(item, "Beginning import..."); + + item.Files = archive != null ? createFileInfos(archive, Files) : new List(); + + await Populate(item, archive, cancellationToken); using (var write = ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. { @@ -302,11 +337,6 @@ namespace osu.Game.Database { if (!write.IsTransactionLeader) throw new InvalidOperationException($"Ensure there is no parent transaction so errors can correctly be handled by {this}"); - if (archive != null) - item.Files = createFileInfos(archive, Files); - - Populate(item, archive); - var existing = CheckForExisting(item); if (existing != null) @@ -314,15 +344,17 @@ namespace osu.Game.Database if (CanUndelete(existing, item)) { Undelete(existing); - Logger.Log($"Found existing {typeof(TModel)} for {item} (ID {existing.ID}). Skipping import.", LoggingTarget.Database); + LogForModel(item, $"Found existing {humanisedModelName} for {item} (ID {existing.ID}) – skipping import."); handleEvent(() => ItemAdded?.Invoke(existing, true)); + + // existing item will be used; rollback new import and exit early. + rollback(); + flushEvents(true); return existing; } - else - { - Delete(existing); - ModelStore.PurgeDeletable(s => s.ID == existing.ID); - } + + Delete(existing); + ModelStore.PurgeDeletable(s => s.ID == existing.ID); } PreImport(item); @@ -337,21 +369,21 @@ namespace osu.Game.Database } } - Logger.Log($"Import of {item} successfully completed!", LoggingTarget.Database); + LogForModel(item, "Import successfully completed!"); } catch (Exception e) { - Logger.Error(e, $"Import of {item} failed and has been rolled back.", LoggingTarget.Database); - item = null; - } - finally - { - // we only want to flush events after we've confirmed the write context didn't have any errors. - flushEvents(item != null); + if (!(e is TaskCanceledException)) + LogForModel(item, "Database import or population failed and has been rolled back.", e); + + rollback(); + flushEvents(false); + throw; } + flushEvents(true); return item; - } + }, cancellationToken, TaskCreationOptions.HideScheduler, IMPORT_SCHEDULER).Unwrap(); /// /// Perform an update of the specified item. @@ -534,7 +566,7 @@ namespace osu.Game.Database return Task.CompletedTask; } - return Task.Factory.StartNew(() => Import(stable.GetDirectories(ImportFromStablePath).Select(f => stable.GetFullPath(f)).ToArray()), TaskCreationOptions.LongRunning); + return Task.Run(async () => await Import(stable.GetDirectories(ImportFromStablePath).Select(f => stable.GetFullPath(f)).ToArray())); } #endregion @@ -553,9 +585,8 @@ namespace osu.Game.Database /// /// The model to populate. /// The archive to use as a reference for population. May be null. - protected virtual void Populate(TModel model, [CanBeNull] ArchiveReader archive) - { - } + /// An optional cancellation token. + protected virtual Task Populate(TModel model, [CanBeNull] ArchiveReader archive, CancellationToken cancellationToken = default) => Task.CompletedTask; /// /// Perform any final actions before the import to database executes. @@ -583,6 +614,8 @@ namespace osu.Game.Database private DbSet queryModel() => ContextFactory.Get().Set(); + private string humanisedModelName => $"{typeof(TModel).Name.Replace("Info", "").ToLower()}"; + /// /// Creates an from a valid storage path. /// @@ -600,4 +633,18 @@ namespace osu.Game.Database throw new InvalidFormatException($"{path} is not a valid archive"); } } + + public abstract class ArchiveModelManager + { + private const int import_queue_request_concurrency = 1; + + /// + /// A singleton scheduler shared by all . + /// + /// + /// This scheduler generally performs IO and CPU intensive work so concurrency is limited harshly. + /// It is mainly being used as a queue mechanism for large imports. + /// + protected static readonly ThreadedTaskScheduler IMPORT_SCHEDULER = new ThreadedTaskScheduler(import_queue_request_concurrency, nameof(ArchiveModelManager)); + } } diff --git a/osu.Game/Database/ICanAcceptFiles.cs b/osu.Game/Database/ICanAcceptFiles.cs index f55d0c389e..b9f882468d 100644 --- a/osu.Game/Database/ICanAcceptFiles.cs +++ b/osu.Game/Database/ICanAcceptFiles.cs @@ -1,6 +1,8 @@ // Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. // See the LICENCE file in the repository root for full licence text. +using System.Threading.Tasks; + namespace osu.Game.Database { /// @@ -12,7 +14,7 @@ namespace osu.Game.Database /// Import the specified paths. /// /// The files which should be imported. - void Import(params string[] paths); + Task Import(params string[] paths); /// /// An array of accepted file extensions (in the standard format of ".abc"). diff --git a/osu.Game/IO/FileStore.cs b/osu.Game/IO/FileStore.cs index 458f8964f9..370d6786f5 100644 --- a/osu.Game/IO/FileStore.cs +++ b/osu.Game/IO/FileStore.cs @@ -2,8 +2,11 @@ // See the LICENCE file in the repository root for full licence text. using System; +using System.Collections.Generic; using System.IO; using System.Linq; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore; using osu.Framework.Extensions; using osu.Framework.IO.Stores; using osu.Framework.Logging; @@ -27,6 +30,13 @@ namespace osu.Game.IO Store = new StorageBackedResourceStore(Storage); } + /// + /// Perform a lookup query on available s. + /// + /// The query. + /// Results from the provided query. + public IEnumerable QueryFiles(Expression> query) => ContextFactory.Get().Set().AsNoTracking().Where(f => f.ReferenceCount > 0).Where(query); + public FileInfo Add(Stream data, bool reference = true) { using (var usage = ContextFactory.GetForWrite()) diff --git a/osu.Game/IPC/ArchiveImportIPCChannel.cs b/osu.Game/IPC/ArchiveImportIPCChannel.cs index fc747cd446..484db932f8 100644 --- a/osu.Game/IPC/ArchiveImportIPCChannel.cs +++ b/osu.Game/IPC/ArchiveImportIPCChannel.cs @@ -38,7 +38,7 @@ namespace osu.Game.IPC } if (importer.HandledExtensions.Contains(Path.GetExtension(path)?.ToLowerInvariant())) - importer.Import(path); + await importer.Import(path); } } diff --git a/osu.Game/OsuGameBase.cs b/osu.Game/OsuGameBase.cs index f9128687d6..d6b8caaf5b 100644 --- a/osu.Game/OsuGameBase.cs +++ b/osu.Game/OsuGameBase.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; +using System.Threading.Tasks; using osu.Framework.Allocation; using osu.Framework.Audio; using osu.Framework.Bindables; @@ -268,13 +269,13 @@ namespace osu.Game private readonly List fileImporters = new List(); - public void Import(params string[] paths) + public async Task Import(params string[] paths) { var extension = Path.GetExtension(paths.First())?.ToLowerInvariant(); foreach (var importer in fileImporters) if (importer.HandledExtensions.Contains(extension)) - importer.Import(paths); + await importer.Import(paths); } public string[] HandledExtensions => fileImporters.SelectMany(i => i.HandledExtensions).ToArray(); diff --git a/osu.Game/Overlays/Notifications/ProgressNotification.cs b/osu.Game/Overlays/Notifications/ProgressNotification.cs index 857a0bda9e..c8e081d29f 100644 --- a/osu.Game/Overlays/Notifications/ProgressNotification.cs +++ b/osu.Game/Overlays/Notifications/ProgressNotification.cs @@ -2,6 +2,7 @@ // See the LICENCE file in the repository root for full licence text. using System; +using System.Threading; using osu.Framework.Allocation; using osu.Framework.Graphics; using osu.Framework.Graphics.Containers; @@ -36,6 +37,10 @@ namespace osu.Game.Overlays.Notifications State = state; } + private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + + public CancellationToken CancellationToken => cancellationTokenSource.Token; + public virtual ProgressNotificationState State { get => state; @@ -62,6 +67,8 @@ namespace osu.Game.Overlays.Notifications break; case ProgressNotificationState.Cancelled: + cancellationTokenSource.Cancel(); + Light.Colour = colourCancelled; Light.Pulsate = false; progressBar.Active = false; diff --git a/osu.Game/Screens/Menu/Intro.cs b/osu.Game/Screens/Menu/Intro.cs index 98a2fe8f13..cf5d247482 100644 --- a/osu.Game/Screens/Menu/Intro.cs +++ b/osu.Game/Screens/Menu/Intro.cs @@ -66,7 +66,7 @@ namespace osu.Game.Screens.Menu if (setInfo == null) { // we need to import the default menu background beatmap - setInfo = beatmaps.Import(new ZipArchiveReader(game.Resources.GetStream(@"Tracks/circles.osz"), "circles.osz")); + setInfo = beatmaps.Import(new ZipArchiveReader(game.Resources.GetStream(@"Tracks/circles.osz"), "circles.osz")).Result; setInfo.Protected = true; beatmaps.Update(setInfo); diff --git a/osu.Game/Screens/Play/Player.cs b/osu.Game/Screens/Play/Player.cs index 35ef7b3200..d69d64c2b1 100644 --- a/osu.Game/Screens/Play/Player.cs +++ b/osu.Game/Screens/Play/Player.cs @@ -279,7 +279,7 @@ namespace osu.Game.Screens.Play var score = CreateScore(); if (DrawableRuleset.ReplayScore == null) - scoreManager.Import(score); + scoreManager.Import(score).Wait(); this.Push(CreateResults(score)); diff --git a/osu.Game/Screens/Select/SongSelect.cs b/osu.Game/Screens/Select/SongSelect.cs index d0645dbab6..f9df8c3a39 100644 --- a/osu.Game/Screens/Select/SongSelect.cs +++ b/osu.Game/Screens/Select/SongSelect.cs @@ -32,6 +32,7 @@ using osuTK.Input; using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using osu.Framework.Graphics.Sprites; namespace osu.Game.Screens.Select @@ -256,8 +257,8 @@ namespace osu.Game.Screens.Select if (!beatmaps.GetAllUsableBeatmapSets().Any() && beatmaps.StableInstallationAvailable) dialogOverlay.Push(new ImportFromStablePopup(() => { - beatmaps.ImportFromStableAsync(); - skins.ImportFromStableAsync(); + Task.Run(beatmaps.ImportFromStableAsync); + Task.Run(skins.ImportFromStableAsync); })); }); } diff --git a/osu.Game/Skinning/SkinManager.cs b/osu.Game/Skinning/SkinManager.cs index 3a4d44f608..73cc47ea47 100644 --- a/osu.Game/Skinning/SkinManager.cs +++ b/osu.Game/Skinning/SkinManager.cs @@ -5,6 +5,8 @@ using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; using Microsoft.EntityFrameworkCore; using osu.Framework.Audio; using osu.Framework.Audio.Sample; @@ -71,9 +73,9 @@ namespace osu.Game.Skinning protected override SkinInfo CreateModel(ArchiveReader archive) => new SkinInfo { Name = archive.Name }; - protected override void Populate(SkinInfo model, ArchiveReader archive) + protected override async Task Populate(SkinInfo model, ArchiveReader archive, CancellationToken cancellationToken = default) { - base.Populate(model, archive); + await base.Populate(model, archive, cancellationToken); Skin reference = getSkin(model);