diff --git a/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbabilitySkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbabilitySkill.cs new file mode 100644 index 0000000000..dde817a2e1 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbabilitySkill.cs @@ -0,0 +1,155 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; +using osu.Game.Rulesets.Difficulty.Preprocessing; +using osu.Game.Rulesets.Difficulty.Skills; +using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Utils; + +namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation +{ + public abstract class OsuProbabilitySkill : Skill + { + protected OsuProbabilitySkill(Mod[] mods) + : base(mods) + { + } + + // We assume players have a 2% chance to hit every note in the map. + // A higher value of fc_probability increases the influence of difficulty spikes, + // while a lower value increases the influence of length and consistent difficulty. + private const double fc_probability = 0.02; + + private const int bin_count = 32; + + // The number of difficulties there must be before we can be sure that binning difficulties would not change the output significantly. + private double binThreshold => 2 * bin_count; + + private readonly List difficulties = new List(); + + /// + /// Returns the strain value at . This value is calculated with or without respect to previous objects. + /// + protected abstract double StrainValueAt(DifficultyHitObject current); + + public override void Process(DifficultyHitObject current) + { + difficulties.Add(StrainValueAt(current)); + } + + protected abstract double HitProbability(double skill, double difficulty); + + private double difficultyValueExact() + { + double maxDiff = difficulties.Max(); + if (maxDiff <= 1e-10) return 0; + + const double lower_bound = 0; + double upperBoundEstimate = 3.0 * maxDiff; + + double skill = RootFinding.FindRootExpand( + skill => fcProbability(skill) - fc_probability, + lower_bound, + upperBoundEstimate, + accuracy: 1e-4); + + return skill; + + double fcProbability(double s) + { + if (s <= 0) return 0; + + return difficulties.Aggregate(1, (current, d) => current * HitProbability(s, d)); + } + } + + private double difficultyValueBinned() + { + double maxDiff = difficulties.Max(); + if (maxDiff <= 1e-10) return 0; + + var bins = Bin.CreateBins(difficulties, bin_count); + + const double lower_bound = 0; + double upperBoundEstimate = 3.0 * maxDiff; + + double skill = RootFinding.FindRootExpand( + skill => fcProbability(skill) - fc_probability, + lower_bound, + upperBoundEstimate, + accuracy: 1e-4); + + return skill; + + double fcProbability(double s) + { + if (s <= 0) return 0; + + return bins.Aggregate(1.0, (current, bin) => current * Math.Pow(HitProbability(s, bin.Difficulty), bin.Count)); + } + } + + public override double DifficultyValue() + { + if (difficulties.Count == 0) return 0; + + return difficulties.Count > binThreshold ? difficultyValueBinned() : difficultyValueExact(); + } + + /// + /// A polynomial fitted to the miss counts at each skill level. + /// + public ExpPolynomial GetMissPenaltyCurve() + { + double[] missCounts = new double[7]; + double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 }; + + ExpPolynomial missPenaltyCurve = new ExpPolynomial(); + + // If there are no notes, we just return the curve with all coefficients set to zero. + if (difficulties.Count == 0 || difficulties.Max() == 0) + return missPenaltyCurve; + + double fcSkill = DifficultyValue(); + + var bins = Bin.CreateBins(difficulties, bin_count); + + for (int i = 0; i < penalties.Length; i++) + { + if (i == 0) + { + missCounts[i] = 0; + continue; + } + + double penalizedSkill = fcSkill * penalties[i]; + + missCounts[i] = getMissCountAtSkill(penalizedSkill, bins); + } + + missPenaltyCurve.Fit(missCounts); + + return missPenaltyCurve; + } + + /// + /// Find the lowest miss count that a player with the provided would have a 2% chance of achieving or better. + /// + private double getMissCountAtSkill(double skill, List bins) + { + double maxDiff = difficulties.Max(); + + if (maxDiff == 0) + return 0; + if (skill <= 0) + return difficulties.Count; + + var poiBin = difficulties.Count > binThreshold ? new PoissonBinomial(bins, skill, HitProbability) : new PoissonBinomial(difficulties, skill, HitProbability); + + return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - fc_probability, -50, 1000, accuracy: 1e-4)); + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuStrainSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuStrainSkill.cs similarity index 97% rename from osu.Game.Rulesets.Osu/Difficulty/Skills/OsuStrainSkill.cs rename to osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuStrainSkill.cs index 6823512cef..822894a513 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuStrainSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuStrainSkill.cs @@ -8,7 +8,7 @@ using osu.Game.Rulesets.Mods; using System.Linq; using osu.Framework.Utils; -namespace osu.Game.Rulesets.Osu.Difficulty.Skills +namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation { public abstract class OsuStrainSkill : StrainSkill { diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs index a3c0209a08..e2c6b8fe99 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs @@ -8,6 +8,7 @@ using Newtonsoft.Json; using osu.Game.Beatmaps; using osu.Game.Rulesets.Difficulty; using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Utils; namespace osu.Game.Rulesets.Osu.Difficulty { @@ -19,6 +20,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty [JsonProperty("aim_difficulty")] public double AimDifficulty { get; set; } + /// + /// The difficulty corresponding to the aim skill. + /// + [JsonProperty("aim_penalty_constants")] + public ExpPolynomial AimMissPenaltyCurve { get; set; } + /// /// The difficulty corresponding to the speed skill. /// diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs index 575e03051c..f0f32c4d26 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs @@ -11,8 +11,10 @@ using osu.Game.Rulesets.Difficulty; using osu.Game.Rulesets.Difficulty.Preprocessing; using osu.Game.Rulesets.Difficulty.Skills; using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; using osu.Game.Rulesets.Osu.Difficulty.Preprocessing; using osu.Game.Rulesets.Osu.Difficulty.Skills; +using osu.Game.Rulesets.Osu.Difficulty.Utils; using osu.Game.Rulesets.Osu.Mods; using osu.Game.Rulesets.Osu.Objects; using osu.Game.Rulesets.Osu.Scoring; @@ -48,7 +50,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty double sliderFactor = aimRating > 0 ? aimRatingNoSliders / aimRating : 1; - double aimDifficultyStrainCount = ((OsuStrainSkill)skills[0]).CountTopWeightedStrains(); + ExpPolynomial aimMissPenaltyCurve = ((OsuProbabilitySkill)skills[0]).GetMissPenaltyCurve(); double speedDifficultyStrainCount = ((OsuStrainSkill)skills[2]).CountTopWeightedStrains(); if (mods.Any(m => m is OsuModTouchDevice)) @@ -103,7 +105,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty SpeedNoteCount = speedNotes, FlashlightDifficulty = flashlightRating, SliderFactor = sliderFactor, - AimDifficultStrainCount = aimDifficultyStrainCount, + AimMissPenaltyCurve = aimMissPenaltyCurve, SpeedDifficultStrainCount = speedDifficultyStrainCount, ApproachRate = preempt > 1200 ? (1800 - preempt) / 120 : (1200 - preempt) / 150 + 5, OverallDifficulty = (80 - hitWindowGreat) / 6, diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index 31b00dba2b..00b22f66ed 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -5,7 +5,9 @@ using System; using System.Collections.Generic; using System.Linq; using osu.Game.Rulesets.Difficulty; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; using osu.Game.Rulesets.Osu.Difficulty.Skills; +using osu.Game.Rulesets.Osu.Difficulty.Utils; using osu.Game.Rulesets.Osu.Mods; using osu.Game.Rulesets.Scoring; using osu.Game.Scoring; @@ -139,10 +141,9 @@ namespace osu.Game.Rulesets.Osu.Difficulty double lengthBonus = 0.95 + 0.4 * Math.Min(1.0, totalHits / 2000.0) + (totalHits > 2000 ? Math.Log10(totalHits / 2000.0) * 0.5 : 0.0); - aimValue *= lengthBonus; if (effectiveMissCount > 0) - aimValue *= calculateMissPenalty(effectiveMissCount, attributes.AimDifficultStrainCount); + aimValue *= calculateCurveFittedMissPenalty(effectiveMissCount, attributes.AimMissPenaltyCurve); double approachRateFactor = 0.0; if (attributes.ApproachRate > 10.33) @@ -206,7 +207,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty speedValue *= lengthBonus; if (effectiveMissCount > 0) - speedValue *= calculateMissPenalty(effectiveMissCount, attributes.SpeedDifficultStrainCount); + speedValue *= calculateStrainCountMissPenalty(effectiveMissCount, attributes.SpeedDifficultStrainCount); double approachRateFactor = 0.0; if (attributes.ApproachRate > 10.33) @@ -305,10 +306,17 @@ namespace osu.Game.Rulesets.Osu.Difficulty return flashlightValue; } - // Miss penalty assumes that a player will miss on the hardest parts of a map, - // so we use the amount of relatively difficult sections to adjust miss penalty - // to make it more punishing on maps with lower amount of hard sections. - private double calculateMissPenalty(double missCount, double difficultStrainCount) => 0.96 / ((missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94))) + 1); + // Due to the unavailability of miss location in PP, the following formulas assume that a player will miss on the hardest parts of a map. + + // With the curve fitted miss penalty, we use a pre-computed curve of skill levels for each miss count, raised to the power of 1.5 as + // the multiple of the exponents on star rating and PP. This power should be changed if either SR or PP begin to use a different exponent. + // As a result, this exponent is not subject to balance. + private double calculateCurveFittedMissPenalty(double missCount, ExpPolynomial curve) => Math.Pow(1 - curve.GetPenaltyAt(missCount), 1.5); + + // With the strain count miss penalty, we use the amount of relatively difficult sections to adjust the miss penalty, + // to make it more punishing on maps with lower amount of hard sections. This formula is subject to balance. + private double calculateStrainCountMissPenalty(double missCount, double difficultStrainCount) => 0.96 / (missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94)) + 1); + private double getComboScalingFactor(OsuDifficultyAttributes attributes) => attributes.MaxCombo <= 0 ? 1.0 : Math.Min(Math.Pow(scoreMaxCombo, 0.8) / Math.Pow(attributes.MaxCombo, 0.8), 1.0); private int totalHits => countGreat + countOk + countMeh + countMiss; private int totalImperfectHits => countOk + countMeh + countMiss; diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index faf91e4652..751696d349 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -4,14 +4,16 @@ using System; using osu.Game.Rulesets.Difficulty.Preprocessing; using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; using osu.Game.Rulesets.Osu.Difficulty.Evaluators; +using osu.Game.Utils; namespace osu.Game.Rulesets.Osu.Difficulty.Skills { /// /// Represents the skill required to correctly aim at every object in the map with a uniform CircleSize and normalized distances. /// - public class Aim : OsuStrainSkill + public class Aim : OsuProbabilitySkill { public Aim(Mod[] mods, bool withSliders) : base(mods) @@ -23,12 +25,18 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills private double currentStrain; - private double skillMultiplier => 25.18; + private double skillMultiplier => 132; private double strainDecayBase => 0.15; - private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000); + protected override double HitProbability(double skill, double difficulty) + { + if (difficulty <= 0) return 1; + if (skill <= 0) return 0; - protected override double CalculateInitialStrain(double time, DifficultyHitObject current) => currentStrain * strainDecay(time - current.Previous(0).StartTime); + return SpecialFunctions.Erf(skill / (Math.Sqrt(2) * difficulty)); + } + + private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000); protected override double StrainValueAt(DifficultyHitObject current) { diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs index d2c4bbb618..c4a9dd89f4 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs @@ -7,6 +7,7 @@ using osu.Game.Rulesets.Mods; using osu.Game.Rulesets.Osu.Difficulty.Evaluators; using osu.Game.Rulesets.Osu.Difficulty.Preprocessing; using System.Linq; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; namespace osu.Game.Rulesets.Osu.Difficulty.Skills { diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs new file mode 100644 index 0000000000..f3a8bdbd55 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs @@ -0,0 +1,62 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + public struct Bin + { + public double Difficulty; + public double Count; + + /// + /// Create an array of spaced bins. Count is linearly interpolated into each bin. + /// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2, + /// we will add 0.8 to the count of 3's and 0.2 to the count of 4's + /// + public static List CreateBins(List difficulties, int totalBins) + { + double maxDifficulty = difficulties.Max(); + + var binsArray = new Bin[totalBins]; + + for (int i = 0; i < totalBins; i++) + { + binsArray[i].Difficulty = maxDifficulty * (i + 1) / totalBins; + } + + foreach (double d in difficulties) + { + double binIndex = totalBins * (d / maxDifficulty) - 1; + + int lowerBound = (int)Math.Floor(binIndex); + double t = binIndex - lowerBound; + + //This can be -1, corresponding to the zero difficulty bucket. + //We don't store that since it doesn't contribute to difficulty + if (lowerBound >= 0) + { + binsArray[lowerBound].Count += (1 - t); + } + + int upperBound = lowerBound + 1; + + // this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway + if (upperBound < totalBins) + { + binsArray[upperBound].Count += t; + } + } + + var binsList = binsArray.ToList(); + + // For a slight performance improvement, we remove bins that don't contribute to difficulty. + binsList.RemoveAll(bin => bin.Count == 0); + + return binsList; + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs new file mode 100644 index 0000000000..c76002a770 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs @@ -0,0 +1,86 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; +using osu.Game.Utils; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + /// + /// Represents a polynomial fitted to the logarithm of a given set of points. + /// The resulting polynomial is exponentiated, ensuring low residuals for + /// small inputs while handling exponentially increasing trends in the data. + /// This approach is useful for modelling the results of decreasing skill with few coefficients, + /// as linear decreases in skill correspond with exponential increases in miss counts. + /// + public struct ExpPolynomial + { + private double[]? coefficients; + + // The matrix that minimizes the square error at X values [0.0, 0.30, 0.60, 0.80, 0.90, 0.95, 1.0]. + private static readonly double[][] matrix = + { + new[] { 0.0, 3.14395, 5.18439, 6.46975, 1.4638, -9.53526, 0.0 }, + new[] { 0.0, -4.85829, -8.09612, -10.4498, -3.84479, 12.1626, 0.0 } + }; + + /// + /// Computes the coefficients of a quartic polynomial, starting at 0 and ending at the highest miss count in the array. + /// + /// A list of miss counts, with X values [1, 0.95, 0.9, 0.8, 0.6, 0.3, 0] corresponding to their skill levels. + public void Fit(double[] missCounts) + { + List logMissCounts = missCounts.Select(x => Math.Log(x + 1)).ToList(); + + double endPoint = logMissCounts.Max(); + + double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 }; + + for (int i = 0; i < logMissCounts.Count; i++) + { + logMissCounts[i] -= endPoint * (1 - penalties[i]); + } + + // The precomputed matrix assumes the miss counts go in order of greatest to least. + logMissCounts.Reverse(); + + coefficients = new double[4]; + + coefficients[3] = endPoint; + + // Now we dot product the adjusted miss counts with the precomputed matrix. + for (int row = 0; row < matrix.Length; row++) + { + for (int column = 0; column < matrix[row].Length; column++) + { + coefficients[row] += matrix[row][column] * logMissCounts[column]; + } + + coefficients[2] -= coefficients[row]; + } + } + + /// + /// Solve for the miss penalty at a specified miss count. + /// + public double GetPenaltyAt(double missCount) + { + if (coefficients is null) + return 1; + + List listCoefficients = coefficients.ToList(); + listCoefficients.Add(-Math.Log(missCount + 1)); + + List xVals = SpecialFunctions.SolvePolynomialRoots(listCoefficients); + + const double max_error = 1e-7; + + // We find the largest value of x (corresponding to the penalty) found as a root of the function, with a fallback of a 100% penalty if no roots were found. + double largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault() ?? 1; + + return Math.Clamp(largestValue, 0, 1); + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs new file mode 100644 index 0000000000..52005513cb --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs @@ -0,0 +1,120 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using osu.Game.Utils; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + /// + /// Approximation of the Poisson binomial distribution: + /// https://en.wikipedia.org/wiki/Poisson_binomial_distribution + /// + /// + /// + /// For the approximation method, see "Refined Normal Approximation (RNA)" from: + /// Hong, Y. (2013). On computing the distribution function for the Poisson binomial distribution. Computational Statistics and Data Analysis, Vol. 59, pp. 41-51. + /// (https://www.researchgate.net/publication/257017356_On_computing_the_distribution_function_for_the_Poisson_binomial_distribution) + /// + /// + /// This has been verified against a reference implementation provided by the authors in the R package "poibin", + /// which can be viewed here: + /// https://rdrr.io/cran/poibin/man/poibin-package.html + /// + /// + public class PoissonBinomial + { + /// + /// The expected value of the distribution. + /// + private readonly double mu; + + /// + /// The standard deviation of the distribution. + /// + private readonly double sigma; + + /// + /// The gamma factor from equation (11) in the cited paper, pre-divided by 6 to save on re-computation. + /// + private readonly double v; + + /// + /// Creates a Poisson binomial distribution based on N trials with the provided difficulties, skill, and method for getting the miss probabilities. + /// + /// The list of difficulties in the map. + /// The skill level to get the miss probabilities with. + /// Converts difficulties and skill to miss probabilities. + public PoissonBinomial(IList difficulties, double skill, Func hitProbability) + { + double variance = 0; + double gamma = 0; + + foreach (double d in difficulties) + { + double p = 1 - hitProbability(skill, d); + + mu += p; + variance += p * (1 - p); + gamma += p * (1 - p) * (1 - 2 * p); + } + + sigma = Math.Sqrt(variance); + + v = gamma / (6 * Math.Pow(sigma, 3)); + } + + /// + /// Creates a Poisson binomial distribution based on N trials with the provided bins of difficulties, skill, and method for getting the miss probabilities. + /// + /// The bins of difficulties in the map. + /// The skill level to get the miss probabilities with. + /// /// Converts difficulties and skill to miss probabilities. + public PoissonBinomial(List bins, double skill, Func hitProbability) + { + double variance = 0; + double gamma = 0; + + foreach (Bin bin in bins) + { + double p = 1 - hitProbability(skill, bin.Difficulty); + + mu += p * bin.Count; + variance += p * (1 - p) * bin.Count; + gamma += p * (1 - p) * (1 - 2 * p) * bin.Count; + } + + sigma = Math.Sqrt(variance); + + v = gamma / (6 * Math.Pow(sigma, 3)); + } + + /// + /// Computes the value of the cumulative distribution function for this Poisson binomial distribution. + /// + /// + /// The argument of the CDF to sample the distribution for. + /// In the discrete case (when it is a whole number), this corresponds to the number + /// of successful Bernoulli trials to query the CDF for. + /// + /// + /// The value of the CDF at . + /// In the discrete case this corresponds to the probability that at most + /// Bernoulli trials ended in a success. + /// + // ReSharper disable once InconsistentNaming + public double CDF(double count) + { + if (sigma == 0) + return 1; + + double k = (count + 0.5 - mu) / sigma; + + // see equation (14) of the cited paper + double result = SpecialFunctions.NormalCdf(0, 1, k) + v * (1 - k * k) * SpecialFunctions.NormalPdf(0, 1, k); + + return Math.Clamp(result, 0, 1); + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs new file mode 100644 index 0000000000..bb01e8162c --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs @@ -0,0 +1,119 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + public static class RootFinding + { + /// + /// Finds the root of a using the Chandrupatla method, expanding the bounds if the root is not located within. + /// Expansion only occurs for the upward bound, as this function is optimized for functions of range [0, x), + /// which is useful for finding skill level (skill can never be below 0). + /// + /// The function of which to find the root. + /// The lower bound of the function inputs. + /// The upper bound of the function inputs. + /// The maximum number of iterations before the function throws an error. + /// The desired precision in which the root is returned. + /// The multiplier on the upper bound when no root is found within the provided bounds. + /// The maximum number of times the bounds of the function should increase. + public static double FindRootExpand(Func function, double guessLowerBound, double guessUpperBound, int maxIterations = 25, double accuracy = 1e-6D, double expansionFactor = 2, double maxExpansions = 32) + { + double a = guessLowerBound; + double b = guessUpperBound; + double fa = function(a); + double fb = function(b); + + int expansions = 0; + + while (fa * fb > 0) + { + a = b; + b *= expansionFactor; + fa = function(a); + fb = function(b); + + expansions++; + + if (expansions > maxExpansions) + { + throw new MaximumIterationsException("No root was found within the provided function."); + } + } + + double t = 0.5; + + for (int i = 0; i < maxIterations; i++) + { + double xt = a + t * (b - a); + double ft = function(xt); + + double c; + double fc; + + if (Math.Sign(ft) == Math.Sign(fa)) + { + c = a; + fc = fa; + } + else + { + c = b; + b = a; + fc = fb; + fb = fa; + } + + a = xt; + fa = ft; + + double xm, fm; + + if (Math.Abs(fa) < Math.Abs(fb)) + { + xm = a; + fm = fa; + } + else + { + xm = b; + fm = fb; + } + + if (fm == 0) + return xm; + + double tol = 2 * accuracy * Math.Abs(xm) + 2 * accuracy; + double tlim = tol / Math.Abs(b - c); + + if (tlim > 0.5) + { + return xm; + } + + double chi = (a - b) / (c - b); + double phi = (fa - fb) / (fc - fb); + bool iqi = phi * phi < chi && (1 - phi) * (1 - phi) < chi; + + if (iqi) + t = fa / (fb - fa) * fc / (fb - fc) + (c - a) / (b - a) * fa / (fc - fa) * fb / (fc - fb); + else + t = 0.5; + + t = Math.Min(1 - tlim, Math.Max(tlim, t)); + } + + return 0; + } + + private class MaximumIterationsException : Exception + { + public MaximumIterationsException(string message) + : base(message) + { + } + } + } +} diff --git a/osu.Game/Utils/SpecialFunctions.cs b/osu.Game/Utils/SpecialFunctions.cs index 0b0f0598bb..3d8433279f 100644 --- a/osu.Game/Utils/SpecialFunctions.cs +++ b/osu.Game/Utils/SpecialFunctions.cs @@ -13,12 +13,17 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI */ using System; +using System.Collections.Generic; +using System.Linq; +using osu.Framework.Utils; namespace osu.Game.Utils { public class SpecialFunctions { + private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d; private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d; + private const double pi_mult_2 = 6.28318530717958647692528676655900576d; /// /// ************************************** @@ -690,5 +695,280 @@ namespace osu.Game.Utils return sum; } + + /// + /// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x. + /// + /// The mean (μ) of the normal distribution. + /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// The location at which to compute the density. + /// the density at . + /// MATLAB: normpdf + public static double NormalPdf(double mean, double stddev, double x) + { + if (stddev < 0.0) + { + throw new ArgumentException("Invalid parametrization for the distribution."); + } + + double d = (x - mean) / stddev; + return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev); + } + + /// + /// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x). + /// + /// The location at which to compute the cumulative distribution function. + /// The mean (μ) of the normal distribution. + /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// the cumulative distribution at location . + /// MATLAB: normcdf + public static double NormalCdf(double mean, double stddev, double x) + { + if (stddev < 0.0) + { + throw new ArgumentException("Invalid parametrization for the distribution."); + } + + return 0.5 * Erfc((mean - x) / (stddev * sqrt2)); + } + + /// + /// Solve for the exact real roots of any polynomial up to degree 4. + /// + /// The coefficients of the polynomial, in ascending order ([1, 3, 5] -> x^2 + 3x + 5). + /// The real roots of the polynomial, and null if the root does not exist. + public static List SolvePolynomialRoots(List coefficients) + { + List xVals = new List(); + + switch (coefficients.Count) + { + case 5: + xVals = solveP4(coefficients[0], coefficients[1], coefficients[2], coefficients[3], coefficients[4], out int _).ToList(); + break; + + case 4: + xVals = solveP3(coefficients[0], coefficients[1], coefficients[2], coefficients[3], out int _).ToList(); + break; + + case 3: + xVals = solveP2(coefficients[0], coefficients[1], coefficients[2], out int _).ToList(); + break; + + case 2: + xVals = solveP2(0, coefficients[1], coefficients[2], out int _).ToList(); + break; + } + + return xVals; + } + + // https://github.com/sasamil/Quartic/blob/master/quartic.cpp + private static double?[] solveP4(double a, double b, double c, double d, double e, out int nRoots) + { + double?[] xVals = new double?[4]; + + nRoots = 0; + + if (a == 0) + { + double?[] xValsCubic = solveP3(b, c, d, e, out nRoots); + + xVals[0] = xValsCubic[0]; + xVals[1] = xValsCubic[1]; + xVals[2] = xValsCubic[2]; + xVals[3] = null; + + return xVals; + } + + b /= a; + c /= a; + d /= a; + e /= a; + + double a3 = -c; + double b3 = b * d - 4 * e; + double c3 = -b * b * e - d * d + 4 * c * e; + + double?[] x3 = solveP3(1, a3, b3, c3, out int iZeroes); + + double q1, q2, p1, p2, sqD; + + double y = x3[0]!.Value; + + // Get the y value with the highest absolute value. + if (iZeroes != 1) + { + if (Math.Abs(x3[1]!.Value) > Math.Abs(y)) + y = x3[1]!.Value; + if (Math.Abs(x3[2]!.Value) > Math.Abs(y)) + y = x3[2]!.Value; + } + + double upperD = y * y - 4 * e; + + if (Precision.AlmostEquals(upperD, 0)) + { + q1 = q2 = y * 0.5; + + upperD = b * b - 4 * (c - y); + + if (Precision.AlmostEquals(upperD, 0)) + p1 = p2 = b * 0.5; + + else + { + sqD = Math.Sqrt(upperD); + p1 = (b + sqD) * 0.5; + p2 = (b - sqD) * 0.5; + } + } + else + { + sqD = Math.Sqrt(upperD); + q1 = (y + sqD) * 0.5; + q2 = (y - sqD) * 0.5; + + p1 = (b * q1 - d) / (q1 - q2); + p2 = (d - b * q2) / (q1 - q2); + } + + // solving quadratic eq. - x^2 + p1*x + q1 = 0 + upperD = p1 * p1 - 4 * q1; + + if (upperD >= 0) + { + nRoots += 2; + + sqD = Math.Sqrt(upperD); + xVals[0] = (-p1 + sqD) * 0.5; + xVals[1] = (-p1 - sqD) * 0.5; + } + + // solving quadratic eq. - x^2 + p2*x + q2 = 0 + upperD = p2 * p2 - 4 * q2; + + if (upperD >= 0) + { + nRoots += 2; + + sqD = Math.Sqrt(upperD); + xVals[2] = (-p2 + sqD) * 0.5; + xVals[3] = (-p2 - sqD) * 0.5; + } + + // Put the null roots at the end of the array. + var nonNulls = xVals.Where(x => x != null); + var nulls = xVals.Where(x => x == null); + xVals = nonNulls.Concat(nulls).ToArray(); + + return xVals; + } + + private static double?[] solveP3(double a, double b, double c, double d, out int nRoots) + { + double?[] xVals = new double?[3]; + + nRoots = 0; + + if (a == 0) + { + double?[] xValsQuadratic = solveP2(b, c, d, out nRoots); + + xVals[0] = xValsQuadratic[0]; + xVals[1] = xValsQuadratic[1]; + xVals[2] = null; + + return xVals; + } + + b /= a; + c /= a; + d /= a; + + double a2 = b * b; + double q = (a2 - 3 * c) / 9; + double q3 = q * q * q; + double r = (b * (2 * a2 - 9 * c) + 27 * d) / 54; + double r2 = r * r; + + if (r2 < q3) + { + nRoots = 3; + + double t = r / Math.Sqrt(q3); + t = Math.Clamp(t, -1, 1); + t = Math.Acos(t); + b /= 3; + q = -2 * Math.Sqrt(q); + + xVals[0] = q * Math.Cos(t / 3) - b; + xVals[1] = q * Math.Cos((t + pi_mult_2) / 3) - b; + xVals[2] = q * Math.Cos((t - pi_mult_2) / 3) - b; + + return xVals; + } + + double upperA = -Math.Cbrt(Math.Abs(r) + Math.Sqrt(r2 - q3)); + + if (r < 0) + upperA = -upperA; + + double upperB = upperA == 0 ? 0 : q / upperA; + b /= 3; + + xVals[0] = upperA + upperB - b; + + if (Precision.AlmostEquals(0.5 * Math.Sqrt(3) * (upperA - upperB), 0)) + { + nRoots = 2; + xVals[1] = -0.5 * (upperA + upperB) - b; + + return xVals; + } + + nRoots = 1; + + return xVals; + } + + private static double?[] solveP2(double a, double b, double c, out int nRoots) + { + double?[] xVals = new double?[2]; + + nRoots = 0; + + if (a == 0) + { + if (b == 0) + return xVals; + + nRoots = 1; + xVals[0] = -c / b; + } + + double discriminant = b * b - 4 * a * c; + + switch (discriminant) + { + case < 0: + break; + + case 0: + nRoots = 1; + xVals[0] = -b / (2 * a); + break; + + default: + nRoots = 2; + xVals[0] = (-b + Math.Sqrt(discriminant)) / (2 * a); + xVals[1] = (-b - Math.Sqrt(discriminant)) / (2 * a); + break; + } + + return xVals; + } } }