diff --git a/osu.Game.Tests/Visual/TestCasePlaySongSelect.cs b/osu.Game.Tests/Visual/TestCasePlaySongSelect.cs index 809de2b8db..f54eb77c6b 100644 --- a/osu.Game.Tests/Visual/TestCasePlaySongSelect.cs +++ b/osu.Game.Tests/Visual/TestCasePlaySongSelect.cs @@ -63,12 +63,10 @@ namespace osu.Game.Tests.Visual var storage = new TestStorage(@"TestCasePlaySongSelect"); // this is by no means clean. should be replacing inside of OsuGameBase somehow. - var context = new OsuDbContext(); + DatabaseContextFactory factory = new SingletonContextFactory(new OsuDbContext()); - OsuDbContext contextFactory() => context; - - dependencies.Cache(rulesets = new RulesetStore(contextFactory)); - dependencies.Cache(manager = new BeatmapManager(storage, contextFactory, rulesets, null) + dependencies.Cache(rulesets = new RulesetStore(factory)); + dependencies.Cache(manager = new BeatmapManager(storage, factory, rulesets, null) { DefaultBeatmap = defaultBeatmap = game.Beatmap.Default }); diff --git a/osu.Game/Beatmaps/BeatmapManager.cs b/osu.Game/Beatmaps/BeatmapManager.cs index cbaa8a1066..4ec153c78f 100644 --- a/osu.Game/Beatmaps/BeatmapManager.cs +++ b/osu.Game/Beatmaps/BeatmapManager.cs @@ -60,7 +60,7 @@ namespace osu.Game.Beatmaps /// public WorkingBeatmap DefaultBeatmap { private get; set; } - private readonly Func createContext; + private readonly DatabaseContextFactory contextFactory; private readonly FileStore files; @@ -85,29 +85,18 @@ namespace osu.Game.Beatmaps /// public Func GetStableStorage { private get; set; } - private void refreshImportContext() + public BeatmapManager(Storage storage, DatabaseContextFactory contextFactory, RulesetStore rulesets, APIAccess api, IIpcHost importHost = null) { - lock (importContextLock) - { - importContext?.Value?.Dispose(); + this.contextFactory = contextFactory; - importContext = new Lazy(() => - { - var c = createContext(); - c.Database.AutoTransactionsEnabled = false; - return c; - }); - } - } + beatmaps = new BeatmapStore(contextFactory); - public BeatmapManager(Storage storage, Func context, RulesetStore rulesets, APIAccess api, IIpcHost importHost = null) - { - createContext = context; + beatmaps.BeatmapSetAdded += s => BeatmapSetAdded?.Invoke(s); + beatmaps.BeatmapSetRemoved += s => BeatmapSetRemoved?.Invoke(s); + beatmaps.BeatmapHidden += b => BeatmapHidden?.Invoke(b); + beatmaps.BeatmapRestored += b => BeatmapRestored?.Invoke(b); - refreshImportContext(); - - beatmaps = getBeatmapStoreWithContext(context); - files = new FileStore(context, storage); + files = new FileStore(contextFactory, storage); this.rulesets = rulesets; this.api = api; @@ -170,7 +159,6 @@ namespace osu.Game.Beatmaps { e = e.InnerException ?? e; Logger.Error(e, $@"Could not import beatmap set ({Path.GetFileName(path)})"); - refreshImportContext(); } } @@ -178,80 +166,57 @@ namespace osu.Game.Beatmaps return imported; } - private readonly object importContextLock = new object(); - private Lazy importContext; - /// /// Import a beatmap from an . /// /// The beatmap to be imported. public BeatmapSetInfo Import(ArchiveReader archive) { - // let's only allow one concurrent import at a time for now - lock (importContextLock) + using ( contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes. { - var context = importContext.Value; + // create a new set info (don't yet add to database) + var beatmapSet = createBeatmapSetInfo(archive); - using (var transaction = context.BeginTransaction()) + // check if this beatmap has already been imported and exit early if so + var existingHashMatch = beatmaps.BeatmapSets.FirstOrDefault(b => b.Hash == beatmapSet.Hash); + if (existingHashMatch != null) { - // create a new set info (don't yet add to database) - var beatmapSet = createBeatmapSetInfo(archive); - - // check if this beatmap has already been imported and exit early if so - var existingHashMatch = beatmaps.BeatmapSets.FirstOrDefault(b => b.Hash == beatmapSet.Hash); - if (existingHashMatch != null) - { - undelete(beatmaps, files, existingHashMatch); - return existingHashMatch; - } - - // check if a set already exists with the same online id - if (beatmapSet.OnlineBeatmapSetID != null) - { - var existingOnlineId = beatmaps.BeatmapSets.FirstOrDefault(b => b.OnlineBeatmapSetID == beatmapSet.OnlineBeatmapSetID); - if (existingOnlineId != null) - { - // {Microsoft.EntityFrameworkCore.DbUpdateConcurrencyException: Database operation expected to affect 1 row(s) but actually affected 0 row(s). Data may have been modified or deleted since entities were loaded. See http://go.microsoft.com/fwlink/?LinkId=527962…} - - Delete(existingOnlineId); - beatmaps.Cleanup(s => s.ID == existingOnlineId.ID); - } - } - - beatmapSet.Files = createFileInfos(archive, getFileStoreWithContext(context)); - beatmapSet.Beatmaps = createBeatmapDifficulties(archive); - - // remove metadata from difficulties where it matches the set - foreach (BeatmapInfo b in beatmapSet.Beatmaps) - if (beatmapSet.Metadata.Equals(b.Metadata)) - b.Metadata = null; - - // import to beatmap store - import(beatmapSet, context); - - context.SaveChanges(transaction); - return beatmapSet; + undelete(existingHashMatch); + return existingHashMatch; } + + // check if a set already exists with the same online id + if (beatmapSet.OnlineBeatmapSetID != null) + { + var existingOnlineId = beatmaps.BeatmapSets.FirstOrDefault(b => b.OnlineBeatmapSetID == beatmapSet.OnlineBeatmapSetID); + if (existingOnlineId != null) + { + // {Microsoft.EntityFrameworkCore.DbUpdateConcurrencyException: Database operation expected to affect 1 row(s) but actually affected 0 row(s). Data may have been modified or deleted since entities were loaded. See http://go.microsoft.com/fwlink/?LinkId=527962…} + + Delete(existingOnlineId); + beatmaps.Cleanup(s => s.ID == existingOnlineId.ID); + } + } + + beatmapSet.Files = createFileInfos(archive, files); + beatmapSet.Beatmaps = createBeatmapDifficulties(archive); + + // remove metadata from difficulties where it matches the set + foreach (BeatmapInfo b in beatmapSet.Beatmaps) + if (beatmapSet.Metadata.Equals(b.Metadata)) + b.Metadata = null; + + // import to beatmap store + Import(beatmapSet); + return beatmapSet; } } /// /// Import a beatmap from a . /// - /// The beatmap to be imported. - public void Import(BeatmapSetInfo beatmapSetInfo) - { - lock (importContextLock) - { - var context = importContext.Value; - - using (var transaction = context.BeginTransaction()) - { - import(beatmapSetInfo, context); - context.SaveChanges(transaction); - } - } - } + /// The beatmap to be imported. + public void Import(BeatmapSetInfo beatmapSet) => beatmaps.Add(beatmapSet); /// /// Downloads a beatmap. @@ -350,26 +315,22 @@ namespace osu.Game.Beatmaps /// The beatmap set to delete. public void Delete(BeatmapSetInfo beatmapSet) { - lock (importContextLock) + using (var db = contextFactory.GetForWrite()) { - var context = importContext.Value; + var context = db.Context; - using (var transaction = context.BeginTransaction()) + context.ChangeTracker.AutoDetectChangesEnabled = false; + + // re-fetch the beatmap set on the import context. + beatmapSet = context.BeatmapSetInfo.Include(s => s.Files).ThenInclude(f => f.FileInfo).First(s => s.ID == beatmapSet.ID); + + if (beatmaps.Delete(beatmapSet)) { - context.ChangeTracker.AutoDetectChangesEnabled = false; - - // re-fetch the beatmap set on the import context. - beatmapSet = context.BeatmapSetInfo.Include(s => s.Files).ThenInclude(f => f.FileInfo).First(s => s.ID == beatmapSet.ID); - - if (getBeatmapStoreWithContext(context).Delete(beatmapSet)) - { - if (!beatmapSet.Protected) - getFileStoreWithContext(context).Dereference(beatmapSet.Files.Select(f => f.FileInfo).ToArray()); - } - - context.ChangeTracker.AutoDetectChangesEnabled = true; - context.SaveChanges(transaction); + if (!beatmapSet.Protected) + files.Dereference(beatmapSet.Files.Select(f => f.FileInfo).ToArray()); } + + context.ChangeTracker.AutoDetectChangesEnabled = true; } } @@ -417,19 +378,11 @@ namespace osu.Game.Beatmaps if (beatmapSet.Protected) return; - lock (importContextLock) + using (var db = contextFactory.GetForWrite()) { - var context = importContext.Value; - - using (var transaction = context.BeginTransaction()) - { - context.ChangeTracker.AutoDetectChangesEnabled = false; - - undelete(getBeatmapStoreWithContext(context), getFileStoreWithContext(context), beatmapSet); - - context.ChangeTracker.AutoDetectChangesEnabled = true; - context.SaveChanges(transaction); - } + db.Context.ChangeTracker.AutoDetectChangesEnabled = false; + undelete(beatmapSet); + db.Context.ChangeTracker.AutoDetectChangesEnabled = true; } } @@ -452,7 +405,7 @@ namespace osu.Game.Beatmaps /// The store to restore beatmaps from. /// The store to restore beatmap files from. /// The beatmap to restore. - private void undelete(BeatmapStore beatmaps, FileStore files, BeatmapSetInfo beatmapSet) + private void undelete(BeatmapSetInfo beatmapSet) { if (!beatmaps.Undelete(beatmapSet)) return; @@ -578,11 +531,6 @@ namespace osu.Game.Beatmaps notification.State = ProgressNotificationState.Completed; } - /// - /// Import a into the beatmap store. - /// - private void import(BeatmapSetInfo beatmapSet, OsuDbContext context) => getBeatmapStoreWithContext(context).Add(beatmapSet); - /// /// Creates an from a valid storage path. /// @@ -689,19 +637,5 @@ namespace osu.Game.Beatmaps return beatmapInfos; } - - private FileStore getFileStoreWithContext(OsuDbContext context) => new FileStore(() => context, files.Storage); - - private BeatmapStore getBeatmapStoreWithContext(OsuDbContext context) => getBeatmapStoreWithContext(() => context); - - private BeatmapStore getBeatmapStoreWithContext(Func context) - { - var store = new BeatmapStore(context); - store.BeatmapSetAdded += s => BeatmapSetAdded?.Invoke(s); - store.BeatmapSetRemoved += s => BeatmapSetRemoved?.Invoke(s); - store.BeatmapHidden += b => BeatmapHidden?.Invoke(b); - store.BeatmapRestored += b => BeatmapRestored?.Invoke(b); - return store; - } } } diff --git a/osu.Game/Beatmaps/BeatmapStore.cs b/osu.Game/Beatmaps/BeatmapStore.cs index f2c3eddec9..67a2bbbd90 100644 --- a/osu.Game/Beatmaps/BeatmapStore.cs +++ b/osu.Game/Beatmaps/BeatmapStore.cs @@ -20,7 +20,7 @@ namespace osu.Game.Beatmaps public event Action BeatmapHidden; public event Action BeatmapRestored; - public BeatmapStore(Func factory) + public BeatmapStore(DatabaseContextFactory factory) : base(factory) { } @@ -31,24 +31,25 @@ namespace osu.Game.Beatmaps /// The beatmap to add. public void Add(BeatmapSetInfo beatmapSet) { - var context = GetContext(); - - foreach (var beatmap in beatmapSet.Beatmaps.Where(b => b.Metadata != null)) + using (var db = ContextFactory.GetForWrite()) { - // If we detect a new metadata object it'll be attached to the current context so it can be reused - // to prevent duplicate entries when persisting. To accomplish this we look in the cache (.Local) - // of the corresponding table (.Set()) for matching entries to our criteria. - var contextMetadata = context.Set().Local.SingleOrDefault(e => e.Equals(beatmap.Metadata)); - if (contextMetadata != null) - beatmap.Metadata = contextMetadata; - else - context.BeatmapMetadata.Attach(beatmap.Metadata); + var context = db.Context; + + foreach (var beatmap in beatmapSet.Beatmaps.Where(b => b.Metadata != null)) + { + // If we detect a new metadata object it'll be attached to the current context so it can be reused + // to prevent duplicate entries when persisting. To accomplish this we look in the cache (.Local) + // of the corresponding table (.Set()) for matching entries to our criteria. + var contextMetadata = context.Set().Local.SingleOrDefault(e => e.Equals(beatmap.Metadata)); + if (contextMetadata != null) + beatmap.Metadata = contextMetadata; + else + context.BeatmapMetadata.Attach(beatmap.Metadata); + } + + context.BeatmapSetInfo.Attach(beatmapSet); + BeatmapSetAdded?.Invoke(beatmapSet); } - - context.BeatmapSetInfo.Attach(beatmapSet); - context.SaveChanges(); - - BeatmapSetAdded?.Invoke(beatmapSet); } /// @@ -59,10 +60,8 @@ namespace osu.Game.Beatmaps { BeatmapSetRemoved?.Invoke(beatmapSet); - var context = GetContext(); - - context.BeatmapSetInfo.Update(beatmapSet); - context.SaveChanges(); + using (var usage = ContextFactory.GetForWrite()) + usage.Context.BeatmapSetInfo.Update(beatmapSet); BeatmapSetAdded?.Invoke(beatmapSet); } @@ -74,13 +73,13 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Delete(BeatmapSetInfo beatmapSet) { - var context = GetContext(); + using ( ContextFactory.GetForWrite()) + { + Refresh(ref beatmapSet, BeatmapSets); - Refresh(ref beatmapSet, BeatmapSets); - - if (beatmapSet.DeletePending) return false; - beatmapSet.DeletePending = true; - context.SaveChanges(); + if (beatmapSet.DeletePending) return false; + beatmapSet.DeletePending = true; + } BeatmapSetRemoved?.Invoke(beatmapSet); return true; @@ -93,13 +92,13 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Undelete(BeatmapSetInfo beatmapSet) { - var context = GetContext(); + using ( ContextFactory.GetForWrite()) + { + Refresh(ref beatmapSet, BeatmapSets); - Refresh(ref beatmapSet, BeatmapSets); - - if (!beatmapSet.DeletePending) return false; - beatmapSet.DeletePending = false; - context.SaveChanges(); + if (!beatmapSet.DeletePending) return false; + beatmapSet.DeletePending = false; + } BeatmapSetAdded?.Invoke(beatmapSet); return true; @@ -112,15 +111,16 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Hide(BeatmapInfo beatmap) { - var context = GetContext(); + using (ContextFactory.GetForWrite()) + { + Refresh(ref beatmap, Beatmaps); - Refresh(ref beatmap, Beatmaps); + if (beatmap.Hidden) return false; + beatmap.Hidden = true; - if (beatmap.Hidden) return false; - beatmap.Hidden = true; - context.SaveChanges(); + BeatmapHidden?.Invoke(beatmap); + } - BeatmapHidden?.Invoke(beatmap); return true; } @@ -131,13 +131,13 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Restore(BeatmapInfo beatmap) { - var context = GetContext(); + using (ContextFactory.GetForWrite()) + { + Refresh(ref beatmap, Beatmaps); - Refresh(ref beatmap, Beatmaps); - - if (!beatmap.Hidden) return false; - beatmap.Hidden = false; - context.SaveChanges(); + if (!beatmap.Hidden) return false; + beatmap.Hidden = false; + } BeatmapRestored?.Invoke(beatmap); return true; @@ -147,34 +147,36 @@ namespace osu.Game.Beatmaps public void Cleanup(Expression> query) { - var context = GetContext(); + using (var usage = ContextFactory.GetForWrite()) + { + var context = usage.Context; - var purgeable = context.BeatmapSetInfo.Where(s => s.DeletePending && !s.Protected) - .Where(query) - .Include(s => s.Beatmaps).ThenInclude(b => b.Metadata) - .Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty) - .Include(s => s.Metadata); + var purgeable = context.BeatmapSetInfo.Where(s => s.DeletePending && !s.Protected) + .Where(query) + .Include(s => s.Beatmaps).ThenInclude(b => b.Metadata) + .Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty) + .Include(s => s.Metadata); - // metadata is M-N so we can't rely on cascades - context.BeatmapMetadata.RemoveRange(purgeable.Select(s => s.Metadata)); - context.BeatmapMetadata.RemoveRange(purgeable.SelectMany(s => s.Beatmaps.Select(b => b.Metadata).Where(m => m != null))); + // metadata is M-N so we can't rely on cascades + context.BeatmapMetadata.RemoveRange(purgeable.Select(s => s.Metadata)); + context.BeatmapMetadata.RemoveRange(purgeable.SelectMany(s => s.Beatmaps.Select(b => b.Metadata).Where(m => m != null))); - // todo: we can probably make cascades work here with a FK in BeatmapDifficulty. just make to make it work correctly. - context.BeatmapDifficulty.RemoveRange(purgeable.SelectMany(s => s.Beatmaps.Select(b => b.BaseDifficulty))); + // todo: we can probably make cascades work here with a FK in BeatmapDifficulty. just make to make it work correctly. + context.BeatmapDifficulty.RemoveRange(purgeable.SelectMany(s => s.Beatmaps.Select(b => b.BaseDifficulty))); - // cascades down to beatmaps. - context.BeatmapSetInfo.RemoveRange(purgeable); - context.SaveChanges(); + // cascades down to beatmaps. + context.BeatmapSetInfo.RemoveRange(purgeable); + } } - public IQueryable BeatmapSets => GetContext().BeatmapSetInfo + public IQueryable BeatmapSets => ContextFactory.Get().BeatmapSetInfo .Include(s => s.Metadata) .Include(s => s.Beatmaps).ThenInclude(s => s.Ruleset) .Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty) .Include(s => s.Beatmaps).ThenInclude(b => b.Metadata) .Include(s => s.Files).ThenInclude(f => f.FileInfo); - public IQueryable Beatmaps => GetContext().BeatmapInfo + public IQueryable Beatmaps => ContextFactory.Get().BeatmapInfo .Include(b => b.BeatmapSet).ThenInclude(s => s.Metadata) .Include(b => b.BeatmapSet).ThenInclude(s => s.Files).ThenInclude(f => f.FileInfo) .Include(b => b.Metadata) diff --git a/osu.Game/Configuration/SettingsStore.cs b/osu.Game/Configuration/SettingsStore.cs index 9b18151c84..7b66002a79 100644 --- a/osu.Game/Configuration/SettingsStore.cs +++ b/osu.Game/Configuration/SettingsStore.cs @@ -12,8 +12,8 @@ namespace osu.Game.Configuration { public event Action SettingChanged; - public SettingsStore(Func createContext) - : base(createContext) + public SettingsStore(DatabaseContextFactory contextFactory) + : base(contextFactory) { } @@ -24,19 +24,16 @@ namespace osu.Game.Configuration /// An optional variant. /// public List Query(int? rulesetId = null, int? variant = null) => - GetContext().DatabasedSetting.Where(b => b.RulesetID == rulesetId && b.Variant == variant).ToList(); + ContextFactory.Get().DatabasedSetting.Where(b => b.RulesetID == rulesetId && b.Variant == variant).ToList(); public void Update(DatabasedSetting setting) { - var context = GetContext(); - - var newValue = setting.Value; - - Refresh(ref setting); - - setting.Value = newValue; - - context.SaveChanges(); + using (ContextFactory.GetForWrite()) + { + var newValue = setting.Value; + Refresh(ref setting); + setting.Value = newValue; + } SettingChanged?.Invoke(); } diff --git a/osu.Game/Database/DatabaseBackedStore.cs b/osu.Game/Database/DatabaseBackedStore.cs index ec9967e097..da66167b14 100644 --- a/osu.Game/Database/DatabaseBackedStore.cs +++ b/osu.Game/Database/DatabaseBackedStore.cs @@ -1,10 +1,8 @@ // Copyright (c) 2007-2018 ppy Pty Ltd . // Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE -using System; using System.Collections.Generic; using System.Linq; -using System.Threading; using Microsoft.EntityFrameworkCore; using osu.Framework.Platform; @@ -17,9 +15,7 @@ namespace osu.Game.Database /// /// Create a new instance (separate from the shared context via for performing isolated operations. /// - protected readonly Func CreateContext; - - private readonly ThreadLocal queryContext; + protected readonly DatabaseContextFactory ContextFactory; /// /// Refresh an instance potentially from a different thread with a local context-tracked instance. @@ -29,33 +25,27 @@ namespace osu.Game.Database /// A valid EF-stored type. protected virtual void Refresh(ref T obj, IEnumerable lookupSource = null) where T : class, IHasPrimaryKey { - var context = GetContext(); - - if (context.Entry(obj).State != EntityState.Detached) return; - - var id = obj.ID; - var foundObject = lookupSource?.SingleOrDefault(t => t.ID == id) ?? context.Find(id); - if (foundObject != null) + using (var usage = ContextFactory.GetForWrite()) { - obj = foundObject; - context.Entry(obj).Reload(); + var context = usage.Context; + + if (context.Entry(obj).State != EntityState.Detached) return; + + var id = obj.ID; + var foundObject = lookupSource?.SingleOrDefault(t => t.ID == id) ?? context.Find(id); + if (foundObject != null) + { + obj = foundObject; + context.Entry(obj).Reload(); + } + else + context.Add(obj); } - else - context.Add(obj); } - /// - /// Retrieve a shared context for performing lookups (or write operations on the update thread, for now). - /// - protected OsuDbContext GetContext() => queryContext.Value; - - protected DatabaseBackedStore(Func createContext, Storage storage = null) + protected DatabaseBackedStore(DatabaseContextFactory contextFactory, Storage storage = null) { - CreateContext = createContext; - - // todo: while this seems to work quite well, we need to consider that contexts could enter a state where they are never cleaned up. - queryContext = new ThreadLocal(CreateContext); - + ContextFactory = contextFactory; Storage = storage; } diff --git a/osu.Game/Database/DatabaseContextFactory.cs b/osu.Game/Database/DatabaseContextFactory.cs index b1917d92c4..c092ed377f 100644 --- a/osu.Game/Database/DatabaseContextFactory.cs +++ b/osu.Game/Database/DatabaseContextFactory.cs @@ -1,6 +1,7 @@ // Copyright (c) 2007-2018 ppy Pty Ltd . // Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE +using System.Threading; using osu.Framework.Platform; namespace osu.Game.Database @@ -11,17 +12,70 @@ namespace osu.Game.Database private const string database_name = @"client"; + private ThreadLocal threadContexts; + + private readonly object writeLock = new object(); + + private OsuDbContext writeContext; + + private volatile int currentWriteUsages; + public DatabaseContextFactory(GameHost host) { this.host = host; + recycleThreadContexts(); } - public OsuDbContext GetContext() => new OsuDbContext(host.Storage.GetDatabaseConnectionString(database_name)); + /// + /// Get a context for read-only usage. + /// + public OsuDbContext Get() => threadContexts.Value; + + /// + /// Request a context for write usage. Can be consumed in a nested fashion (and will return the same underlying context). + /// This method may block if a write is already active on a different thread. + /// + /// A usage containing a usable context. + public DatabaseWriteUsage GetForWrite() + { + lock (writeLock) + { + var usage = new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted); + Interlocked.Increment(ref currentWriteUsages); + return usage; + } + } + + private void usageCompleted(DatabaseWriteUsage usage) + { + int usages = Interlocked.Decrement(ref currentWriteUsages); + if (usages == 0) + { + writeContext.Dispose(); + writeContext = null; + + // once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches. + recycleThreadContexts(); + } + } + + private void recycleThreadContexts() => threadContexts = new ThreadLocal(CreateContext); + + protected virtual OsuDbContext CreateContext() + { + var ctx = new OsuDbContext(host.Storage.GetDatabaseConnectionString(database_name)); + ctx.Database.AutoTransactionsEnabled = false; + + return ctx; + } public void ResetDatabase() { - // todo: we probably want to make sure there are no active contexts before performing this operation. - host.Storage.DeleteDatabase(database_name); + lock (writeLock) + { + recycleThreadContexts(); + host.Storage.DeleteDatabase(database_name); + } } } } diff --git a/osu.Game/Database/DatabaseWriteUsage.cs b/osu.Game/Database/DatabaseWriteUsage.cs new file mode 100644 index 0000000000..0dc5a4cfe9 --- /dev/null +++ b/osu.Game/Database/DatabaseWriteUsage.cs @@ -0,0 +1,28 @@ +// Copyright (c) 2007-2018 ppy Pty Ltd . +// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE + +using System; +using Microsoft.EntityFrameworkCore.Storage; + +namespace osu.Game.Database +{ + public class DatabaseWriteUsage : IDisposable + { + public readonly OsuDbContext Context; + private readonly IDbContextTransaction transaction; + private readonly Action usageCompleted; + + public DatabaseWriteUsage(OsuDbContext context, Action onCompleted) + { + Context = context; + transaction = Context.BeginTransaction(); + usageCompleted = onCompleted; + } + + public void Dispose() + { + Context.SaveChanges(transaction); + usageCompleted?.Invoke(this); + } + } +} diff --git a/osu.Game/Database/SingletonContextFactory.cs b/osu.Game/Database/SingletonContextFactory.cs new file mode 100644 index 0000000000..88a43dc836 --- /dev/null +++ b/osu.Game/Database/SingletonContextFactory.cs @@ -0,0 +1,21 @@ +// Copyright (c) 2007-2018 ppy Pty Ltd . +// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE + +namespace osu.Game.Database +{ + public class SingletonContextFactory : DatabaseContextFactory + { + private readonly OsuDbContext context; + + public SingletonContextFactory(OsuDbContext context) + : base(null) + { + this.context = context; + } + + protected override OsuDbContext CreateContext() + { + return context; + } + } +} diff --git a/osu.Game/IO/FileStore.cs b/osu.Game/IO/FileStore.cs index 31c608a5f4..1bfe4db81a 100644 --- a/osu.Game/IO/FileStore.cs +++ b/osu.Game/IO/FileStore.cs @@ -21,86 +21,91 @@ namespace osu.Game.IO public new Storage Storage => base.Storage; - public FileStore(Func createContext, Storage storage) : base(createContext, storage.GetStorageForDirectory(@"files")) + public FileStore(DatabaseContextFactory contextFactory, Storage storage) : base(contextFactory, storage.GetStorageForDirectory(@"files")) { Store = new StorageBackedResourceStore(Storage); } public FileInfo Add(Stream data, bool reference = true) { - var context = GetContext(); - - string hash = data.ComputeSHA2Hash(); - - var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash); - - var info = existing ?? new FileInfo { Hash = hash }; - - string path = info.StoragePath; - - // we may be re-adding a file to fix missing store entries. - if (!Storage.Exists(path)) + using (var usage = ContextFactory.GetForWrite()) { - data.Seek(0, SeekOrigin.Begin); + var context = usage.Context; - using (var output = Storage.GetStream(path, FileAccess.Write)) - data.CopyTo(output); + string hash = data.ComputeSHA2Hash(); - data.Seek(0, SeekOrigin.Begin); + var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash); + + var info = existing ?? new FileInfo { Hash = hash }; + + string path = info.StoragePath; + + // we may be re-adding a file to fix missing store entries. + if (!Storage.Exists(path)) + { + data.Seek(0, SeekOrigin.Begin); + + using (var output = Storage.GetStream(path, FileAccess.Write)) + data.CopyTo(output); + + data.Seek(0, SeekOrigin.Begin); + } + + if (reference || existing == null) + Reference(info); + + return info; } - - if (reference || existing == null) - Reference(info); - - return info; } - public void Reference(params FileInfo[] files) => reference(GetContext(), files); - - private void reference(OsuDbContext context, FileInfo[] files) + public void Reference(params FileInfo[] files) { - foreach (var f in files.GroupBy(f => f.ID)) + using (var usage = ContextFactory.GetForWrite()) { - var refetch = context.Find(f.First().ID) ?? f.First(); - refetch.ReferenceCount += f.Count(); - context.FileInfo.Update(refetch); - } + var context = usage.Context; - context.SaveChanges(); + foreach (var f in files.GroupBy(f => f.ID)) + { + var refetch = context.Find(f.First().ID) ?? f.First(); + refetch.ReferenceCount += f.Count(); + context.FileInfo.Update(refetch); + } + } } - public void Dereference(params FileInfo[] files) => dereference(GetContext(), files); - - private void dereference(OsuDbContext context, FileInfo[] files) + public void Dereference(params FileInfo[] files) { - foreach (var f in files.GroupBy(f => f.ID)) + using (var usage = ContextFactory.GetForWrite()) { - var refetch = context.FileInfo.Find(f.Key); - refetch.ReferenceCount -= f.Count(); - context.FileInfo.Update(refetch); + var context = usage.Context; + foreach (var f in files.GroupBy(f => f.ID)) + { + var refetch = context.FileInfo.Find(f.Key); + refetch.ReferenceCount -= f.Count(); + context.FileInfo.Update(refetch); + } } - - context.SaveChanges(); } public override void Cleanup() { - var context = GetContext(); - - foreach (var f in context.FileInfo.Where(f => f.ReferenceCount < 1)) + using (var usage = ContextFactory.GetForWrite()) { - try + var context = usage.Context; + + foreach (var f in context.FileInfo.Where(f => f.ReferenceCount < 1)) { - Storage.Delete(f.StoragePath); - context.FileInfo.Remove(f); - } - catch (Exception e) - { - Logger.Error(e, $@"Could not delete beatmap {f}"); + try + { + Storage.Delete(f.StoragePath); + context.FileInfo.Remove(f); + } + catch (Exception e) + { + Logger.Error(e, $@"Could not delete beatmap {f}"); + } } } - - context.SaveChanges(); } } } diff --git a/osu.Game/Input/KeyBindingStore.cs b/osu.Game/Input/KeyBindingStore.cs index 92159ab491..4aad684959 100644 --- a/osu.Game/Input/KeyBindingStore.cs +++ b/osu.Game/Input/KeyBindingStore.cs @@ -16,14 +16,17 @@ namespace osu.Game.Input { public event Action KeyBindingChanged; - public KeyBindingStore(Func createContext, RulesetStore rulesets, Storage storage = null) - : base(createContext, storage) + public KeyBindingStore(DatabaseContextFactory contextFactory, RulesetStore rulesets, Storage storage = null) + : base(contextFactory, storage) { - foreach (var info in rulesets.AvailableRulesets) + using (ContextFactory.GetForWrite()) { - var ruleset = info.CreateInstance(); - foreach (var variant in ruleset.AvailableVariants) - insertDefaults(ruleset.GetDefaultKeyBindings(variant), info.ID, variant); + foreach (var info in rulesets.AvailableRulesets) + { + var ruleset = info.CreateInstance(); + foreach (var variant in ruleset.AvailableVariants) + insertDefaults(ruleset.GetDefaultKeyBindings(variant), info.ID, variant); + } } } @@ -31,10 +34,10 @@ namespace osu.Game.Input private void insertDefaults(IEnumerable defaults, int? rulesetId = null, int? variant = null) { - var context = GetContext(); - - using (var transaction = context.BeginTransaction()) + using (var usage = ContextFactory.GetForWrite()) { + var context = usage.Context; + // compare counts in database vs defaults foreach (var group in defaults.GroupBy(k => k.Action)) { @@ -54,8 +57,6 @@ namespace osu.Game.Input Variant = variant }); } - - context.SaveChanges(transaction); } } @@ -66,19 +67,16 @@ namespace osu.Game.Input /// An optional variant. /// public List Query(int? rulesetId = null, int? variant = null) => - GetContext().DatabasedKeyBinding.Where(b => b.RulesetID == rulesetId && b.Variant == variant).ToList(); + ContextFactory.Get().DatabasedKeyBinding.Where(b => b.RulesetID == rulesetId && b.Variant == variant).ToList(); public void Update(KeyBinding keyBinding) { - var dbKeyBinding = (DatabasedKeyBinding)keyBinding; - - var context = GetContext(); - - Refresh(ref dbKeyBinding); - - dbKeyBinding.KeyCombination = keyBinding.KeyCombination; - - context.SaveChanges(); + using (ContextFactory.GetForWrite()) + { + var dbKeyBinding = (DatabasedKeyBinding)keyBinding; + Refresh(ref dbKeyBinding); + dbKeyBinding.KeyCombination = keyBinding.KeyCombination; + } KeyBindingChanged?.Invoke(); } diff --git a/osu.Game/OsuGameBase.cs b/osu.Game/OsuGameBase.cs index a7eac27056..505577416d 100644 --- a/osu.Game/OsuGameBase.cs +++ b/osu.Game/OsuGameBase.cs @@ -106,12 +106,12 @@ namespace osu.Game Token = LocalConfig.Get(OsuSetting.Token) }); - dependencies.Cache(RulesetStore = new RulesetStore(contextFactory.GetContext)); - dependencies.Cache(FileStore = new FileStore(contextFactory.GetContext, Host.Storage)); - dependencies.Cache(BeatmapManager = new BeatmapManager(Host.Storage, contextFactory.GetContext, RulesetStore, API, Host)); - dependencies.Cache(ScoreStore = new ScoreStore(Host.Storage, contextFactory.GetContext, Host, BeatmapManager, RulesetStore)); - dependencies.Cache(KeyBindingStore = new KeyBindingStore(contextFactory.GetContext, RulesetStore)); - dependencies.Cache(SettingsStore = new SettingsStore(contextFactory.GetContext)); + dependencies.Cache(RulesetStore = new RulesetStore(contextFactory)); + dependencies.Cache(FileStore = new FileStore(contextFactory, Host.Storage)); + dependencies.Cache(BeatmapManager = new BeatmapManager(Host.Storage, contextFactory, RulesetStore, API, Host)); + dependencies.Cache(ScoreStore = new ScoreStore(Host.Storage, contextFactory, Host, BeatmapManager, RulesetStore)); + dependencies.Cache(KeyBindingStore = new KeyBindingStore(contextFactory, RulesetStore)); + dependencies.Cache(SettingsStore = new SettingsStore(contextFactory)); dependencies.Cache(new OsuColour()); //this completely overrides the framework default. will need to change once we make a proper FontStore. @@ -179,8 +179,8 @@ namespace osu.Game { try { - using (var context = contextFactory.GetContext()) - context.Migrate(); + using (var db = contextFactory.GetForWrite()) + db.Context.Migrate(); } catch (MigrationFailedException e) { @@ -191,8 +191,8 @@ namespace osu.Game contextFactory.ResetDatabase(); Logger.Log("Database purged successfully.", LoggingTarget.Database, LogLevel.Important); - using (var context = contextFactory.GetContext()) - context.Migrate(); + using (var db = contextFactory.GetForWrite()) + db.Context.Migrate(); } } diff --git a/osu.Game/Rulesets/RulesetStore.cs b/osu.Game/Rulesets/RulesetStore.cs index 01e3b6848f..f66a126211 100644 --- a/osu.Game/Rulesets/RulesetStore.cs +++ b/osu.Game/Rulesets/RulesetStore.cs @@ -25,7 +25,7 @@ namespace osu.Game.Rulesets loadRulesetFromFile(file); } - public RulesetStore(Func factory) + public RulesetStore(DatabaseContextFactory factory) : base(factory) { AddMissingRulesets(); @@ -56,47 +56,50 @@ namespace osu.Game.Rulesets protected void AddMissingRulesets() { - var context = GetContext(); - - var instances = loaded_assemblies.Values.Select(r => (Ruleset)Activator.CreateInstance(r, (RulesetInfo)null)).ToList(); - - //add all legacy modes in correct order - foreach (var r in instances.Where(r => r.LegacyID >= 0).OrderBy(r => r.LegacyID)) + using (var usage = ContextFactory.GetForWrite()) { - if (context.RulesetInfo.SingleOrDefault(rsi => rsi.ID == r.RulesetInfo.ID) == null) - context.RulesetInfo.Add(r.RulesetInfo); - } + var context = usage.Context; - context.SaveChanges(); + var instances = loaded_assemblies.Values.Select(r => (Ruleset)Activator.CreateInstance(r, (RulesetInfo)null)).ToList(); - //add any other modes - foreach (var r in instances.Where(r => r.LegacyID < 0)) - if (context.RulesetInfo.FirstOrDefault(ri => ri.InstantiationInfo == r.RulesetInfo.InstantiationInfo) == null) - context.RulesetInfo.Add(r.RulesetInfo); - - context.SaveChanges(); - - //perform a consistency check - foreach (var r in context.RulesetInfo) - { - try + //add all legacy modes in correct order + foreach (var r in instances.Where(r => r.LegacyID >= 0).OrderBy(r => r.LegacyID)) { - var instance = r.CreateInstance(); - - r.Name = instance.Description; - r.ShortName = instance.ShortName; - - r.Available = true; + if (context.RulesetInfo.SingleOrDefault(rsi => rsi.ID == r.RulesetInfo.ID) == null) + context.RulesetInfo.Add(r.RulesetInfo); } - catch + + context.SaveChanges(); + + //add any other modes + foreach (var r in instances.Where(r => r.LegacyID < 0)) + if (context.RulesetInfo.FirstOrDefault(ri => ri.InstantiationInfo == r.RulesetInfo.InstantiationInfo) == null) + context.RulesetInfo.Add(r.RulesetInfo); + + context.SaveChanges(); + + //perform a consistency check + foreach (var r in context.RulesetInfo) { - r.Available = false; + try + { + var instance = r.CreateInstance(); + + r.Name = instance.Description; + r.ShortName = instance.ShortName; + + r.Available = true; + } + catch + { + r.Available = false; + } } + + context.SaveChanges(); + + AvailableRulesets = context.RulesetInfo.Where(r => r.Available).ToList(); } - - context.SaveChanges(); - - AvailableRulesets = context.RulesetInfo.Where(r => r.Available).ToList(); } private static void loadRulesetFromFile(string file) diff --git a/osu.Game/Rulesets/Scoring/ScoreStore.cs b/osu.Game/Rulesets/Scoring/ScoreStore.cs index d21ca79736..8bde2747a2 100644 --- a/osu.Game/Rulesets/Scoring/ScoreStore.cs +++ b/osu.Game/Rulesets/Scoring/ScoreStore.cs @@ -1,7 +1,6 @@ // Copyright (c) 2007-2018 ppy Pty Ltd . // Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE -using System; using System.Collections.Generic; using System.IO; using osu.Framework.Platform; @@ -27,7 +26,7 @@ namespace osu.Game.Rulesets.Scoring // ReSharper disable once NotAccessedField.Local (we should keep a reference to this so it is not finalised) private ScoreIPCChannel ipc; - public ScoreStore(Storage storage, Func factory, IIpcHost importHost = null, BeatmapManager beatmaps = null, RulesetStore rulesets = null) : base(factory) + public ScoreStore(Storage storage, DatabaseContextFactory factory, IIpcHost importHost = null, BeatmapManager beatmaps = null, RulesetStore rulesets = null) : base(factory) { this.storage = storage; this.beatmaps = beatmaps; diff --git a/osu.Game/osu.Game.csproj b/osu.Game/osu.Game.csproj index c16767c02c..71f1629c19 100644 --- a/osu.Game/osu.Game.csproj +++ b/osu.Game/osu.Game.csproj @@ -275,7 +275,9 @@ + +