1
0
mirror of https://github.com/ppy/osu.git synced 2024-12-14 07:42:57 +08:00

Merge pull request #15494 from Tollii/beatmap-cancellation-token

Add support for cancellation tokens for beatmap difficulty calculation
This commit is contained in:
Dan Balasescu 2021-11-19 10:54:32 +09:00 committed by GitHub
commit 6ebe54b183
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 255 additions and 119 deletions

View File

@ -0,0 +1,114 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Moq;
using NUnit.Framework;
using osu.Game.Beatmaps;
using osu.Game.Rulesets;
using osu.Game.Rulesets.Objects;
using osu.Game.Rulesets.Osu;
using osu.Game.Rulesets.Osu.Beatmaps;
namespace osu.Game.Tests.Beatmaps
{
[TestFixture]
public class WorkingBeatmapTest
{
[Test]
public void TestGetPlayableSuccess()
{
var working = new TestNeverLoadsWorkingBeatmap();
working.ResetEvent.Set();
Assert.NotNull(working.GetPlayableBeatmap(new OsuRuleset().RulesetInfo));
}
[Test]
public void TestGetPlayableCancellationToken()
{
var working = new TestNeverLoadsWorkingBeatmap();
var cts = new CancellationTokenSource();
var loadStarted = new ManualResetEventSlim();
var loadCompleted = new ManualResetEventSlim();
Task.Factory.StartNew(() =>
{
loadStarted.Set();
Assert.Throws<OperationCanceledException>(() => working.GetPlayableBeatmap(new OsuRuleset().RulesetInfo, cancellationToken: cts.Token));
loadCompleted.Set();
}, TaskCreationOptions.LongRunning);
Assert.IsTrue(loadStarted.Wait(10000));
cts.Cancel();
Assert.IsTrue(loadCompleted.Wait(10000));
working.ResetEvent.Set();
}
[Test]
public void TestGetPlayableDefaultTimeout()
{
var working = new TestNeverLoadsWorkingBeatmap();
Assert.Throws<OperationCanceledException>(() => working.GetPlayableBeatmap(new OsuRuleset().RulesetInfo));
working.ResetEvent.Set();
}
[Test]
public void TestGetPlayableRulesetLoadFailure()
{
var working = new TestWorkingBeatmap(new Beatmap());
// by default mocks return nulls if not set up, which is actually desired here to simulate a ruleset load failure scenario.
var ruleset = new Mock<IRulesetInfo>();
Assert.Throws<RulesetLoadException>(() => working.GetPlayableBeatmap(ruleset.Object));
}
public class TestNeverLoadsWorkingBeatmap : TestWorkingBeatmap
{
public ManualResetEventSlim ResetEvent = new ManualResetEventSlim();
public TestNeverLoadsWorkingBeatmap()
: base(new Beatmap())
{
}
protected override IBeatmapConverter CreateBeatmapConverter(IBeatmap beatmap, Ruleset ruleset) => new TestConverter(beatmap, ResetEvent);
public class TestConverter : IBeatmapConverter
{
private readonly ManualResetEventSlim resetEvent;
public TestConverter(IBeatmap beatmap, ManualResetEventSlim resetEvent)
{
this.resetEvent = resetEvent;
Beatmap = beatmap;
}
public event Action<HitObject, IEnumerable<HitObject>> ObjectConverted;
protected virtual void OnObjectConverted(HitObject arg1, IEnumerable<HitObject> arg2) => ObjectConverted?.Invoke(arg1, arg2);
public IBeatmap Beatmap { get; }
public bool CanConvert() => true;
public IBeatmap Convert(CancellationToken cancellationToken = default)
{
resetEvent.Wait(cancellationToken);
return new OsuBeatmap();
}
}
}
}
}

View File

