1
0
mirror of https://github.com/ppy/osu.git synced 2024-11-11 09:27:29 +08:00

Merge pull request #2665 from peppy/import-fixes

Fix beatmap importing entering a bad state
This commit is contained in:
Dan Balasescu 2018-05-31 13:58:22 +09:00 committed by GitHub
commit 3fc2bfd313
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 221 additions and 79 deletions

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\osu.Game.props" />
<PropertyGroup Label="Project">
<TargetFrameworks>net471;netcoreapp2.0</TargetFrameworks>
@ -30,10 +30,10 @@
<PackageReference Include="Microsoft.Win32.Registry" Version="4.4.0" />
</ItemGroup>
<ItemGroup Label="Package References">
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" Version="2.0.1" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" Version="2.0.3" />
<PackageReference Include="squirrel.windows" Version="1.8.0" Condition="'$(TargetFramework)' == 'net471'" />
</ItemGroup>
<ItemGroup Label="Resources">
<EmbeddedResource Include="lazer.ico" />
</ItemGroup>
</Project>
</Project>

View File

@ -12,6 +12,7 @@ using osu.Framework.Platform;
using osu.Game.IPC;
using osu.Framework.Allocation;
using osu.Game.Beatmaps;
using SharpCompress.Archives.Zip;
namespace osu.Game.Tests.Beatmaps.IO
{
@ -77,8 +78,69 @@ namespace osu.Game.Tests.Beatmaps.IO
var manager = osu.Dependencies.Get<BeatmapManager>();
Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 1);
Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1);
Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
}
finally
{
host.Exit();
}
}
}
[Test]
public void TestRollbackOnFailure()
{
//unfortunately for the time being we need to reference osu.Framework.Desktop for a game host here.
using (HeadlessGameHost host = new CleanRunHeadlessGameHost("TestRollbackOnFailure"))
{
try
{
var osu = loadOsu(host);
var manager = osu.Dependencies.Get<BeatmapManager>();
int fireCount = 0;
// ReSharper disable once AccessToModifiedClosure
manager.ItemAdded += _ => fireCount++;
manager.ItemRemoved += _ => fireCount++;
var imported = loadOszIntoOsu(osu);
Assert.AreEqual(0, fireCount -= 1);
imported.Hash += "-changed";
manager.Update(imported);
Assert.AreEqual(0, fireCount -= 2);
var breakTemp = createTemporaryBeatmap();
MemoryStream brokenOsu = new MemoryStream(new byte[] { 1, 3, 3, 7 });
MemoryStream brokenOsz = new MemoryStream(File.ReadAllBytes(breakTemp));
File.Delete(breakTemp);
using (var outStream = File.Open(breakTemp, FileMode.CreateNew))
using (var zip = ZipArchive.Open(brokenOsz))
{
zip.AddEntry("broken.osu", brokenOsu, false);
zip.SaveTo(outStream, SharpCompress.Common.CompressionType.Deflate);
}
Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count);
// this will trigger purging of the existing beatmap (online set id match) but should rollback due to broken osu.
manager.Import(breakTemp);
// no events should be fired in the case of a rollback.
Assert.AreEqual(0, fireCount);
Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
Assert.AreEqual(12, manager.QueryBeatmaps(_ => true).ToList().Count);
}
finally
{
@ -100,18 +162,17 @@ namespace osu.Game.Tests.Beatmaps.IO
var imported = loadOszIntoOsu(osu);
//var change = manager.QueryBeatmapSets(_ => true).First();
imported.Hash += "-changed";
manager.Update(imported);
var importedSecondTime = loadOszIntoOsu(osu);
// check the newly "imported" beatmap is actually just the restored previous import. since it matches hash.
Assert.IsTrue(imported.ID != importedSecondTime.ID);
Assert.IsTrue(imported.Beatmaps.First().ID < importedSecondTime.Beatmaps.First().ID);
Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 1);
Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1);
// only one beatmap will exist as the online set ID matched, causing purging of the first import.
Assert.AreEqual(1, manager.GetAllUsableBeatmapSets().Count);
Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
}
finally
{
@ -162,8 +223,7 @@ namespace osu.Game.Tests.Beatmaps.IO
var osu = loadOsu(host);
var temp = prepareTempCopy(osz_path);
Assert.IsTrue(File.Exists(temp));
var temp = createTemporaryBeatmap();
var importer = new ArchiveImportIPCChannel(client);
if (!importer.ImportAsync(temp).Wait(10000))
@ -188,8 +248,7 @@ namespace osu.Game.Tests.Beatmaps.IO
try
{
var osu = loadOsu(host);
var temp = prepareTempCopy(osz_path);
Assert.IsTrue(File.Exists(temp), "Temporary file copy never substantiated");
var temp = createTemporaryBeatmap();
using (File.OpenRead(temp))
osu.Dependencies.Get<BeatmapManager>().Import(temp);
ensureLoaded(osu);
@ -203,11 +262,16 @@ namespace osu.Game.Tests.Beatmaps.IO
}
}
private BeatmapSetInfo loadOszIntoOsu(OsuGameBase osu)
private string createTemporaryBeatmap()
{
var temp = prepareTempCopy(osz_path);
var temp = new FileInfo(osz_path).CopyTo(Path.GetTempFileName(), true).FullName;
Assert.IsTrue(File.Exists(temp));
return temp;
}
private BeatmapSetInfo loadOszIntoOsu(OsuGameBase osu, string path = null)
{
var temp = path ?? createTemporaryBeatmap();
var manager = osu.Dependencies.Get<BeatmapManager>();
@ -219,7 +283,7 @@ namespace osu.Game.Tests.Beatmaps.IO
waitForOrAssert(() => !File.Exists(temp), "Temporary file still exists after standard import", 5000);
return imported.FirstOrDefault();
return imported.LastOrDefault();
}
private void deleteBeatmapSet(BeatmapSetInfo imported, OsuGameBase osu)
@ -228,16 +292,10 @@ namespace osu.Game.Tests.Beatmaps.IO
manager.Delete(imported);
Assert.IsTrue(manager.GetAllUsableBeatmapSets().Count == 0);
Assert.IsTrue(manager.QueryBeatmapSets(_ => true).ToList().Count == 1);
Assert.AreEqual(1, manager.QueryBeatmapSets(_ => true).ToList().Count);
Assert.IsTrue(manager.QueryBeatmapSets(_ => true).First().DeletePending);
}
private string prepareTempCopy(string path)
{
var temp = Path.GetTempFileName();
return new FileInfo(path).CopyTo(temp, true).FullName;
}
private OsuGameBase loadOsu(GameHost host)
{
var osu = new OsuGameBase();

View File

@ -56,13 +56,49 @@ namespace osu.Game.Database
// ReSharper disable once NotAccessedField.Local (we should keep a reference to this so it is not finalised)
private ArchiveImportIPCChannel ipc;
private readonly List<Action> cachedEvents = new List<Action>();
/// <summary>
/// Allows delaying of outwards events until an operation is confirmed (at a database level).
/// </summary>
private bool delayingEvents;
/// <summary>
/// Begin delaying outwards events.
/// </summary>
private void delayEvents() => delayingEvents = true;
/// <summary>
/// Flush delayed events and disable delaying.
/// </summary>
/// <param name="perform">Whether the flushed events should be performed.</param>
private void flushEvents(bool perform)
{
if (perform)
{
foreach (var a in cachedEvents)
a.Invoke();
}
cachedEvents.Clear();
delayingEvents = false;
}
private void handleEvent(Action a)
{
if (delayingEvents)
cachedEvents.Add(a);
else
a.Invoke();
}
protected ArchiveModelManager(Storage storage, IDatabaseContextFactory contextFactory, MutableDatabaseBackedStore<TModel> modelStore, IIpcHost importHost = null)
{
ContextFactory = contextFactory;
ModelStore = modelStore;
ModelStore.ItemAdded += s => ItemAdded?.Invoke(s);
ModelStore.ItemRemoved += s => ItemRemoved?.Invoke(s);
ModelStore.ItemAdded += s => handleEvent(() => ItemAdded?.Invoke(s));
ModelStore.ItemRemoved += s => handleEvent(() => ItemRemoved?.Invoke(s));
Files = new FileStore(contextFactory, storage);
@ -138,24 +174,50 @@ namespace osu.Game.Database
/// <param name="archive">The archive to be imported.</param>
public TModel Import(ArchiveReader archive)
{
using (ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
TModel item = null;
delayEvents();
try
{
// create a new model (don't yet add to database)
var item = CreateModel(archive);
using (var write = ContextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
{
try
{
if (!write.IsTransactionLeader) throw new InvalidOperationException($"Ensure there is no parent transaction so errors can correctly be handled by {this}");
var existing = CheckForExisting(item);
// create a new model (don't yet add to database)
item = CreateModel(archive);
if (existing != null) return existing;
var existing = CheckForExisting(item);
item.Files = createFileInfos(archive, Files);
if (existing != null) return existing;
Populate(item, archive);
item.Files = createFileInfos(archive, Files);
// import to store
ModelStore.Add(item);
Populate(item, archive);
return item;
// import to store
ModelStore.Add(item);
}
catch (Exception e)
{
write.Errors.Add(e);
throw;
}
}
}
catch (Exception e)
{
Logger.Error(e, $"Import of {archive.Name} failed and has been rolled back.", LoggingTarget.Database);
item = null;
}
finally
{
// we only want to flush events after we've confirmed the write context didn't have any errors.
flushEvents(item != null);
}
return item;
}
/// <summary>
@ -178,12 +240,8 @@ namespace osu.Game.Database
/// <param name="item">The item to delete.</param>
public void Delete(TModel item)
{
using (var usage = ContextFactory.GetForWrite())
using (ContextFactory.GetForWrite())
{
var context = usage.Context;
context.ChangeTracker.AutoDetectChangesEnabled = false;
// re-fetch the model on the import context.
var foundModel = queryModel().Include(s => s.Files).ThenInclude(f => f.FileInfo).First(s => s.ID == item.ID);
@ -191,8 +249,6 @@ namespace osu.Game.Database
if (ModelStore.Delete(foundModel))
Files.Dereference(foundModel.Files.Select(f => f.FileInfo).ToArray());
context.ChangeTracker.AutoDetectChangesEnabled = true;
}
}

View File

@ -1,7 +1,9 @@
// Copyright (c) 2007-2018 ppy Pty Ltd <contact@ppy.sh>.
// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE
using System.Linq;
using System.Threading;
using Microsoft.EntityFrameworkCore.Storage;
using osu.Framework.Platform;
namespace osu.Game.Database
@ -17,8 +19,12 @@ namespace osu.Game.Database
private readonly object writeLock = new object();
private bool currentWriteDidWrite;
private bool currentWriteDidError;
private int currentWriteUsages;
private IDbContextTransaction currentWriteTransaction;
public DatabaseContextFactory(GameHost host)
{
this.host = host;
@ -35,14 +41,25 @@ namespace osu.Game.Database
/// Request a context for write usage. Can be consumed in a nested fashion (and will return the same underlying context).
/// This method may block if a write is already active on a different thread.
/// </summary>
/// <param name="withTransaction">Whether to start a transaction for this write.</param>
/// <returns>A usage containing a usable context.</returns>
public DatabaseWriteUsage GetForWrite()
public DatabaseWriteUsage GetForWrite(bool withTransaction = true)
{
Monitor.Enter(writeLock);
if (currentWriteTransaction == null && withTransaction)
{
// this mitigates the fact that changes on tracked entities will not be rolled back with the transaction by ensuring write operations are always executed in isolated contexts.
// if this results in sub-optimal efficiency, we may need to look into removing Database-level transactions in favour of running SaveChanges where we currently commit the transaction.
if (threadContexts.IsValueCreated)
recycleThreadContexts();
currentWriteTransaction = threadContexts.Value.Database.BeginTransaction();
}
Interlocked.Increment(ref currentWriteUsages);
return new DatabaseWriteUsage(threadContexts.Value, usageCompleted);
return new DatabaseWriteUsage(threadContexts.Value, usageCompleted) { IsTransactionLeader = currentWriteTransaction != null && currentWriteUsages == 1 };
}
private void usageCompleted(DatabaseWriteUsage usage)
@ -52,18 +69,27 @@ namespace osu.Game.Database
try
{
currentWriteDidWrite |= usage.PerformedWrite;
currentWriteDidError |= usage.Errors.Any();
if (usages > 0) return;
if (currentWriteDidWrite)
if (usages == 0)
{
// explicitly dispose to ensure any outstanding flushes happen as soon as possible (and underlying resources are purged).
usage.Context.Dispose();
if (currentWriteDidError)
currentWriteTransaction?.Rollback();
else
currentWriteTransaction?.Commit();
if (currentWriteDidWrite || currentWriteDidError)
{
// explicitly dispose to ensure any outstanding flushes happen as soon as possible (and underlying resources are purged).
usage.Context.Dispose();
// once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
recycleThreadContexts();
}
currentWriteTransaction = null;
currentWriteDidWrite = false;
// once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
recycleThreadContexts();
currentWriteDidError = false;
}
}
finally

View File

@ -2,34 +2,50 @@
// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE
using System;
using Microsoft.EntityFrameworkCore.Storage;
using System.Collections.Generic;
namespace osu.Game.Database
{
public class DatabaseWriteUsage : IDisposable
{
public readonly OsuDbContext Context;
private readonly IDbContextTransaction transaction;
private readonly Action<DatabaseWriteUsage> usageCompleted;
public DatabaseWriteUsage(OsuDbContext context, Action<DatabaseWriteUsage> onCompleted)
{
Context = context;
transaction = Context.BeginTransaction();
usageCompleted = onCompleted;
}
public bool PerformedWrite { get; private set; }
private bool isDisposed;
public List<Exception> Errors = new List<Exception>();
/// <summary>
/// Whether this write usage will commit a transaction on completion.
/// If false, there is a parent usage responsible for transaction commit.
/// </summary>
public bool IsTransactionLeader = false;
protected void Dispose(bool disposing)
{
if (isDisposed) return;
isDisposed = true;
PerformedWrite |= Context.SaveChanges(transaction) > 0;
usageCompleted?.Invoke(this);
try
{
PerformedWrite |= Context.SaveChanges() > 0;
}
catch (Exception e)
{
Errors.Add(e);
throw;
}
finally
{
usageCompleted?.Invoke(this);
}
}
public void Dispose()

View File

@ -14,7 +14,8 @@ namespace osu.Game.Database
/// Request a context for write usage. Can be consumed in a nested fashion (and will return the same underlying context).
/// This method may block if a write is already active on a different thread.
/// </summary>
/// <param name="withTransaction">Whether to start a transaction for this write.</param>
/// <returns>A usage containing a usable context.</returns>
DatabaseWriteUsage GetForWrite();
DatabaseWriteUsage GetForWrite(bool withTransaction = true);
}
}

View File

@ -50,11 +50,10 @@ namespace osu.Game.Database
/// <param name="item">The item to update.</param>
public void Update(T item)
{
ItemRemoved?.Invoke(item);
using (var usage = ContextFactory.GetForWrite())
usage.Context.Update(item);
ItemRemoved?.Invoke(item);
ItemAdded?.Invoke(item);
}

View File

@ -3,7 +3,6 @@
using System;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.Extensions.Logging;
using osu.Framework.Logging;
@ -104,19 +103,6 @@ namespace osu.Game.Database
modelBuilder.Entity<BeatmapInfo>().HasOne(b => b.BaseDifficulty);
}
public IDbContextTransaction BeginTransaction()
{
// return Database.BeginTransaction();
return null;
}
public int SaveChanges(IDbContextTransaction transaction = null)
{
var ret = base.SaveChanges();
if (ret > 0) transaction?.Commit();
return ret;
}
private class OsuDbLoggerFactory : ILoggerFactory
{
#region Disposal

View File

@ -14,6 +14,6 @@ namespace osu.Game.Database
public OsuDbContext Get() => context;
public DatabaseWriteUsage GetForWrite() => new DatabaseWriteUsage(context, null);
public DatabaseWriteUsage GetForWrite(bool withTransaction = true) => new DatabaseWriteUsage(context, null);
}
}

View File

@ -208,7 +208,7 @@ namespace osu.Game
{
try
{
using (var db = contextFactory.GetForWrite())
using (var db = contextFactory.GetForWrite(false))
db.Context.Migrate();
}
catch (MigrationFailedException e)
@ -220,7 +220,7 @@ namespace osu.Game
contextFactory.ResetDatabase();
Logger.Log("Database purged successfully.", LoggingTarget.Database, LogLevel.Important);
using (var db = contextFactory.GetForWrite())
using (var db = contextFactory.GetForWrite(false))
db.Context.Migrate();
}
}

View File

@ -17,7 +17,7 @@
<ItemGroup Label="Package References">
<PackageReference Include="Humanizer" Version="2.2.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" Version="2.0.1" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite.Core" Version="2.0.1" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite.Core" Version="2.0.3" />
<PackageReference Include="Newtonsoft.Json" Version="10.0.3" />
<PackageReference Include="SharpCompress" Version="0.18.1" />
<PackageReference Include="NUnit" Version="3.10.1" />