diff --git a/osu.Game.Tests/Database/BeatmapImporterTests.cs b/osu.Game.Tests/Database/BeatmapImporterTests.cs new file mode 100644 index 0000000000..4cdcf507b6 --- /dev/null +++ b/osu.Game.Tests/Database/BeatmapImporterTests.cs @@ -0,0 +1,820 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using osu.Framework.Extensions; +using osu.Framework.Extensions.ObjectExtensions; +using osu.Framework.Logging; +using osu.Game.Beatmaps; +using osu.Game.Database; +using osu.Game.IO.Archives; +using osu.Game.Models; +using osu.Game.Stores; +using osu.Game.Tests.Resources; +using Realms; +using SharpCompress.Archives; +using SharpCompress.Archives.Zip; +using SharpCompress.Common; +using SharpCompress.Writers.Zip; + +#nullable enable + +namespace osu.Game.Tests.Database +{ + [TestFixture] + public class BeatmapImporterTests : RealmTest + { + [Test] + public void TestImportBeatmapThenCleanup() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using (var importer = new BeatmapImporter(realmFactory, storage)) + using (new RealmRulesetStore(realmFactory, storage)) + { + ILive? imported; + + using (var reader = new ZipArchiveReader(TestResources.GetTestBeatmapStream())) + imported = await importer.Import(reader); + + Assert.AreEqual(1, realmFactory.Context.All().Count()); + + Assert.NotNull(imported); + Debug.Assert(imported != null); + + imported.PerformWrite(s => s.DeletePending = true); + + Assert.AreEqual(1, realmFactory.Context.All().Count(s => s.DeletePending)); + } + }); + + Logger.Log("Running with no work to purge pending deletions"); + + RunTestWithRealm((realmFactory, _) => { Assert.AreEqual(0, realmFactory.Context.All().Count()); }); + } + + [Test] + public void TestImportWhenClosed() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + await LoadOszIntoStore(importer, realmFactory.Context); + }); + } + + [Test] + public void TestImportThenDelete() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + deleteBeatmapSet(imported, realmFactory.Context); + }); + } + + [Test] + public void TestImportThenDeleteFromStream() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var tempPath = TestResources.GetTestBeatmapForImport(); + + ILive? importedSet; + + using (var stream = File.OpenRead(tempPath)) + { + importedSet = await importer.Import(new ImportTask(stream, Path.GetFileName(tempPath))); + ensureLoaded(realmFactory.Context); + } + + Assert.NotNull(importedSet); + Debug.Assert(importedSet != null); + + Assert.IsTrue(File.Exists(tempPath), "Stream source file somehow went missing"); + File.Delete(tempPath); + + var imported = realmFactory.Context.All().First(beatmapSet => beatmapSet.ID == importedSet.ID); + + deleteBeatmapSet(imported, realmFactory.Context); + }); + } + + [Test] + public void TestImportThenImport() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + var importedSecondTime = await LoadOszIntoStore(importer, realmFactory.Context); + + // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. + Assert.IsTrue(imported.ID == importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID == importedSecondTime.Beatmaps.First().ID); + + checkBeatmapSetCount(realmFactory.Context, 1); + checkSingleReferencedFileCount(realmFactory.Context, 18); + }); + } + + [Test] + public void TestImportThenImportWithReZip() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + + string extractedFolder = $"{temp}_extracted"; + Directory.CreateDirectory(extractedFolder); + + try + { + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + string hashBefore = hashFile(temp); + + using (var zip = ZipArchive.Open(temp)) + zip.WriteToDirectory(extractedFolder); + + using (var zip = ZipArchive.Create()) + { + zip.AddAllFromDirectory(extractedFolder); + zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate)); + } + + // zip files differ because different compression or encoder. + Assert.AreNotEqual(hashBefore, hashFile(temp)); + + var importedSecondTime = await importer.Import(new ImportTask(temp)); + + ensureLoaded(realmFactory.Context); + + Assert.NotNull(importedSecondTime); + Debug.Assert(importedSecondTime != null); + + // but contents doesn't, so existing should still be used. + Assert.IsTrue(imported.ID == importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID == importedSecondTime.PerformRead(s => s.Beatmaps.First().ID)); + } + finally + { + Directory.Delete(extractedFolder, true); + } + }); + } + + [Test] + public void TestImportThenImportWithChangedHashedFile() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + + string extractedFolder = $"{temp}_extracted"; + Directory.CreateDirectory(extractedFolder); + + try + { + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + await createScoreForBeatmap(realmFactory.Context, imported.Beatmaps.First()); + + using (var zip = ZipArchive.Open(temp)) + zip.WriteToDirectory(extractedFolder); + + // arbitrary write to hashed file + // this triggers the special BeatmapManager.PreImport deletion/replacement flow. + using (var sw = new FileInfo(Directory.GetFiles(extractedFolder, "*.osu").First()).AppendText()) + await sw.WriteLineAsync("// changed"); + + using (var zip = ZipArchive.Create()) + { + zip.AddAllFromDirectory(extractedFolder); + zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate)); + } + + var importedSecondTime = await importer.Import(new ImportTask(temp)); + + ensureLoaded(realmFactory.Context); + + // check the newly "imported" beatmap is not the original. + Assert.NotNull(importedSecondTime); + Debug.Assert(importedSecondTime != null); + + Assert.IsTrue(imported.ID != importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID != importedSecondTime.PerformRead(s => s.Beatmaps.First().ID)); + } + finally + { + Directory.Delete(extractedFolder, true); + } + }); + } + + [Test] + [Ignore("intentionally broken by import optimisations")] + public void TestImportThenImportWithChangedFile() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + + string extractedFolder = $"{temp}_extracted"; + Directory.CreateDirectory(extractedFolder); + + try + { + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + using (var zip = ZipArchive.Open(temp)) + zip.WriteToDirectory(extractedFolder); + + // arbitrary write to non-hashed file + using (var sw = new FileInfo(Directory.GetFiles(extractedFolder, "*.mp3").First()).AppendText()) + await sw.WriteLineAsync("text"); + + using (var zip = ZipArchive.Create()) + { + zip.AddAllFromDirectory(extractedFolder); + zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate)); + } + + var importedSecondTime = await importer.Import(new ImportTask(temp)); + + ensureLoaded(realmFactory.Context); + + Assert.NotNull(importedSecondTime); + Debug.Assert(importedSecondTime != null); + + // check the newly "imported" beatmap is not the original. + Assert.IsTrue(imported.ID != importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID != importedSecondTime.PerformRead(s => s.Beatmaps.First().ID)); + } + finally + { + Directory.Delete(extractedFolder, true); + } + }); + } + + [Test] + public void TestImportThenImportWithDifferentFilename() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + + string extractedFolder = $"{temp}_extracted"; + Directory.CreateDirectory(extractedFolder); + + try + { + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + using (var zip = ZipArchive.Open(temp)) + zip.WriteToDirectory(extractedFolder); + + // change filename + var firstFile = new FileInfo(Directory.GetFiles(extractedFolder).First()); + firstFile.MoveTo(Path.Combine(firstFile.DirectoryName.AsNonNull(), $"{firstFile.Name}-changed{firstFile.Extension}")); + + using (var zip = ZipArchive.Create()) + { + zip.AddAllFromDirectory(extractedFolder); + zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate)); + } + + var importedSecondTime = await importer.Import(new ImportTask(temp)); + + ensureLoaded(realmFactory.Context); + + Assert.NotNull(importedSecondTime); + Debug.Assert(importedSecondTime != null); + + // check the newly "imported" beatmap is not the original. + Assert.IsTrue(imported.ID != importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID != importedSecondTime.PerformRead(s => s.Beatmaps.First().ID)); + } + finally + { + Directory.Delete(extractedFolder, true); + } + }); + } + + [Test] + [Ignore("intentionally broken by import optimisations")] + public void TestImportCorruptThenImport() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + var firstFile = imported.Files.First(); + + long originalLength; + using (var stream = storage.GetStream(firstFile.File.StoragePath)) + originalLength = stream.Length; + + using (var stream = storage.GetStream(firstFile.File.StoragePath, FileAccess.Write, FileMode.Create)) + stream.WriteByte(0); + + var importedSecondTime = await LoadOszIntoStore(importer, realmFactory.Context); + + using (var stream = storage.GetStream(firstFile.File.StoragePath)) + Assert.AreEqual(stream.Length, originalLength, "Corruption was not fixed on second import"); + + // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. + Assert.IsTrue(imported.ID == importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID == importedSecondTime.Beatmaps.First().ID); + + checkBeatmapSetCount(realmFactory.Context, 1); + checkSingleReferencedFileCount(realmFactory.Context, 18); + }); + } + + [Test] + public void TestRollbackOnFailure() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + int loggedExceptionCount = 0; + + Logger.NewEntry += l => + { + if (l.Target == LoggingTarget.Database && l.Exception != null) + Interlocked.Increment(ref loggedExceptionCount); + }; + + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + realmFactory.Context.Write(() => imported.Hash += "-changed"); + + checkBeatmapSetCount(realmFactory.Context, 1); + checkBeatmapCount(realmFactory.Context, 12); + checkSingleReferencedFileCount(realmFactory.Context, 18); + + var brokenTempFilename = TestResources.GetTestBeatmapForImport(); + + MemoryStream brokenOsu = new MemoryStream(); + MemoryStream brokenOsz = new MemoryStream(await File.ReadAllBytesAsync(brokenTempFilename)); + + File.Delete(brokenTempFilename); + + using (var outStream = File.Open(brokenTempFilename, FileMode.CreateNew)) + using (var zip = ZipArchive.Open(brokenOsz)) + { + zip.AddEntry("broken.osu", brokenOsu, false); + zip.SaveTo(outStream, CompressionType.Deflate); + } + + // this will trigger purging of the existing beatmap (online set id match) but should rollback due to broken osu. + try + { + await importer.Import(new ImportTask(brokenTempFilename)); + } + catch + { + } + + checkBeatmapSetCount(realmFactory.Context, 1); + checkBeatmapCount(realmFactory.Context, 12); + + checkSingleReferencedFileCount(realmFactory.Context, 18); + + Assert.AreEqual(1, loggedExceptionCount); + + File.Delete(brokenTempFilename); + }); + } + + [Test] + public void TestImportThenDeleteThenImport() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + deleteBeatmapSet(imported, realmFactory.Context); + + var importedSecondTime = await LoadOszIntoStore(importer, realmFactory.Context); + + // check the newly "imported" beatmap is actually just the restored previous import. since it matches hash. + Assert.IsTrue(imported.ID == importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID == importedSecondTime.Beatmaps.First().ID); + }); + } + + [Test] + public void TestImportThenDeleteThenImportWithOnlineIDsMissing() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var imported = await LoadOszIntoStore(importer, realmFactory.Context); + + realmFactory.Context.Write(() => + { + foreach (var b in imported.Beatmaps) + b.OnlineID = -1; + }); + + deleteBeatmapSet(imported, realmFactory.Context); + + var importedSecondTime = await LoadOszIntoStore(importer, realmFactory.Context); + + // check the newly "imported" beatmap has been reimported due to mismatch (even though hashes matched) + Assert.IsTrue(imported.ID != importedSecondTime.ID); + Assert.IsTrue(imported.Beatmaps.First().ID != importedSecondTime.Beatmaps.First().ID); + }); + } + + [Test] + public void TestImportWithDuplicateBeatmapIDs() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var metadata = new RealmBeatmapMetadata + { + Artist = "SomeArtist", + Author = "SomeAuthor" + }; + + var ruleset = realmFactory.Context.All().First(); + + var toImport = new RealmBeatmapSet + { + OnlineID = 1, + Beatmaps = + { + new RealmBeatmap(ruleset, new RealmBeatmapDifficulty(), metadata) + { + OnlineID = 2, + }, + new RealmBeatmap(ruleset, new RealmBeatmapDifficulty(), metadata) + { + OnlineID = 2, + Status = BeatmapSetOnlineStatus.Loved, + } + } + }; + + var imported = await importer.Import(toImport); + + Assert.NotNull(imported); + Debug.Assert(imported != null); + + Assert.AreEqual(-1, imported.PerformRead(s => s.Beatmaps[0].OnlineID)); + Assert.AreEqual(-1, imported.PerformRead(s => s.Beatmaps[1].OnlineID)); + }); + } + + [Test] + public void TestImportWhenFileOpen() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + using (File.OpenRead(temp)) + await importer.Import(temp); + ensureLoaded(realmFactory.Context); + File.Delete(temp); + Assert.IsFalse(File.Exists(temp), "We likely held a read lock on the file when we shouldn't"); + }); + } + + [Test] + public void TestImportWithDuplicateHashes() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + + string extractedFolder = $"{temp}_extracted"; + Directory.CreateDirectory(extractedFolder); + + try + { + using (var zip = ZipArchive.Open(temp)) + zip.WriteToDirectory(extractedFolder); + + using (var zip = ZipArchive.Create()) + { + zip.AddAllFromDirectory(extractedFolder); + zip.AddEntry("duplicate.osu", Directory.GetFiles(extractedFolder, "*.osu").First()); + zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate)); + } + + await importer.Import(temp); + + ensureLoaded(realmFactory.Context); + } + finally + { + Directory.Delete(extractedFolder, true); + } + }); + } + + [Test] + public void TestImportNestedStructure() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + + string extractedFolder = $"{temp}_extracted"; + string subfolder = Path.Combine(extractedFolder, "subfolder"); + + Directory.CreateDirectory(subfolder); + + try + { + using (var zip = ZipArchive.Open(temp)) + zip.WriteToDirectory(subfolder); + + using (var zip = ZipArchive.Create()) + { + zip.AddAllFromDirectory(extractedFolder); + zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate)); + } + + var imported = await importer.Import(new ImportTask(temp)); + + Assert.NotNull(imported); + Debug.Assert(imported != null); + + ensureLoaded(realmFactory.Context); + + Assert.IsFalse(imported.PerformRead(s => s.Files.Any(f => f.Filename.Contains("subfolder"))), "Files contain common subfolder"); + } + finally + { + Directory.Delete(extractedFolder, true); + } + }); + } + + [Test] + public void TestImportWithIgnoredDirectoryInArchive() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + + string extractedFolder = $"{temp}_extracted"; + string dataFolder = Path.Combine(extractedFolder, "actual_data"); + string resourceForkFolder = Path.Combine(extractedFolder, "__MACOSX"); + string resourceForkFilePath = Path.Combine(resourceForkFolder, ".extracted"); + + Directory.CreateDirectory(dataFolder); + Directory.CreateDirectory(resourceForkFolder); + + using (var resourceForkFile = File.CreateText(resourceForkFilePath)) + { + await resourceForkFile.WriteLineAsync("adding content so that it's not empty"); + } + + try + { + using (var zip = ZipArchive.Open(temp)) + zip.WriteToDirectory(dataFolder); + + using (var zip = ZipArchive.Create()) + { + zip.AddAllFromDirectory(extractedFolder); + zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate)); + } + + var imported = await importer.Import(new ImportTask(temp)); + + Assert.NotNull(imported); + Debug.Assert(imported != null); + + ensureLoaded(realmFactory.Context); + + Assert.IsFalse(imported.PerformRead(s => s.Files.Any(f => f.Filename.Contains("__MACOSX"))), "Files contain resource fork folder, which should be ignored"); + Assert.IsFalse(imported.PerformRead(s => s.Files.Any(f => f.Filename.Contains("actual_data"))), "Files contain common subfolder"); + } + finally + { + Directory.Delete(extractedFolder, true); + } + }); + } + + [Test] + public void TestUpdateBeatmapInfo() + { + RunTestWithRealmAsync(async (realmFactory, storage) => + { + using var importer = new BeatmapImporter(realmFactory, storage); + using var store = new RealmRulesetStore(realmFactory, storage); + + var temp = TestResources.GetTestBeatmapForImport(); + await importer.Import(temp); + + // Update via the beatmap, not the beatmap info, to ensure correct linking + RealmBeatmapSet setToUpdate = realmFactory.Context.All().First(); + + var beatmapToUpdate = setToUpdate.Beatmaps.First(); + + realmFactory.Context.Write(() => beatmapToUpdate.DifficultyName = "updated"); + + RealmBeatmap updatedInfo = realmFactory.Context.All().First(b => b.ID == beatmapToUpdate.ID); + Assert.That(updatedInfo.DifficultyName, Is.EqualTo("updated")); + }); + } + + public static async Task LoadQuickOszIntoOsu(BeatmapImporter importer, Realm realm) + { + var temp = TestResources.GetQuickTestBeatmapForImport(); + + var importedSet = await importer.Import(new ImportTask(temp)); + + Assert.NotNull(importedSet); + + ensureLoaded(realm); + + waitForOrAssert(() => !File.Exists(temp), "Temporary file still exists after standard import", 5000); + + return realm.All().FirstOrDefault(beatmapSet => beatmapSet.ID == importedSet!.ID); + } + + public static async Task LoadOszIntoStore(BeatmapImporter importer, Realm realm, string? path = null, bool virtualTrack = false) + { + var temp = path ?? TestResources.GetTestBeatmapForImport(virtualTrack); + + var importedSet = await importer.Import(new ImportTask(temp)); + + Assert.NotNull(importedSet); + Debug.Assert(importedSet != null); + + ensureLoaded(realm); + + waitForOrAssert(() => !File.Exists(temp), "Temporary file still exists after standard import", 5000); + + return realm.All().First(beatmapSet => beatmapSet.ID == importedSet.ID); + } + + private void deleteBeatmapSet(RealmBeatmapSet imported, Realm realm) + { + realm.Write(() => imported.DeletePending = true); + + checkBeatmapSetCount(realm, 0); + checkBeatmapSetCount(realm, 1, true); + + Assert.IsTrue(realm.All().First(_ => true).DeletePending); + } + + private static Task createScoreForBeatmap(Realm realm, RealmBeatmap beatmap) + { + // TODO: reimplement when we have score support in realm. + // return ImportScoreTest.LoadScoreIntoOsu(osu, new ScoreInfo + // { + // OnlineScoreID = 2, + // Beatmap = beatmap, + // BeatmapInfoID = beatmap.ID + // }, new ImportScoreTest.TestArchiveReader()); + + return Task.CompletedTask; + } + + private static void checkBeatmapSetCount(Realm realm, int expected, bool includeDeletePending = false) + { + Assert.AreEqual(expected, includeDeletePending + ? realm.All().Count() + : realm.All().Count(s => !s.DeletePending)); + } + + private static string hashFile(string filename) + { + using (var s = File.OpenRead(filename)) + return s.ComputeMD5Hash(); + } + + private static void checkBeatmapCount(Realm realm, int expected) + { + Assert.AreEqual(expected, realm.All().Where(_ => true).ToList().Count); + } + + private static void checkSingleReferencedFileCount(Realm realm, int expected) + { + int singleReferencedCount = 0; + + foreach (var f in realm.All()) + { + if (f.BacklinksCount == 1) + singleReferencedCount++; + } + + Assert.AreEqual(expected, singleReferencedCount); + } + + private static void ensureLoaded(Realm realm, int timeout = 60000) + { + IQueryable? resultSets = null; + + waitForOrAssert(() => (resultSets = realm.All().Where(s => !s.DeletePending && s.OnlineID == 241526)).Any(), + @"BeatmapSet did not import to the database in allocated time.", timeout); + + // ensure we were stored to beatmap database backing... + Assert.IsTrue(resultSets?.Count() == 1, $@"Incorrect result count found ({resultSets?.Count()} but should be 1)."); + + IEnumerable queryBeatmapSets() => realm.All().Where(s => !s.DeletePending && s.OnlineID == 241526); + + var set = queryBeatmapSets().First(); + + // ReSharper disable once PossibleUnintendedReferenceComparison + IEnumerable queryBeatmaps() => realm.All().Where(s => s.BeatmapSet != null && s.BeatmapSet == set); + + waitForOrAssert(() => queryBeatmaps().Count() == 12, @"Beatmaps did not import to the database in allocated time", timeout); + waitForOrAssert(() => queryBeatmapSets().Count() == 1, @"BeatmapSet did not import to the database in allocated time", timeout); + + int countBeatmapSetBeatmaps = 0; + int countBeatmaps = 0; + + waitForOrAssert(() => + (countBeatmapSetBeatmaps = queryBeatmapSets().First().Beatmaps.Count) == + (countBeatmaps = queryBeatmaps().Count()), + $@"Incorrect database beatmap count post-import ({countBeatmaps} but should be {countBeatmapSetBeatmaps}).", timeout); + + foreach (RealmBeatmap b in set.Beatmaps) + Assert.IsTrue(set.Beatmaps.Any(c => c.OnlineID == b.OnlineID)); + Assert.IsTrue(set.Beatmaps.Count > 0); + } + + private static void waitForOrAssert(Func result, string failureMessage, int timeout = 60000) + { + const int sleep = 200; + + while (timeout > 0) + { + Thread.Sleep(sleep); + timeout -= sleep; + + if (result()) + return; + } + + Assert.Fail(failureMessage); + } + } +} diff --git a/osu.Game/Beatmaps/BeatmapManager.cs b/osu.Game/Beatmaps/BeatmapManager.cs index 562cbfabf0..0509a9db47 100644 --- a/osu.Game/Beatmaps/BeatmapManager.cs +++ b/osu.Game/Beatmaps/BeatmapManager.cs @@ -29,7 +29,7 @@ namespace osu.Game.Beatmaps /// Handles general operations related to global beatmap management. /// [ExcludeFromDynamicCompile] - public class BeatmapManager : IModelDownloader, IModelManager, IModelFileManager, ICanAcceptFiles, IWorkingBeatmapCache, IDisposable + public class BeatmapManager : IModelDownloader, IModelManager, IModelFileManager, IWorkingBeatmapCache, IDisposable { private readonly BeatmapModelManager beatmapModelManager; private readonly BeatmapModelDownloader beatmapModelDownloader; diff --git a/osu.Game/Beatmaps/BeatmapModelManager.cs b/osu.Game/Beatmaps/BeatmapModelManager.cs index 76019a15ae..16cf6193f9 100644 --- a/osu.Game/Beatmaps/BeatmapModelManager.cs +++ b/osu.Game/Beatmaps/BeatmapModelManager.cs @@ -123,15 +123,15 @@ namespace osu.Game.Beatmaps // check if a set already exists with the same online id, delete if it does. if (beatmapSet.OnlineBeatmapSetID != null) { - var existingOnlineId = beatmaps.ConsumableItems.FirstOrDefault(b => b.OnlineBeatmapSetID == beatmapSet.OnlineBeatmapSetID); + var existingSetWithSameOnlineID = beatmaps.ConsumableItems.FirstOrDefault(b => b.OnlineBeatmapSetID == beatmapSet.OnlineBeatmapSetID); - if (existingOnlineId != null) + if (existingSetWithSameOnlineID != null) { - Delete(existingOnlineId); + Delete(existingSetWithSameOnlineID); // in order to avoid a unique key constraint, immediately remove the online ID from the previous set. - existingOnlineId.OnlineBeatmapSetID = null; - foreach (var b in existingOnlineId.Beatmaps) + existingSetWithSameOnlineID.OnlineBeatmapSetID = null; + foreach (var b in existingSetWithSameOnlineID.Beatmaps) b.OnlineBeatmapID = null; LogForModel(beatmapSet, $"Found existing beatmap set with same OnlineBeatmapSetID ({beatmapSet.OnlineBeatmapSetID}). It has been deleted."); diff --git a/osu.Game/Database/ArchiveModelManager.cs b/osu.Game/Database/ArchiveModelManager.cs index 84e33e3f36..9c777d324b 100644 --- a/osu.Game/Database/ArchiveModelManager.cs +++ b/osu.Game/Database/ArchiveModelManager.cs @@ -30,7 +30,7 @@ namespace osu.Game.Database /// /// The model type. /// The associated file join type. - public abstract class ArchiveModelManager : ICanAcceptFiles, IModelManager, IModelFileManager + public abstract class ArchiveModelManager : IModelManager, IModelFileManager where TModel : class, IHasFiles, IHasPrimaryKey, ISoftDelete where TFileModel : class, INamedFileInfo, new() { diff --git a/osu.Game/Database/IHasOnlineID.cs b/osu.Game/Database/IHasOnlineID.cs index 529c68a8f8..6e2be7e1f9 100644 --- a/osu.Game/Database/IHasOnlineID.cs +++ b/osu.Game/Database/IHasOnlineID.cs @@ -8,8 +8,12 @@ namespace osu.Game.Database public interface IHasOnlineID { /// - /// The server-side ID representing this instance, if one exists. -1 denotes a missing ID. + /// The server-side ID representing this instance, if one exists. Any value 0 or less denotes a missing ID. /// + /// + /// Generally we use -1 when specifying "missing" in code, but values of 0 are also considered missing as the online source + /// is generally a MySQL autoincrement value, which can never be 0. + /// int OnlineID { get; } } } diff --git a/osu.Game/Database/IModelImporter.cs b/osu.Game/Database/IModelImporter.cs index 479f33c3b4..5d0a044578 100644 --- a/osu.Game/Database/IModelImporter.cs +++ b/osu.Game/Database/IModelImporter.cs @@ -13,21 +13,9 @@ namespace osu.Game.Database /// A class which handles importing of associated models to the game store. /// /// The model type. - public interface IModelImporter : IPostNotifications, IPostImports + public interface IModelImporter : IPostNotifications, IPostImports, ICanAcceptFiles where TModel : class { - /// - /// Import one or more items from filesystem . - /// - /// - /// This will be treated as a low priority import if more than one path is specified; use to always import at standard priority. - /// This will post notifications tracking progress. - /// - /// One or more archive locations on disk. - Task Import(params string[] paths); - - Task Import(params ImportTask[] tasks); - Task>> Import(ProgressNotification notification, params ImportTask[] tasks); /// diff --git a/osu.Game/Database/IPostImports.cs b/osu.Game/Database/IPostImports.cs index f09285089a..b3b83f23ef 100644 --- a/osu.Game/Database/IPostImports.cs +++ b/osu.Game/Database/IPostImports.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; +#nullable enable + namespace osu.Game.Database { public interface IPostImports @@ -12,6 +14,6 @@ namespace osu.Game.Database /// /// Fired when the user requests to view the resulting import. /// - public Action>> PostImport { set; } + public Action>>? PostImport { set; } } } diff --git a/osu.Game/Database/RealmContextFactory.cs b/osu.Game/Database/RealmContextFactory.cs index b5c44927ca..3d0bb34dc1 100644 --- a/osu.Game/Database/RealmContextFactory.cs +++ b/osu.Game/Database/RealmContextFactory.cs @@ -77,6 +77,27 @@ namespace osu.Game.Database if (!Filename.EndsWith(realm_extension, StringComparison.Ordinal)) Filename += realm_extension; + + cleanupPendingDeletions(); + } + + private void cleanupPendingDeletions() + { + using (var realm = CreateContext()) + using (var transaction = realm.BeginWrite()) + { + var pendingDeleteSets = realm.All().Where(s => s.DeletePending); + + foreach (var s in pendingDeleteSets) + { + foreach (var b in s.Beatmaps) + realm.Remove(b); + + realm.Remove(s); + } + + transaction.Commit(); + } } /// diff --git a/osu.Game/Models/RealmBeatmapSet.cs b/osu.Game/Models/RealmBeatmapSet.cs index d6e56fd61c..6735510422 100644 --- a/osu.Game/Models/RealmBeatmapSet.cs +++ b/osu.Game/Models/RealmBeatmapSet.cs @@ -63,7 +63,7 @@ namespace osu.Game.Models if (IsManaged && other.IsManaged) return ID == other.ID; - if (OnlineID >= 0 && other.OnlineID >= 0) + if (OnlineID > 0 && other.OnlineID > 0) return OnlineID == other.OnlineID; if (!string.IsNullOrEmpty(Hash) && !string.IsNullOrEmpty(other.Hash)) diff --git a/osu.Game/Scoring/ScoreManager.cs b/osu.Game/Scoring/ScoreManager.cs index 8494cdcd22..a9791fba7e 100644 --- a/osu.Game/Scoring/ScoreManager.cs +++ b/osu.Game/Scoring/ScoreManager.cs @@ -25,7 +25,7 @@ using osu.Game.Rulesets.Scoring; namespace osu.Game.Scoring { - public class ScoreManager : IModelManager, IModelFileManager, IModelDownloader, ICanAcceptFiles + public class ScoreManager : IModelManager, IModelFileManager, IModelDownloader { private readonly Scheduler scheduler; private readonly Func difficulties; diff --git a/osu.Game/Stores/BeatmapImporter.cs b/osu.Game/Stores/BeatmapImporter.cs new file mode 100644 index 0000000000..254127cc7e --- /dev/null +++ b/osu.Game/Stores/BeatmapImporter.cs @@ -0,0 +1,331 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using NuGet.Packaging; +using osu.Framework.Audio.Track; +using osu.Framework.Extensions; +using osu.Framework.Extensions.IEnumerableExtensions; +using osu.Framework.Graphics.Textures; +using osu.Framework.Logging; +using osu.Framework.Platform; +using osu.Framework.Testing; +using osu.Game.Beatmaps; +using osu.Game.Beatmaps.Formats; +using osu.Game.Database; +using osu.Game.IO; +using osu.Game.IO.Archives; +using osu.Game.Models; +using osu.Game.Rulesets; +using osu.Game.Rulesets.Objects; +using osu.Game.Skinning; +using Realms; + +#nullable enable + +namespace osu.Game.Stores +{ + /// + /// Handles the storage and retrieval of Beatmaps/WorkingBeatmaps. + /// + [ExcludeFromDynamicCompile] + public class BeatmapImporter : RealmArchiveModelImporter, IDisposable + { + public override IEnumerable HandledExtensions => new[] { ".osz" }; + + protected override string[] HashableFileTypes => new[] { ".osu" }; + + // protected override bool CheckLocalAvailability(RealmBeatmapSet model, System.Linq.IQueryable items) + // => base.CheckLocalAvailability(model, items) || (model.OnlineID > -1)); + + private readonly BeatmapOnlineLookupQueue? onlineLookupQueue; + + public BeatmapImporter(RealmContextFactory contextFactory, Storage storage, BeatmapOnlineLookupQueue? onlineLookupQueue = null) + : base(storage, contextFactory) + { + this.onlineLookupQueue = onlineLookupQueue; + } + + protected override bool ShouldDeleteArchive(string path) => Path.GetExtension(path).ToLowerInvariant() == ".osz"; + + protected override Task Populate(RealmBeatmapSet beatmapSet, ArchiveReader? archive, Realm realm, CancellationToken cancellationToken = default) + { + if (archive != null) + beatmapSet.Beatmaps.AddRange(createBeatmapDifficulties(beatmapSet.Files, realm)); + + foreach (RealmBeatmap b in beatmapSet.Beatmaps) + b.BeatmapSet = beatmapSet; + + validateOnlineIds(beatmapSet, realm); + + bool hadOnlineBeatmapIDs = beatmapSet.Beatmaps.Any(b => b.OnlineID > 0); + + if (onlineLookupQueue != null) + { + // TODO: this required `BeatmapOnlineLookupQueue` to somehow support new types. + // await onlineLookupQueue.UpdateAsync(beatmapSet, cancellationToken).ConfigureAwait(false); + } + + // ensure at least one beatmap was able to retrieve or keep an online ID, else drop the set ID. + if (hadOnlineBeatmapIDs && !beatmapSet.Beatmaps.Any(b => b.OnlineID > 0)) + { + if (beatmapSet.OnlineID > 0) + { + beatmapSet.OnlineID = -1; + LogForModel(beatmapSet, "Disassociating beatmap set ID due to loss of all beatmap IDs"); + } + } + + return Task.CompletedTask; + } + + protected override void PreImport(RealmBeatmapSet beatmapSet, Realm realm) + { + // We are about to import a new beatmap. Before doing so, ensure that no other set shares the online IDs used by the new one. + // Note that this means if the previous beatmap is restored by the user, it will no longer be linked to its online IDs. + // If this is ever an issue, we can consider marking as pending delete but not resetting the IDs (but care will be required for + // beatmaps, which don't have their own `DeletePending` state). + + if (beatmapSet.OnlineID > 0) + { + var existingSetWithSameOnlineID = realm.All().SingleOrDefault(b => b.OnlineID == beatmapSet.OnlineID); + + if (existingSetWithSameOnlineID != null) + { + existingSetWithSameOnlineID.DeletePending = true; + existingSetWithSameOnlineID.OnlineID = -1; + + foreach (var b in existingSetWithSameOnlineID.Beatmaps) + b.OnlineID = -1; + + LogForModel(beatmapSet, $"Found existing beatmap set with same OnlineID ({beatmapSet.OnlineID}). It will be deleted."); + } + } + } + + private void validateOnlineIds(RealmBeatmapSet beatmapSet, Realm realm) + { + var beatmapIds = beatmapSet.Beatmaps.Where(b => b.OnlineID > 0).Select(b => b.OnlineID).ToList(); + + // ensure all IDs are unique + if (beatmapIds.GroupBy(b => b).Any(g => g.Count() > 1)) + { + LogForModel(beatmapSet, "Found non-unique IDs, resetting..."); + resetIds(); + return; + } + + // find any existing beatmaps in the database that have matching online ids + List existingBeatmaps = new List(); + + foreach (var id in beatmapIds) + existingBeatmaps.AddRange(realm.All().Where(b => b.OnlineID == id)); + + if (existingBeatmaps.Any()) + { + // reset the import ids (to force a re-fetch) *unless* they match the candidate CheckForExisting set. + // we can ignore the case where the new ids are contained by the CheckForExisting set as it will either be used (import skipped) or deleted. + + var existing = CheckForExisting(beatmapSet, realm); + + if (existing == null || existingBeatmaps.Any(b => !existing.Beatmaps.Contains(b))) + { + LogForModel(beatmapSet, "Found existing import with online IDs already, resetting..."); + resetIds(); + } + } + + void resetIds() => beatmapSet.Beatmaps.ForEach(b => b.OnlineID = -1); + } + + protected override bool CanSkipImport(RealmBeatmapSet existing, RealmBeatmapSet import) + { + if (!base.CanSkipImport(existing, import)) + return false; + + return existing.Beatmaps.Any(b => b.OnlineID > 0); + } + + protected override bool CanReuseExisting(RealmBeatmapSet existing, RealmBeatmapSet import) + { + if (!base.CanReuseExisting(existing, import)) + return false; + + var existingIds = existing.Beatmaps.Select(b => b.OnlineID).OrderBy(i => i); + var importIds = import.Beatmaps.Select(b => b.OnlineID).OrderBy(i => i); + + // force re-import if we are not in a sane state. + return existing.OnlineID == import.OnlineID && existingIds.SequenceEqual(importIds); + } + + public override string HumanisedModelName => "beatmap"; + + protected override RealmBeatmapSet? CreateModel(ArchiveReader reader) + { + // let's make sure there are actually .osu files to import. + string? mapName = reader.Filenames.FirstOrDefault(f => f.EndsWith(".osu", StringComparison.OrdinalIgnoreCase)); + + if (string.IsNullOrEmpty(mapName)) + { + Logger.Log($"No beatmap files found in the beatmap archive ({reader.Name}).", LoggingTarget.Database); + return null; + } + + Beatmap beatmap; + using (var stream = new LineBufferedReader(reader.GetStream(mapName))) + beatmap = Decoder.GetDecoder(stream).Decode(stream); + + return new RealmBeatmapSet + { + OnlineID = beatmap.BeatmapInfo.BeatmapSet?.OnlineBeatmapSetID ?? -1, + // Metadata = beatmap.Metadata, + DateAdded = DateTimeOffset.UtcNow + }; + } + + /// + /// Create all required s for the provided archive. + /// + private List createBeatmapDifficulties(IList files, Realm realm) + { + var beatmaps = new List(); + + foreach (var file in files.Where(f => f.Filename.EndsWith(".osu", StringComparison.OrdinalIgnoreCase))) + { + using (var memoryStream = new MemoryStream(Files.Store.Get(file.File.StoragePath))) // we need a memory stream so we can seek + { + IBeatmap decoded; + using (var lineReader = new LineBufferedReader(memoryStream, true)) + decoded = Decoder.GetDecoder(lineReader).Decode(lineReader); + + string hash = memoryStream.ComputeSHA2Hash(); + + if (beatmaps.Any(b => b.Hash == hash)) + { + Logger.Log($"Skipping import of {file.Filename} due to duplicate file content.", LoggingTarget.Database); + continue; + } + + var decodedInfo = decoded.BeatmapInfo; + var decodedDifficulty = decodedInfo.BaseDifficulty; + + var ruleset = realm.All().FirstOrDefault(r => r.OnlineID == decodedInfo.RulesetID); + + if (ruleset?.Available != true) + { + Logger.Log($"Skipping import of {file.Filename} due to missing local ruleset {decodedInfo.RulesetID}.", LoggingTarget.Database); + continue; + } + + var difficulty = new RealmBeatmapDifficulty + { + DrainRate = decodedDifficulty.DrainRate, + CircleSize = decodedDifficulty.CircleSize, + OverallDifficulty = decodedDifficulty.OverallDifficulty, + ApproachRate = decodedDifficulty.ApproachRate, + SliderMultiplier = decodedDifficulty.SliderMultiplier, + SliderTickRate = decodedDifficulty.SliderTickRate, + }; + + var metadata = new RealmBeatmapMetadata + { + Title = decoded.Metadata.Title, + TitleUnicode = decoded.Metadata.TitleUnicode, + Artist = decoded.Metadata.Artist, + ArtistUnicode = decoded.Metadata.ArtistUnicode, + Author = decoded.Metadata.AuthorString, + Source = decoded.Metadata.Source, + Tags = decoded.Metadata.Tags, + PreviewTime = decoded.Metadata.PreviewTime, + AudioFile = decoded.Metadata.AudioFile, + BackgroundFile = decoded.Metadata.BackgroundFile, + }; + + var beatmap = new RealmBeatmap(ruleset, difficulty, metadata) + { + Hash = hash, + DifficultyName = decodedInfo.Version, + OnlineID = decodedInfo.OnlineBeatmapID ?? -1, + AudioLeadIn = decodedInfo.AudioLeadIn, + StackLeniency = decodedInfo.StackLeniency, + SpecialStyle = decodedInfo.SpecialStyle, + LetterboxInBreaks = decodedInfo.LetterboxInBreaks, + WidescreenStoryboard = decodedInfo.WidescreenStoryboard, + EpilepsyWarning = decodedInfo.EpilepsyWarning, + SamplesMatchPlaybackRate = decodedInfo.SamplesMatchPlaybackRate, + DistanceSpacing = decodedInfo.DistanceSpacing, + BeatDivisor = decodedInfo.BeatDivisor, + GridSize = decodedInfo.GridSize, + TimelineZoom = decodedInfo.TimelineZoom, + MD5Hash = memoryStream.ComputeMD5Hash(), + }; + + updateBeatmapStatistics(beatmap, decoded); + + beatmaps.Add(beatmap); + } + } + + return beatmaps; + } + + private void updateBeatmapStatistics(RealmBeatmap beatmap, IBeatmap decoded) + { + var rulesetInstance = ((IRulesetInfo)beatmap.Ruleset).CreateInstance(); + + if (rulesetInstance == null) + return; + + decoded.BeatmapInfo.Ruleset = rulesetInstance.RulesetInfo; + + // TODO: this should be done in a better place once we actually need to dynamically update it. + beatmap.StarRating = rulesetInstance.CreateDifficultyCalculator(new DummyConversionBeatmap(decoded)).Calculate().StarRating; + beatmap.Length = calculateLength(decoded); + beatmap.BPM = 60000 / decoded.GetMostCommonBeatLength(); + } + + private double calculateLength(IBeatmap b) + { + if (!b.HitObjects.Any()) + return 0; + + var lastObject = b.HitObjects.Last(); + + //TODO: this isn't always correct (consider mania where a non-last object may last for longer than the last in the list). + double endTime = lastObject.GetEndTime(); + double startTime = b.HitObjects.First().StartTime; + + return endTime - startTime; + } + + public void Dispose() + { + onlineLookupQueue?.Dispose(); + } + + /// + /// A dummy WorkingBeatmap for the purpose of retrieving a beatmap for star difficulty calculation. + /// + private class DummyConversionBeatmap : WorkingBeatmap + { + private readonly IBeatmap beatmap; + + public DummyConversionBeatmap(IBeatmap beatmap) + : base(beatmap.BeatmapInfo, null) + { + this.beatmap = beatmap; + } + + protected override IBeatmap GetBeatmap() => beatmap; + protected override Texture? GetBackground() => null; + protected override Track? GetBeatmapTrack() => null; + protected internal override ISkin? GetSkin() => null; + public override Stream? GetStream(string storagePath) => null; + } + } +} diff --git a/osu.Game/Stores/RealmArchiveModelImporter.cs b/osu.Game/Stores/RealmArchiveModelImporter.cs new file mode 100644 index 0000000000..ec454d25fa --- /dev/null +++ b/osu.Game/Stores/RealmArchiveModelImporter.cs @@ -0,0 +1,550 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Humanizer; +using NuGet.Packaging; +using osu.Framework.Extensions; +using osu.Framework.Extensions.IEnumerableExtensions; +using osu.Framework.Logging; +using osu.Framework.Platform; +using osu.Framework.Threading; +using osu.Game.Database; +using osu.Game.IO.Archives; +using osu.Game.Models; +using osu.Game.Overlays.Notifications; +using Realms; + +#nullable enable + +namespace osu.Game.Stores +{ + /// + /// Encapsulates a model store class to give it import functionality. + /// Adds cross-functionality with to give access to the central file store for the provided model. + /// + /// The model type. + public abstract class RealmArchiveModelImporter : IModelImporter + where TModel : RealmObject, IHasRealmFiles, IHasGuidPrimaryKey, ISoftDelete + { + private const int import_queue_request_concurrency = 1; + + /// + /// The size of a batch import operation before considering it a lower priority operation. + /// + private const int low_priority_import_batch_size = 1; + + /// + /// A singleton scheduler shared by all . + /// + /// + /// This scheduler generally performs IO and CPU intensive work so concurrency is limited harshly. + /// It is mainly being used as a queue mechanism for large imports. + /// + private static readonly ThreadedTaskScheduler import_scheduler = new ThreadedTaskScheduler(import_queue_request_concurrency, nameof(RealmArchiveModelImporter)); + + /// + /// A second scheduler for lower priority imports. + /// For simplicity, these will just run in parallel with normal priority imports, but a future refactor would see this implemented via a custom scheduler/queue. + /// See https://gist.github.com/peppy/f0e118a14751fc832ca30dd48ba3876b for an incomplete version of this. + /// + private static readonly ThreadedTaskScheduler import_scheduler_low_priority = new ThreadedTaskScheduler(import_queue_request_concurrency, nameof(RealmArchiveModelImporter)); + + public virtual IEnumerable HandledExtensions => new[] { @".zip" }; + + protected readonly RealmFileStore Files; + + protected readonly RealmContextFactory ContextFactory; + + /// + /// Fired when the user requests to view the resulting import. + /// + public Action>>? PostImport { get; set; } + + /// + /// Set an endpoint for notifications to be posted to. + /// + public Action? PostNotification { protected get; set; } + + protected RealmArchiveModelImporter(Storage storage, RealmContextFactory contextFactory) + { + ContextFactory = contextFactory; + + Files = new RealmFileStore(contextFactory, storage); + } + + /// + /// Import one or more items from filesystem . + /// + /// + /// This will be treated as a low priority import if more than one path is specified; use to always import at standard priority. + /// This will post notifications tracking progress. + /// + /// One or more archive locations on disk. + public Task Import(params string[] paths) + { + var notification = new ProgressNotification { State = ProgressNotificationState.Active }; + + PostNotification?.Invoke(notification); + + return Import(notification, paths.Select(p => new ImportTask(p)).ToArray()); + } + + public Task Import(params ImportTask[] tasks) + { + var notification = new ProgressNotification { State = ProgressNotificationState.Active }; + + PostNotification?.Invoke(notification); + + return Import(notification, tasks); + } + + public async Task>> Import(ProgressNotification notification, params ImportTask[] tasks) + { + if (tasks.Length == 0) + { + notification.CompletionText = $"No {HumanisedModelName}s were found to import!"; + notification.State = ProgressNotificationState.Completed; + return Enumerable.Empty>(); + } + + notification.Progress = 0; + notification.Text = $"{HumanisedModelName.Humanize(LetterCasing.Title)} import is initialising..."; + + int current = 0; + + var imported = new List>(); + + bool isLowPriorityImport = tasks.Length > low_priority_import_batch_size; + + try + { + await Task.WhenAll(tasks.Select(async task => + { + notification.CancellationToken.ThrowIfCancellationRequested(); + + try + { + var model = await Import(task, isLowPriorityImport, notification.CancellationToken).ConfigureAwait(false); + + lock (imported) + { + if (model != null) + imported.Add(model); + current++; + + notification.Text = $"Imported {current} of {tasks.Length} {HumanisedModelName}s"; + notification.Progress = (float)current / tasks.Length; + } + } + catch (TaskCanceledException) + { + throw; + } + catch (Exception e) + { + Logger.Error(e, $@"Could not import ({task})", LoggingTarget.Database); + } + })).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + if (imported.Count == 0) + { + notification.State = ProgressNotificationState.Cancelled; + return imported; + } + } + + if (imported.Count == 0) + { + notification.Text = $"{HumanisedModelName.Humanize(LetterCasing.Title)} import failed!"; + notification.State = ProgressNotificationState.Cancelled; + } + else + { + notification.CompletionText = imported.Count == 1 + ? $"Imported {imported.First()}!" + : $"Imported {imported.Count} {HumanisedModelName}s!"; + + if (imported.Count > 0 && PostImport != null) + { + notification.CompletionText += " Click to view."; + notification.CompletionClickAction = () => + { + PostImport?.Invoke(imported); + return true; + }; + } + + notification.State = ProgressNotificationState.Completed; + } + + return imported; + } + + /// + /// Import one from the filesystem and delete the file on success. + /// Note that this bypasses the UI flow and should only be used for special cases or testing. + /// + /// The containing data about the to import. + /// Whether this is a low priority import. + /// An optional cancellation token. + /// The imported model, if successful. + public async Task?> Import(ImportTask task, bool lowPriority = false, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + ILive? import; + using (ArchiveReader reader = task.GetReader()) + import = await Import(reader, lowPriority, cancellationToken).ConfigureAwait(false); + + // We may or may not want to delete the file depending on where it is stored. + // e.g. reconstructing/repairing database with items from default storage. + // Also, not always a single file, i.e. for LegacyFilesystemReader + // TODO: Add a check to prevent files from storage to be deleted. + try + { + if (import != null && File.Exists(task.Path) && ShouldDeleteArchive(task.Path)) + File.Delete(task.Path); + } + catch (Exception e) + { + Logger.Error(e, $@"Could not delete original file after import ({task})"); + } + + return import; + } + + /// + /// Silently import an item from an . + /// + /// The archive to be imported. + /// Whether this is a low priority import. + /// An optional cancellation token. + public async Task?> Import(ArchiveReader archive, bool lowPriority = false, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + TModel? model = null; + + try + { + model = CreateModel(archive); + + if (model == null) + return null; + } + catch (TaskCanceledException) + { + throw; + } + catch (Exception e) + { + LogForModel(model, @$"Model creation of {archive.Name} failed.", e); + return null; + } + + var scheduledImport = Task.Factory.StartNew(async () => await Import(model, archive, lowPriority, cancellationToken).ConfigureAwait(false), + cancellationToken, TaskCreationOptions.HideScheduler, lowPriority ? import_scheduler_low_priority : import_scheduler).Unwrap(); + + return await scheduledImport.ConfigureAwait(true); + } + + /// + /// Any file extensions which should be included in hash creation. + /// Generally should include all file types which determine the file's uniqueness. + /// Large files should be avoided if possible. + /// + /// + /// This is only used by the default hash implementation. If is overridden, it will not be used. + /// + protected abstract string[] HashableFileTypes { get; } + + internal static void LogForModel(TModel? model, string message, Exception? e = null) + { + string trimmedHash; + if (model == null || !model.IsValid || string.IsNullOrEmpty(model.Hash)) + trimmedHash = "?????"; + else + trimmedHash = model.Hash.Substring(0, 5); + + string prefix = $"[{trimmedHash}]"; + + if (e != null) + Logger.Error(e, $"{prefix} {message}", LoggingTarget.Database); + else + Logger.Log($"{prefix} {message}", LoggingTarget.Database); + } + + /// + /// Whether the implementation overrides with a custom implementation. + /// Custom hash implementations must bypass the early exit in the import flow (see usage). + /// + protected virtual bool HasCustomHashFunction => false; + + /// + /// Create a SHA-2 hash from the provided archive based on file content of all files matching . + /// + /// + /// In the case of no matching files, a hash will be generated from the passed archive's . + /// + protected virtual string ComputeHash(TModel item, ArchiveReader? reader = null) + { + if (reader != null) + // fast hashing for cases where the item's files may not be populated. + return computeHashFast(reader); + + // for now, concatenate all hashable files in the set to create a unique hash. + MemoryStream hashable = new MemoryStream(); + + foreach (RealmNamedFileUsage file in item.Files.Where(f => HashableFileTypes.Any(ext => f.Filename.EndsWith(ext, StringComparison.OrdinalIgnoreCase))).OrderBy(f => f.Filename)) + { + using (Stream s = Files.Store.GetStream(file.File.StoragePath)) + s.CopyTo(hashable); + } + + if (hashable.Length > 0) + return hashable.ComputeSHA2Hash(); + + return item.Hash; + } + + /// + /// Silently import an item from a . + /// + /// The model to be imported. + /// An optional archive to use for model population. + /// Whether this is a low priority import. + /// An optional cancellation token. + public virtual async Task?> Import(TModel item, ArchiveReader? archive = null, bool lowPriority = false, CancellationToken cancellationToken = default) + { + using (var realm = ContextFactory.CreateContext()) + { + cancellationToken.ThrowIfCancellationRequested(); + + bool checkedExisting = false; + TModel? existing = null; + + if (archive != null && !HasCustomHashFunction) + { + // this is a fast bail condition to improve large import performance. + item.Hash = computeHashFast(archive); + + checkedExisting = true; + existing = CheckForExisting(item, realm); + + if (existing != null) + { + // bare minimum comparisons + // + // note that this should really be checking filesizes on disk (of existing files) for some degree of sanity. + // or alternatively doing a faster hash check. either of these require database changes and reprocessing of existing files. + if (CanSkipImport(existing, item) && + getFilenames(existing.Files).SequenceEqual(getShortenedFilenames(archive).Select(p => p.shortened).OrderBy(f => f))) + { + LogForModel(item, @$"Found existing (optimised) {HumanisedModelName} for {item} (ID {existing.ID}) – skipping import."); + + using (var transaction = realm.BeginWrite()) + { + existing.DeletePending = false; + transaction.Commit(); + } + + return existing.ToLive(); + } + + LogForModel(item, @"Found existing (optimised) but failed pre-check."); + } + } + + try + { + LogForModel(item, @"Beginning import..."); + + // TODO: do we want to make the transaction this local? not 100% sure, will need further investigation. + using (var transaction = realm.BeginWrite()) + { + if (archive != null) + // TODO: look into rollback of file additions (or delayed commit). + item.Files.AddRange(createFileInfos(archive, Files, realm)); + + item.Hash = ComputeHash(item, archive); + + // TODO: we may want to run this outside of the transaction. + await Populate(item, archive, realm, cancellationToken).ConfigureAwait(false); + + if (!checkedExisting) + existing = CheckForExisting(item, realm); + + if (existing != null) + { + if (CanReuseExisting(existing, item)) + { + LogForModel(item, @$"Found existing {HumanisedModelName} for {item} (ID {existing.ID}) – skipping import."); + existing.DeletePending = false; + + return existing.ToLive(); + } + + LogForModel(item, @"Found existing but failed re-use check."); + + existing.DeletePending = true; + + // todo: actually delete? i don't think this is required... + // ModelStore.PurgeDeletable(s => s.ID == existing.ID); + } + + PreImport(item, realm); + + // import to store + realm.Add(item); + + transaction.Commit(); + } + + LogForModel(item, @"Import successfully completed!"); + } + catch (Exception e) + { + if (!(e is TaskCanceledException)) + LogForModel(item, @"Database import or population failed and has been rolled back.", e); + + throw; + } + + return item.ToLive(); + } + } + + private string computeHashFast(ArchiveReader reader) + { + MemoryStream hashable = new MemoryStream(); + + foreach (var file in reader.Filenames.Where(f => HashableFileTypes.Any(ext => f.EndsWith(ext, StringComparison.OrdinalIgnoreCase))).OrderBy(f => f)) + { + using (Stream s = reader.GetStream(file)) + s.CopyTo(hashable); + } + + if (hashable.Length > 0) + return hashable.ComputeSHA2Hash(); + + return reader.Name.ComputeSHA2Hash(); + } + + /// + /// Create all required s for the provided archive, adding them to the global file store. + /// + private List createFileInfos(ArchiveReader reader, RealmFileStore files, Realm realm) + { + var fileInfos = new List(); + + // import files to manager + foreach (var filenames in getShortenedFilenames(reader)) + { + using (Stream s = reader.GetStream(filenames.original)) + { + var item = new RealmNamedFileUsage(files.Add(s, realm), filenames.shortened); + fileInfos.Add(item); + } + } + + return fileInfos; + } + + private IEnumerable<(string original, string shortened)> getShortenedFilenames(ArchiveReader reader) + { + string prefix = reader.Filenames.GetCommonPrefix(); + if (!(prefix.EndsWith('/') || prefix.EndsWith('\\'))) + prefix = string.Empty; + + // import files to manager + foreach (string file in reader.Filenames) + yield return (file, file.Substring(prefix.Length).ToStandardisedPath()); + } + + /// + /// Create a barebones model from the provided archive. + /// Actual expensive population should be done in ; this should just prepare for duplicate checking. + /// + /// The archive to create the model for. + /// A model populated with minimal information. Returning a null will abort importing silently. + protected abstract TModel? CreateModel(ArchiveReader archive); + + /// + /// Populate the provided model completely from the given archive. + /// After this method, the model should be in a state ready to commit to a store. + /// + /// The model to populate. + /// The archive to use as a reference for population. May be null. + /// The current realm context. + /// An optional cancellation token. + protected abstract Task Populate(TModel model, ArchiveReader? archive, Realm realm, CancellationToken cancellationToken = default); + + /// + /// Perform any final actions before the import to database executes. + /// + /// The model prepared for import. + /// The current realm context. + protected virtual void PreImport(TModel model, Realm realm) + { + } + + /// + /// Check whether an existing model already exists for a new import item. + /// + /// The new model proposed for import. + /// The current realm context. + /// An existing model which matches the criteria to skip importing, else null. + protected TModel? CheckForExisting(TModel model, Realm realm) => string.IsNullOrEmpty(model.Hash) ? null : realm.All().FirstOrDefault(b => b.Hash == model.Hash); + + /// + /// Whether import can be skipped after finding an existing import early in the process. + /// Only valid when is not overridden. + /// + /// The existing model. + /// The newly imported model. + /// Whether to skip this import completely. + protected virtual bool CanSkipImport(TModel existing, TModel import) => true; + + /// + /// After an existing is found during an import process, the default behaviour is to use/restore the existing + /// item and skip the import. This method allows changing that behaviour. + /// + /// The existing model. + /// The newly imported model. + /// Whether the existing model should be restored and used. Returning false will delete the existing and force a re-import. + protected virtual bool CanReuseExisting(TModel existing, TModel import) => + // for the best or worst, we copy and import files of a new import before checking whether + // it is a duplicate. so to check if anything has changed, we can just compare all File IDs. + getIDs(existing.Files).SequenceEqual(getIDs(import.Files)) && + getFilenames(existing.Files).SequenceEqual(getFilenames(import.Files)); + + /// + /// Whether this specified path should be removed after successful import. + /// + /// The path for consideration. May be a file or a directory. + /// Whether to perform deletion. + protected virtual bool ShouldDeleteArchive(string path) => false; + + private IEnumerable getIDs(IEnumerable files) + { + foreach (var f in files.OrderBy(f => f.Filename)) + yield return f.File.Hash; + } + + private IEnumerable getFilenames(IEnumerable files) + { + foreach (var f in files.OrderBy(f => f.Filename)) + yield return f.Filename; + } + + public virtual string HumanisedModelName => $"{typeof(TModel).Name.Replace(@"Info", "").ToLower()}"; + } +}