diff --git a/osu.Game/Beatmaps/BeatmapStore.cs b/osu.Game/Beatmaps/BeatmapStore.cs index f3d3caeb0f..69aadb470e 100644 --- a/osu.Game/Beatmaps/BeatmapStore.cs +++ b/osu.Game/Beatmaps/BeatmapStore.cs @@ -29,12 +29,14 @@ namespace osu.Game.Beatmaps { if (reset) { + var context = GetContext(); + // https://stackoverflow.com/a/10450893 - Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapMetadata"); - Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapDifficulty"); - Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetInfo"); - Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetFileInfo"); - Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapInfo"); + context.Database.ExecuteSqlCommand("DELETE FROM BeatmapMetadata"); + context.Database.ExecuteSqlCommand("DELETE FROM BeatmapDifficulty"); + context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetInfo"); + context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetFileInfo"); + context.Database.ExecuteSqlCommand("DELETE FROM BeatmapInfo"); } } @@ -50,8 +52,10 @@ namespace osu.Game.Beatmaps /// The beatmap to add. public void Add(BeatmapSetInfo beatmapSet) { - Context.BeatmapSetInfo.Attach(beatmapSet); - Context.SaveChanges(); + var context = GetContext(); + + context.BeatmapSetInfo.Attach(beatmapSet); + context.SaveChanges(); BeatmapSetAdded?.Invoke(beatmapSet); } @@ -63,11 +67,13 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Delete(BeatmapSetInfo beatmapSet) { + var context = GetContext(); + if (beatmapSet.DeletePending) return false; beatmapSet.DeletePending = true; - Context.BeatmapSetInfo.Update(beatmapSet); - Context.SaveChanges(); + context.BeatmapSetInfo.Update(beatmapSet); + context.SaveChanges(); BeatmapSetRemoved?.Invoke(beatmapSet); return true; @@ -80,11 +86,13 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Undelete(BeatmapSetInfo beatmapSet) { + var context = GetContext(); + if (!beatmapSet.DeletePending) return false; beatmapSet.DeletePending = false; - Context.BeatmapSetInfo.Update(beatmapSet); - Context.SaveChanges(); + context.BeatmapSetInfo.Update(beatmapSet); + context.SaveChanges(); BeatmapSetAdded?.Invoke(beatmapSet); return true; @@ -97,11 +105,13 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Hide(BeatmapInfo beatmap) { + var context = GetContext(); + if (beatmap.Hidden) return false; beatmap.Hidden = true; - Context.BeatmapInfo.Update(beatmap); - Context.SaveChanges(); + context.BeatmapInfo.Update(beatmap); + context.SaveChanges(); BeatmapHidden?.Invoke(beatmap); return true; @@ -114,11 +124,13 @@ namespace osu.Game.Beatmaps /// Whether the beatmap's was changed. public bool Restore(BeatmapInfo beatmap) { + var context = GetContext(); + if (!beatmap.Hidden) return false; beatmap.Hidden = false; - Context.BeatmapInfo.Update(beatmap); - Context.SaveChanges(); + context.BeatmapInfo.Update(beatmap); + context.SaveChanges(); BeatmapRestored?.Invoke(beatmap); return true; @@ -126,21 +138,23 @@ namespace osu.Game.Beatmaps private void cleanupPendingDeletions() { - Context.BeatmapSetInfo.RemoveRange(Context.BeatmapSetInfo.Where(b => b.DeletePending && !b.Protected)); - Context.SaveChanges(); + var context = GetContext(); + + context.BeatmapSetInfo.RemoveRange(context.BeatmapSetInfo.Where(b => b.DeletePending && !b.Protected)); + context.SaveChanges(); } - public IEnumerable BeatmapSets => Context.BeatmapSetInfo - .Include(s => s.Metadata) - .Include(s => s.Beatmaps).ThenInclude(s => s.Ruleset) - .Include(s => s.Beatmaps).ThenInclude(b => b.Difficulty) - .Include(s => s.Beatmaps).ThenInclude(b => b.Metadata) - .Include(s => s.Files).ThenInclude(f => f.FileInfo); + public IEnumerable BeatmapSets => GetContext().BeatmapSetInfo + .Include(s => s.Metadata) + .Include(s => s.Beatmaps).ThenInclude(s => s.Ruleset) + .Include(s => s.Beatmaps).ThenInclude(b => b.Difficulty) + .Include(s => s.Beatmaps).ThenInclude(b => b.Metadata) + .Include(s => s.Files).ThenInclude(f => f.FileInfo); - public IEnumerable Beatmaps => Context.BeatmapInfo - .Include(b => b.BeatmapSet).ThenInclude(s => s.Metadata) - .Include(b => b.Metadata) - .Include(b => b.Ruleset) - .Include(b => b.Difficulty); + public IEnumerable Beatmaps => GetContext().BeatmapInfo + .Include(b => b.BeatmapSet).ThenInclude(s => s.Metadata) + .Include(b => b.Metadata) + .Include(b => b.Ruleset) + .Include(b => b.Difficulty); } } diff --git a/osu.Game/Database/DatabaseBackedStore.cs b/osu.Game/Database/DatabaseBackedStore.cs index 9d3d020250..79aea7863a 100644 --- a/osu.Game/Database/DatabaseBackedStore.cs +++ b/osu.Game/Database/DatabaseBackedStore.cs @@ -11,14 +11,12 @@ namespace osu.Game.Database { protected readonly Storage Storage; - private readonly Func contextSource; + protected readonly Func GetContext; - protected OsuDbContext Context => contextSource(); - - protected DatabaseBackedStore(Func contextSource, Storage storage = null) + protected DatabaseBackedStore(Func getContext, Storage storage = null) { Storage = storage; - this.contextSource = contextSource; + GetContext = getContext; try { diff --git a/osu.Game/IO/FileStore.cs b/osu.Game/IO/FileStore.cs index d715ccd0a7..b60d82d61c 100644 --- a/osu.Game/IO/FileStore.cs +++ b/osu.Game/IO/FileStore.cs @@ -22,7 +22,7 @@ namespace osu.Game.IO public readonly ResourceStore Store; - public FileStore(Func contextSource, Storage storage) : base(contextSource, storage) + public FileStore(Func getContext, Storage storage) : base(getContext, storage) { Store = new NamespacedResourceStore(new StorageBackedResourceStore(storage), prefix); } @@ -34,7 +34,7 @@ namespace osu.Game.IO if (Storage.ExistsDirectory(prefix)) Storage.DeleteDirectory(prefix); - Context.Database.ExecuteSqlCommand("DELETE FROM FileInfo"); + GetContext().Database.ExecuteSqlCommand("DELETE FROM FileInfo"); } } @@ -46,9 +46,11 @@ namespace osu.Game.IO public FileInfo Add(Stream data, bool reference = true) { + var context = GetContext(); + string hash = data.ComputeSHA2Hash(); - var existing = Context.FileInfo.FirstOrDefault(f => f.Hash == hash); + var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash); var info = existing ?? new FileInfo { Hash = hash }; @@ -71,38 +73,44 @@ namespace osu.Game.IO return info; } - public void Reference(params FileInfo[] files) + public void Reference(params FileInfo[] files) => reference(GetContext(), files); + + private void reference(OsuDbContext context, FileInfo[] files) { foreach (var f in files.GroupBy(f => f.ID)) { - var refetch = Context.Find(f.First().ID) ?? f.First(); + var refetch = context.Find(f.First().ID) ?? f.First(); refetch.ReferenceCount += f.Count(); - Context.FileInfo.Update(refetch); + context.FileInfo.Update(refetch); } - Context.SaveChanges(); + context.SaveChanges(); } - public void Dereference(params FileInfo[] files) + public void Dereference(params FileInfo[] files) => dereference(GetContext(), files); + + private void dereference(OsuDbContext context, FileInfo[] files) { foreach (var f in files.GroupBy(f => f.ID)) { - var refetch = Context.Find(f.First().ID); + var refetch = context.Find(f.First().ID); refetch.ReferenceCount -= f.Count(); - Context.Update(refetch); + context.Update(refetch); } - Context.SaveChanges(); + context.SaveChanges(); } private void deletePending() { - foreach (var f in Context.FileInfo.Where(f => f.ReferenceCount < 1)) + var context = GetContext(); + + foreach (var f in context.FileInfo.Where(f => f.ReferenceCount < 1)) { try { Storage.Delete(Path.Combine(prefix, f.StoragePath)); - Context.FileInfo.Remove(f); + context.FileInfo.Remove(f); } catch (Exception e) { @@ -110,7 +118,7 @@ namespace osu.Game.IO } } - Context.SaveChanges(); + context.SaveChanges(); } } } diff --git a/osu.Game/Input/KeyBindingStore.cs b/osu.Game/Input/KeyBindingStore.cs index 5c41179418..1e9a2aa22f 100644 --- a/osu.Game/Input/KeyBindingStore.cs +++ b/osu.Game/Input/KeyBindingStore.cs @@ -15,9 +15,16 @@ namespace osu.Game.Input { public class KeyBindingStore : DatabaseBackedStore { - public KeyBindingStore(Func contextSource, RulesetStore rulesets, Storage storage = null) - : base(contextSource, storage) + /// + /// As we do a lot of lookups, let's share a context between them to hopefully improve performance. + /// + private readonly OsuDbContext queryContext; + + public KeyBindingStore(Func getContext, RulesetStore rulesets, Storage storage = null) + : base(getContext, storage) { + queryContext = GetContext(); + foreach (var info in rulesets.AvailableRulesets) { var ruleset = info.CreateInstance(); @@ -31,15 +38,17 @@ namespace osu.Game.Input protected override void Prepare(bool reset = false) { if (reset) - Context.Database.ExecuteSqlCommand("DELETE FROM KeyBinding"); + GetContext().Database.ExecuteSqlCommand("DELETE FROM KeyBinding"); } private void insertDefaults(IEnumerable defaults, int? rulesetId = null, int? variant = null) { + var context = GetContext(); + // compare counts in database vs defaults foreach (var group in defaults.GroupBy(k => k.Action)) { - int count = Query(rulesetId, variant).Count(k => (int)k.Action == (int)group.Key); + int count = query(context, rulesetId, variant).Count(k => (int)k.Action == (int)group.Key); int aimCount = group.Count(); if (aimCount <= count) @@ -47,7 +56,7 @@ namespace osu.Game.Input foreach (var insertable in group.Skip(count).Take(aimCount - count)) // insert any defaults which are missing. - Context.DatabasedKeyBinding.Add(new DatabasedKeyBinding + context.DatabasedKeyBinding.Add(new DatabasedKeyBinding { KeyCombination = insertable.KeyCombination, Action = insertable.Action, @@ -56,7 +65,7 @@ namespace osu.Game.Input }); } - Context.SaveChanges(); + context.SaveChanges(); } /// @@ -65,12 +74,16 @@ namespace osu.Game.Input /// The ruleset's internal ID. /// An optional variant. /// - public IEnumerable Query(int? rulesetId = null, int? variant = null) => Context.DatabasedKeyBinding.Where(b => b.RulesetID == rulesetId && b.Variant == variant); + public IEnumerable Query(int? rulesetId = null, int? variant = null) => query(queryContext, rulesetId, variant); + + private IEnumerable query(OsuDbContext context, int? rulesetId = null, int? variant = null) => + context.DatabasedKeyBinding.Where(b => b.RulesetID == rulesetId && b.Variant == variant); public void Update(KeyBinding keyBinding) { - Context.Update(keyBinding); - Context.SaveChanges(); + var context = GetContext(); + context.Update(keyBinding); + context.SaveChanges(); } } } diff --git a/osu.Game/Rulesets/RulesetStore.cs b/osu.Game/Rulesets/RulesetStore.cs index bd3c22fc42..7d982eb39e 100644 --- a/osu.Game/Rulesets/RulesetStore.cs +++ b/osu.Game/Rulesets/RulesetStore.cs @@ -41,7 +41,7 @@ namespace osu.Game.Rulesets /// /// All available rulesets. /// - public IEnumerable AvailableRulesets => Context.RulesetInfo.Where(r => r.Available); + public IEnumerable AvailableRulesets => GetContext().RulesetInfo.Where(r => r.Available); private static Assembly currentDomain_AssemblyResolve(object sender, ResolveEventArgs args) => loaded_assemblies.Keys.FirstOrDefault(a => a.FullName == args.Name); @@ -49,9 +49,11 @@ namespace osu.Game.Rulesets protected override void Prepare(bool reset = false) { + var context = GetContext(); + if (reset) { - Context.Database.ExecuteSqlCommand("DELETE FROM RulesetInfo"); + context.Database.ExecuteSqlCommand("DELETE FROM RulesetInfo"); } var instances = loaded_assemblies.Values.Select(r => (Ruleset)Activator.CreateInstance(r, new RulesetInfo())).ToList(); @@ -60,29 +62,29 @@ namespace osu.Game.Rulesets foreach (var r in instances.Where(r => r.LegacyID >= 0).OrderBy(r => r.LegacyID)) { var rulesetInfo = createRulesetInfo(r); - if (Context.RulesetInfo.SingleOrDefault(rsi => rsi.ID == rulesetInfo.ID) == null) + if (context.RulesetInfo.SingleOrDefault(rsi => rsi.ID == rulesetInfo.ID) == null) { - Context.RulesetInfo.Add(rulesetInfo); + context.RulesetInfo.Add(rulesetInfo); } } - Context.SaveChanges(); + context.SaveChanges(); //add any other modes foreach (var r in instances.Where(r => r.LegacyID < 0)) { var us = createRulesetInfo(r); - var existing = Context.RulesetInfo.FirstOrDefault(ri => ri.InstantiationInfo == us.InstantiationInfo); + var existing = context.RulesetInfo.FirstOrDefault(ri => ri.InstantiationInfo == us.InstantiationInfo); if (existing == null) - Context.RulesetInfo.Add(us); + context.RulesetInfo.Add(us); } - Context.SaveChanges(); + context.SaveChanges(); //perform a consistency check - foreach (var r in Context.RulesetInfo) + foreach (var r in context.RulesetInfo) { try { @@ -95,7 +97,7 @@ namespace osu.Game.Rulesets } } - Context.SaveChanges(); + context.SaveChanges(); } private static void loadRulesetFromFile(string file)