@ -140,21 +140,21 @@ namespace osu.Game.Beatmaps
return GetAsync(new DifficultyCacheLookup(localBeatmapInfo, localRulesetInfo, mods), cancellationToken);
}
protected override Task<StarDifficulty> ComputeValueAsync(DifficultyCacheLookup lookup, CancellationToken token = default)
protected override Task<StarDifficulty> ComputeValueAsync(DifficultyCacheLookup lookup, CancellationToken cancellationToken = default)
{
return Task.Factory.StartNew(() =>
{
if (CheckExists(lookup, out var existing))
return existing;
return computeDifficulty(lookup);
}, token, TaskCreationOptions.HideScheduler | TaskCreationOptions.RunContinuationsAsynchronously, updateScheduler);
return computeDifficulty(lookup, cancellationToken);
}, cancellationToken, TaskCreationOptions.HideScheduler | TaskCreationOptions.RunContinuationsAsynchronously, updateScheduler);
}
public Task<List<TimedDifficultyAttributes>> GetTimedDifficultyAttributesAsync(IWorkingBeatmap beatmap, Ruleset ruleset, Mod[] mods, CancellationToken token = default)
public Task<List<TimedDifficultyAttributes>> GetTimedDifficultyAttributesAsync(IWorkingBeatmap beatmap, Ruleset ruleset, Mod[] mods, CancellationToken cancellationToken = default)
{
return Task.Factory.StartNew(() => ruleset.CreateDifficultyCalculator(beatmap).CalculateTimed(mods),
token,
return Task.Factory.StartNew(() => ruleset.CreateDifficultyCalculator(beatmap).CalculateTimed(mods, cancellationToken),
cancellationToken,
TaskCreationOptions.HideScheduler | TaskCreationOptions.RunContinuationsAsynchronously,
updateScheduler);
}
@ -270,8 +270,9 @@ namespace osu.Game.Beatmaps
/// Computes the difficulty defined by a <see cref="DifficultyCacheLookup"/> key, and stores it to the timed cache.
/// </summary>
/// <param name="key">The <see cref="DifficultyCacheLookup"/> that defines the computation parameters.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The <see cref="StarDifficulty"/>.</returns>
private StarDifficulty computeDifficulty(in DifficultyCacheLookup key)
private StarDifficulty computeDifficulty(in DifficultyCacheLookup key, CancellationToken cancellationToken = default)
{
// In the case that the user hasn't given us a ruleset, use the beatmap's default ruleset.
var beatmapInfo = key.BeatmapInfo;
@ -283,7 +284,7 @@ namespace osu.Game.Beatmaps
Debug.Assert(ruleset != null);
var calculator = ruleset.CreateDifficultyCalculator(beatmapManager.GetWorkingBeatmap(key.BeatmapInfo));
var attributes = calculator.Calculate(key.OrderedMods);
var attributes = calculator.Calculate(key.OrderedMods, cancellationToken);
return new StarDifficulty(attributes);
}

View File

@ -1,9 +1,9 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using osu.Framework.Audio.Track;
using osu.Framework.Graphics.Textures;
using osu.Game.Rulesets;
@ -92,10 +92,10 @@ namespace osu.Game.Beatmaps
/// </summary>
/// <param name="ruleset">The <see cref="RulesetInfo"/> to create a playable <see cref="IBeatmap"/> for.</param>
/// <param name="mods">The <see cref="Mod"/>s to apply to the <see cref="IBeatmap"/>.</param>
/// <param name="timeout">The maximum length in milliseconds to wait for load to complete. Defaults to 10,000ms.</param>
/// <param name="cancellationToken">Cancellation token that cancels the beatmap loading process. If not provided, a default timeout of 10,000ms will be applied to the load process.</param>
/// <returns>The converted <see cref="IBeatmap"/>.</returns>
/// <exception cref="BeatmapInvalidForRulesetException">If <see cref="Beatmap"/> could not be converted to <paramref name="ruleset"/>.</exception>
IBeatmap GetPlayableBeatmap(IRulesetInfo ruleset, IReadOnlyList<Mod> mods = null, TimeSpan? timeout = null);
IBeatmap GetPlayableBeatmap(IRulesetInfo ruleset, IReadOnlyList<Mod> mods = null, CancellationToken? cancellationToken = null);
/// <summary>
/// Load a new audio track instance for this beatmap. This should be called once before accessing <see cref="Track"/>.

View File

@ -79,100 +79,101 @@ namespace osu.Game.Beatmaps
/// <returns>The applicable <see cref="IBeatmapConverter"/>.</returns>
protected virtual IBeatmapConverter CreateBeatmapConverter(IBeatmap beatmap, Ruleset ruleset) => ruleset.CreateBeatmapConverter(beatmap);
public virtual IBeatmap GetPlayableBeatmap(IRulesetInfo ruleset, IReadOnlyList<Mod> mods = null, TimeSpan? timeout = null)
public virtual IBeatmap GetPlayableBeatmap(IRulesetInfo ruleset, IReadOnlyList<Mod> mods = null, CancellationToken? cancellationToken = null)
{
using (var cancellationSource = createCancellationTokenSource(timeout))
var token = cancellationToken ??
// don't apply the default timeout when debugger is attached (may be breakpointing / debugging).
(Debugger.IsAttached ? new CancellationToken() : new CancellationTokenSource(10000).Token);
mods ??= Array.Empty<Mod>();
var rulesetInstance = ruleset.CreateInstance();
if (rulesetInstance == null)
throw new RulesetLoadException("Creating ruleset instance failed when attempting to create playable beatmap.");
IBeatmapConverter converter = CreateBeatmapConverter(Beatmap, rulesetInstance);
// Check if the beatmap can be converted
if (Beatmap.HitObjects.Count > 0 && !converter.CanConvert())
throw new BeatmapInvalidForRulesetException($"{nameof(Beatmaps.Beatmap)} can not be converted for the ruleset (ruleset: {ruleset.InstantiationInfo}, converter: {converter}).");
// Apply conversion mods
foreach (var mod in mods.OfType<IApplicableToBeatmapConverter>())
{
mods ??= Array.Empty<Mod>();
var rulesetInstance = ruleset.CreateInstance();
if (rulesetInstance == null)
throw new RulesetLoadException("Creating ruleset instance failed when attempting to create playable beatmap.");
IBeatmapConverter converter = CreateBeatmapConverter(Beatmap, rulesetInstance);
// Check if the beatmap can be converted
if (Beatmap.HitObjects.Count > 0 && !converter.CanConvert())
throw new BeatmapInvalidForRulesetException($"{nameof(Beatmaps.Beatmap)} can not be converted for the ruleset (ruleset: {ruleset.InstantiationInfo}, converter: {converter}).");
// Apply conversion mods
foreach (var mod in mods.OfType<IApplicableToBeatmapConverter>())
{
if (cancellationSource.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
mod.ApplyToBeatmapConverter(converter);
}
// Convert
IBeatmap converted = converter.Convert(cancellationSource.Token);
// Apply conversion mods to the result
foreach (var mod in mods.OfType<IApplicableAfterBeatmapConversion>())
{
if (cancellationSource.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
mod.ApplyToBeatmap(converted);
}
// Apply difficulty mods
if (mods.Any(m => m is IApplicableToDifficulty))
{
foreach (var mod in mods.OfType<IApplicableToDifficulty>())
{
if (cancellationSource.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
mod.ApplyToDifficulty(converted.Difficulty);
}
}
IBeatmapProcessor processor = rulesetInstance.CreateBeatmapProcessor(converted);
foreach (var mod in mods.OfType<IApplicableToBeatmapProcessor>())
mod.ApplyToBeatmapProcessor(processor);
processor?.PreProcess();
// Compute default values for hitobjects, including creating nested hitobjects in-case they're needed
try
{
foreach (var obj in converted.HitObjects)
{
if (cancellationSource.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
obj.ApplyDefaults(converted.ControlPointInfo, converted.Difficulty, cancellationSource.Token);
}
}
catch (OperationCanceledException)
{
if (token.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
}
foreach (var mod in mods.OfType<IApplicableToHitObject>())
{
foreach (var obj in converted.HitObjects)
{
if (cancellationSource.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
mod.ApplyToHitObject(obj);
}
}
processor?.PostProcess();
foreach (var mod in mods.OfType<IApplicableToBeatmap>())
{
cancellationSource.Token.ThrowIfCancellationRequested();
mod.ApplyToBeatmap(converted);
}
return converted;
mod.ApplyToBeatmapConverter(converter);
}
// Convert
IBeatmap converted = converter.Convert(token);
// Apply conversion mods to the result
foreach (var mod in mods.OfType<IApplicableAfterBeatmapConversion>())
{
if (token.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
mod.ApplyToBeatmap(converted);
}
// Apply difficulty mods
if (mods.Any(m => m is IApplicableToDifficulty))
{
foreach (var mod in mods.OfType<IApplicableToDifficulty>())
{
if (token.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
mod.ApplyToDifficulty(converted.Difficulty);
}
}
IBeatmapProcessor processor = rulesetInstance.CreateBeatmapProcessor(converted);
foreach (var mod in mods.OfType<IApplicableToBeatmapProcessor>())
mod.ApplyToBeatmapProcessor(processor);
processor?.PreProcess();
// Compute default values for hitobjects, including creating nested hitobjects in-case they're needed
try
{
foreach (var obj in converted.HitObjects)
{
if (token.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
obj.ApplyDefaults(converted.ControlPointInfo, converted.Difficulty, token);
}
}
catch (OperationCanceledException)
{
throw new BeatmapLoadTimeoutException(BeatmapInfo);
}
foreach (var mod in mods.OfType<IApplicableToHitObject>())
{
foreach (var obj in converted.HitObjects)
{
if (token.IsCancellationRequested)
throw new BeatmapLoadTimeoutException(BeatmapInfo);
mod.ApplyToHitObject(obj);
}
}
processor?.PostProcess();
foreach (var mod in mods.OfType<IApplicableToBeatmap>())
{
token.ThrowIfCancellationRequested();
mod.ApplyToBeatmap(converted);
}
return converted;
}
private CancellationTokenSource loadCancellation = new CancellationTokenSource();
@ -191,15 +192,6 @@ namespace osu.Game.Beatmaps
}
}
private CancellationTokenSource createCancellationTokenSource(TimeSpan? timeout)
{
if (Debugger.IsAttached)
// ignore timeout when debugger is attached (may be breakpointing / debugging).
return new CancellationTokenSource();
return new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(10));
}
private readonly object beatmapFetchLock = new object();
private Task<IBeatmap> loadBeatmapAsync()

View File

@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using JetBrains.Annotations;
using osu.Framework.Audio.Track;
using osu.Framework.Extensions.IEnumerableExtensions;
using osu.Game.Beatmaps;
@ -35,14 +37,24 @@ namespace osu.Game.Rulesets.Difficulty
this.beatmap = beatmap;
}
/// <summary>
/// Calculates the difficulty of the beatmap with no mods applied.
/// </summary>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A structure describing the difficulty of the beatmap.</returns>
public DifficultyAttributes Calculate(CancellationToken cancellationToken = default)
=> Calculate(Array.Empty<Mod>(), cancellationToken);
/// <summary>
/// Calculates the difficulty of the beatmap using a specific mod combination.
/// </summary>
/// <param name="mods">The mods that should be applied to the beatmap.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A structure describing the difficulty of the beatmap.</returns>
public DifficultyAttributes Calculate(params Mod[] mods)
public DifficultyAttributes Calculate([NotNull] IEnumerable<Mod> mods, CancellationToken cancellationToken = default)
{
preProcess(mods);
cancellationToken.ThrowIfCancellationRequested();
preProcess(mods, cancellationToken);
var skills = CreateSkills(Beatmap, playableMods, clockRate);
@ -52,20 +64,33 @@ namespace osu.Game.Rulesets.Difficulty
foreach (var hitObject in getDifficultyHitObjects())
{
foreach (var skill in skills)
{
cancellationToken.ThrowIfCancellationRequested();
skill.ProcessInternal(hitObject);
}
}
return CreateDifficultyAttributes(Beatmap, playableMods, skills, clockRate);
}
/// <summary>
/// Calculates the difficulty of the beatmap and returns a set of <see cref="TimedDifficultyAttributes"/> representing the difficulty at every relevant time value in the beatmap.
/// Calculates the difficulty of the beatmap with no mods applied and returns a set of <see cref="TimedDifficultyAttributes"/> representing the difficulty at every relevant time value in the beatmap.
/// </summary>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The set of <see cref="TimedDifficultyAttributes"/>.</returns>
public List<TimedDifficultyAttributes> CalculateTimed(CancellationToken cancellationToken = default)
=> CalculateTimed(Array.Empty<Mod>(), cancellationToken);
/// <summary>
/// Calculates the difficulty of the beatmap using a specific mod combination and returns a set of <see cref="TimedDifficultyAttributes"/> representing the difficulty at every relevant time value in the beatmap.
/// </summary>
/// <param name="mods">The mods that should be applied to the beatmap.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The set of <see cref="TimedDifficultyAttributes"/>.</returns>
public List<TimedDifficultyAttributes> CalculateTimed(params Mod[] mods)
public List<TimedDifficultyAttributes> CalculateTimed([NotNull] IEnumerable<Mod> mods, CancellationToken cancellationToken = default)
{
preProcess(mods);
cancellationToken.ThrowIfCancellationRequested();
preProcess(mods, cancellationToken);
var attribs = new List<TimedDifficultyAttributes>();
@ -80,7 +105,10 @@ namespace osu.Game.Rulesets.Difficulty
progressiveBeatmap.HitObjects.Add(hitObject.BaseObject);
foreach (var skill in skills)
{
cancellationToken.ThrowIfCancellationRequested();
skill.ProcessInternal(hitObject);
}
attribs.Add(new TimedDifficultyAttributes(hitObject.EndTime * clockRate, CreateDifficultyAttributes(progressiveBeatmap, playableMods, skills, clockRate)));
}
@ -99,7 +127,7 @@ namespace osu.Game.Rulesets.Difficulty
if (combination is MultiMod multi)
yield return Calculate(multi.Mods);
else
yield return Calculate(combination);
yield return Calculate(combination.Yield());
}
}
@ -112,11 +140,12 @@ namespace osu.Game.Rulesets.Difficulty
/// Performs required tasks before every calculation.
/// </summary>
/// <param name="mods">The original list of <see cref="Mod"/>s.</param>
private void preProcess(Mod[] mods)
/// <param name="cancellationToken">The cancellation token.</param>
private void preProcess([NotNull] IEnumerable<Mod> mods, CancellationToken cancellationToken = default)
{
playableMods = mods.Select(m => m.DeepClone()).ToArray();
Beatmap = beatmap.GetPlayableBeatmap(ruleset, playableMods);
Beatmap = beatmap.GetPlayableBeatmap(ruleset, playableMods, cancellationToken);
var track = new TrackVirtual(10000);
playableMods.OfType<IApplicableToTrack>().ForEach(m => m.ApplyToTrack(track));

View File

@ -7,7 +7,7 @@ namespace osu.Game.Rulesets.Difficulty
{
/// <summary>
/// Wraps a <see cref="DifficultyAttributes"/> object and adds a time value for which the attribute is valid.
/// Output by <see cref="DifficultyCalculator.CalculateTimed"/>.
/// Output by DifficultyCalculator.CalculateTimed methods.
/// </summary>
public class TimedDifficultyAttributes : IComparable<TimedDifficultyAttributes>
{

View File

@ -216,7 +216,7 @@ namespace osu.Game.Screens.Play.HUD
this.gameplayBeatmap = gameplayBeatmap;
}
public override IBeatmap GetPlayableBeatmap(IRulesetInfo ruleset, IReadOnlyList<Mod> mods = null, TimeSpan? timeout = null)
public override IBeatmap GetPlayableBeatmap(IRulesetInfo ruleset, IReadOnlyList<Mod> mods = null, CancellationToken? cancellationToken = null)
=> gameplayBeatmap;
protected override IBeatmap GetBeatmap() => gameplayBeatmap;