From 8a582eef35b795466a62cdd746906265202a40ca Mon Sep 17 00:00:00 2001 From: Nathen Date: Fri, 8 Mar 2024 12:35:11 -0500 Subject: [PATCH 01/14] Implement aim probability --- .../Difficulty/OsuPerformanceCalculator.cs | 1 - .../Difficulty/Skills/Aim.cs | 6 +- .../Difficulty/Skills/OsuProbSkill.cs | 143 ++++ osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs | 17 + .../Difficulty/Utils/Chandrupatla.cs | 102 +++ .../Difficulty/Utils/SpecialFunctions.cs | 694 ++++++++++++++++++ 6 files changed, 958 insertions(+), 5 deletions(-) create mode 100644 osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs create mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs create mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/Chandrupatla.cs create mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index b31f4ff519..2df4443915 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -90,7 +90,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty double lengthBonus = 0.95 + 0.4 * Math.Min(1.0, totalHits / 2000.0) + (totalHits > 2000 ? Math.Log10(totalHits / 2000.0) * 0.5 : 0.0); - aimValue *= lengthBonus; // Penalize misses by assessing # of misses relative to the total # of objects. Default a 3% reduction for any # of misses. if (effectiveMissCount > 0) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index 3f6b22bbb1..e3e630cd57 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -11,7 +11,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills /// /// Represents the skill required to correctly aim at every object in the map with a uniform CircleSize and normalized distances. /// - public class Aim : OsuStrainSkill + public class Aim : OsuProbSkill { public Aim(Mod[] mods, bool withSliders) : base(mods) @@ -23,13 +23,11 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills private double currentStrain; - private double skillMultiplier => 23.55; + private double skillMultiplier => 125; private double strainDecayBase => 0.15; private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000); - protected override double CalculateInitialStrain(double time, DifficultyHitObject current) => currentStrain * strainDecay(time - current.Previous(0).StartTime); - protected override double StrainValueAt(DifficultyHitObject current) { currentStrain *= strainDecay(current.DeltaTime); diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs new file mode 100644 index 0000000000..ef62ee1250 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs @@ -0,0 +1,143 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; +using osu.Game.Rulesets.Difficulty.Preprocessing; +using osu.Game.Rulesets.Difficulty.Skills; +using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Utils; + +namespace osu.Game.Rulesets.Osu.Difficulty.Skills +{ + public abstract class OsuProbSkill : Skill + { + protected OsuProbSkill(Mod[] mods) + : base(mods) + { + } + + private const double fc_probability = 0.02; + + private const int bin_count = 32; + + private readonly List difficulties = new List(); + + /// + /// Returns the strain value at . This value is calculated with or without respect to previous objects. + /// + protected abstract double StrainValueAt(DifficultyHitObject current); + + public override void Process(DifficultyHitObject current) + { + difficulties.Add(StrainValueAt(current)); + } + + private static double hitProbability(double skill, double difficulty) + { + if (skill <= 0) return 0; + if (difficulty <= 0) return 1; + + return SpecialFunctions.Erf(skill / (Math.Sqrt(2) * difficulty)); + } + + private static double fcProbabilityAtSkill(double skill, IEnumerable bins) + { + if (skill <= 0) return 0; + + return bins.Aggregate(1.0, (current, bin) => current * bin.FcProbability(skill)); + } + + private double fcProbabilityAtSkill(double skill) + { + if (skill <= 0) return 0; + + return difficulties.Aggregate(1, (current, d) => current * hitProbability(skill, d)); + } + + /// + /// Create an array of equally spaced bins. Count is linearly interpolated into each bin. + /// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2, + /// we will add 0.8 to the count of 3's and 0.2 to the count of 4's + /// + private Bin[] createBins(double maxDifficulty) + { + var bins = new Bin[bin_count]; + + for (int i = 0; i < bin_count; i++) + { + bins[i].Difficulty = maxDifficulty * (i + 1) / bin_count; + } + + foreach (double d in difficulties) + { + double binIndex = bin_count * (d / maxDifficulty) - 1; + + int lowerBound = (int)Math.Floor(binIndex); + double t = binIndex - lowerBound; + + //This can be -1, corresponding to the zero difficulty bucket. + //We don't store that since it doesn't contribute to difficulty + if (lowerBound >= 0) + { + bins[lowerBound].Count += (1 - t); + } + + int upperBound = lowerBound + 1; + + // this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway + if (upperBound < bin_count) + { + bins[upperBound].Count += t; + } + } + + return bins; + } + + private double difficultyValueBinned() + { + double maxDiff = difficulties.Max(); + if (maxDiff <= 1e-10) return 0; + + var bins = createBins(maxDiff); + + double lowerBoundEstimate = 0.5 * maxDiff; + double upperBoundEstimate = 3.0 * maxDiff; + + double skill = Chandrupatla.FindRootExpand( + skill => fcProbabilityAtSkill(skill, bins) - fc_probability, + lowerBoundEstimate, + upperBoundEstimate, + accuracy: 1e-4); + + return skill; + } + + private double difficultyValueExact() + { + double maxDiff = difficulties.Max(); + if (maxDiff <= 1e-10) return 0; + + double lowerBoundEstimate = 0.5 * maxDiff; + double upperBoundEstimate = 3.0 * maxDiff; + + double skill = Chandrupatla.FindRootExpand( + skill => fcProbabilityAtSkill(skill) - fc_probability, + lowerBoundEstimate, + upperBoundEstimate, + accuracy: 1e-4); + + return skill; + } + + public override double DifficultyValue() + { + if (difficulties.Count == 0) + return 0; + + return difficulties.Count < 2 * bin_count ? difficultyValueExact() : difficultyValueBinned(); + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs new file mode 100644 index 0000000000..210ef68578 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs @@ -0,0 +1,17 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + public struct Bin + { + public double Difficulty; + public double Count; + + public double HitProbability(double skill) => SpecialFunctions.Erf(skill / (Math.Sqrt(2) * Difficulty)); + + public double FcProbability(double skill) => Math.Pow(HitProbability(skill), Count); + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Chandrupatla.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/Chandrupatla.cs new file mode 100644 index 0000000000..1b450f2d29 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/Chandrupatla.cs @@ -0,0 +1,102 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + public static class Chandrupatla + { + /// + /// Finds the root of a , expanding the bounds if the root is not located within. + /// Expansion only occurs for the upward bound, as this function is optimized for functions of range [0, x), + /// which is useful for finding skill level (skill can never be below 0). + /// + /// The function of which to find the root. + /// The lower bound of the function inputs. + /// The upper bound of the function inputs. + /// The maximum number of iterations before the function throws an error. + /// The desired precision in which the root is returned. + /// The multiplier on the upper bound when no root is found within the provided bounds. + /// + public static double FindRootExpand(Func function, double guessLowerBound, double guessUpperBound, int maxIterations = 25, double accuracy = 1e-6D, double expansionFactor = 2) + { + double a = guessLowerBound; + double b = guessUpperBound; + double fa = function(a); + double fb = function(b); + + while (fa * fb > 0) + { + a = b; + b *= expansionFactor; + fa = function(a); + fb = function(b); + } + + double t = 0.5; + + for (int i = 0; i < maxIterations; i++) + { + double xt = a + t * (b - a); + double ft = function(xt); + + double c; + double fc; + + if (Math.Sign(ft) == Math.Sign(fa)) + { + c = a; + fc = fa; + } + else + { + c = b; + b = a; + fc = fb; + fb = fa; + } + + a = xt; + fa = ft; + + double xm, fm; + + if (Math.Abs(fa) < Math.Abs(fb)) + { + xm = a; + fm = fa; + } + else + { + xm = b; + fm = fb; + } + + if (fm == 0) + return xm; + + double tol = 2 * accuracy * Math.Abs(xm) + 2 * accuracy; + double tlim = tol / Math.Abs(b - c); + + if (tlim > 0.5) + { + return xm; + } + + double chi = (a - b) / (c - b); + double phi = (fa - fb) / (fc - fb); + bool iqi = phi * phi < chi && (1 - phi) * (1 - phi) < chi; + + if (iqi) + t = fa / (fb - fa) * fc / (fb - fc) + (c - a) / (b - a) * fa / (fc - fa) * fb / (fc - fb); + else + t = 0.5; + + t = Math.Min(1 - tlim, Math.Max(tlim, t)); + } + + return 0; + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs new file mode 100644 index 0000000000..672f2396be --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs @@ -0,0 +1,694 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +// All code is referenced from the following: +// https://github.com/mathnet/mathnet-numerics/blob/master/src/Numerics/SpecialFunctions/Erf.cs +// https://github.com/mathnet/mathnet-numerics/blob/master/src/Numerics/Optimization/NelderMeadSimplex.cs + +/* + Copyright (c) 2002-2022 Math.NET +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +using System; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + public class SpecialFunctions + { + private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d; + + /// + /// ************************************** + /// COEFFICIENTS FOR METHOD ErfImp * + /// ************************************** + /// + /// Polynomial coefficients for a numerator of ErfImp + /// calculation for Erf(x) in the interval [1e-10, 0.5]. + /// + private static readonly double[] erf_imp_an = { 0.00337916709551257388990745, -0.00073695653048167948530905, -0.374732337392919607868241, 0.0817442448733587196071743, -0.0421089319936548595203468, 0.0070165709512095756344528, -0.00495091255982435110337458, 0.000871646599037922480317225 }; + + /// Polynomial coefficients for a denominator of ErfImp + /// calculation for Erf(x) in the interval [1e-10, 0.5]. + /// + private static readonly double[] erf_imp_ad = { 1, -0.218088218087924645390535, 0.412542972725442099083918, -0.0841891147873106755410271, 0.0655338856400241519690695, -0.0120019604454941768171266, 0.00408165558926174048329689, -0.000615900721557769691924509 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [0.5, 0.75]. + /// + private static readonly double[] erf_imp_bn = { -0.0361790390718262471360258, 0.292251883444882683221149, 0.281447041797604512774415, 0.125610208862766947294894, 0.0274135028268930549240776, 0.00250839672168065762786937 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [0.5, 0.75]. + /// + private static readonly double[] erf_imp_bd = { 1, 1.8545005897903486499845, 1.43575803037831418074962, 0.582827658753036572454135, 0.124810476932949746447682, 0.0113724176546353285778481 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [0.75, 1.25]. + /// + private static readonly double[] erf_imp_cn = { -0.0397876892611136856954425, 0.153165212467878293257683, 0.191260295600936245503129, 0.10276327061989304213645, 0.029637090615738836726027, 0.0046093486780275489468812, 0.000307607820348680180548455 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [0.75, 1.25]. + /// + private static readonly double[] erf_imp_cd = { 1, 1.95520072987627704987886, 1.64762317199384860109595, 0.768238607022126250082483, 0.209793185936509782784315, 0.0319569316899913392596356, 0.00213363160895785378615014 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [1.25, 2.25]. + /// + private static readonly double[] erf_imp_dn = { -0.0300838560557949717328341, 0.0538578829844454508530552, 0.0726211541651914182692959, 0.0367628469888049348429018, 0.00964629015572527529605267, 0.00133453480075291076745275, 0.778087599782504251917881e-4 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [1.25, 2.25]. + /// + private static readonly double[] erf_imp_dd = { 1, 1.75967098147167528287343, 1.32883571437961120556307, 0.552528596508757581287907, 0.133793056941332861912279, 0.0179509645176280768640766, 0.00104712440019937356634038, -0.106640381820357337177643e-7 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [2.25, 3.5]. + /// + private static readonly double[] erf_imp_en = { -0.0117907570137227847827732, 0.014262132090538809896674, 0.0202234435902960820020765, 0.00930668299990432009042239, 0.00213357802422065994322516, 0.00025022987386460102395382, 0.120534912219588189822126e-4 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [2.25, 3.5]. + /// + private static readonly double[] erf_imp_ed = { 1, 1.50376225203620482047419, 0.965397786204462896346934, 0.339265230476796681555511, 0.0689740649541569716897427, 0.00771060262491768307365526, 0.000371421101531069302990367 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [3.5, 5.25]. + /// + private static readonly double[] erf_imp_fn = { -0.00546954795538729307482955, 0.00404190278731707110245394, 0.0054963369553161170521356, 0.00212616472603945399437862, 0.000394984014495083900689956, 0.365565477064442377259271e-4, 0.135485897109932323253786e-5 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [3.5, 5.25]. + /// + private static readonly double[] erf_imp_fd = { 1, 1.21019697773630784832251, 0.620914668221143886601045, 0.173038430661142762569515, 0.0276550813773432047594539, 0.00240625974424309709745382, 0.891811817251336577241006e-4, -0.465528836283382684461025e-11 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [5.25, 8]. + /// + private static readonly double[] erf_imp_gn = { -0.00270722535905778347999196, 0.0013187563425029400461378, 0.00119925933261002333923989, 0.00027849619811344664248235, 0.267822988218331849989363e-4, 0.923043672315028197865066e-6 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [5.25, 8]. + /// + private static readonly double[] erf_imp_gd = { 1, 0.814632808543141591118279, 0.268901665856299542168425, 0.0449877216103041118694989, 0.00381759663320248459168994, 0.000131571897888596914350697, 0.404815359675764138445257e-11 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [8, 11.5]. + /// + private static readonly double[] erf_imp_hn = { -0.00109946720691742196814323, 0.000406425442750422675169153, 0.000274499489416900707787024, 0.465293770646659383436343e-4, 0.320955425395767463401993e-5, 0.778286018145020892261936e-7 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [8, 11.5]. + /// + private static readonly double[] erf_imp_hd = { 1, 0.588173710611846046373373, 0.139363331289409746077541, 0.0166329340417083678763028, 0.00100023921310234908642639, 0.24254837521587225125068e-4 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [11.5, 17]. + /// + private static readonly double[] erf_imp_in = { -0.00056907993601094962855594, 0.000169498540373762264416984, 0.518472354581100890120501e-4, 0.382819312231928859704678e-5, 0.824989931281894431781794e-7 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [11.5, 17]. + /// + private static readonly double[] erf_imp_id = { 1, 0.339637250051139347430323, 0.043472647870310663055044, 0.00248549335224637114641629, 0.535633305337152900549536e-4, -0.117490944405459578783846e-12 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [17, 24]. + /// + private static readonly double[] erf_imp_jn = { -0.000241313599483991337479091, 0.574224975202501512365975e-4, 0.115998962927383778460557e-4, 0.581762134402593739370875e-6, 0.853971555085673614607418e-8 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [17, 24]. + /// + private static readonly double[] erf_imp_jd = { 1, 0.233044138299687841018015, 0.0204186940546440312625597, 0.000797185647564398289151125, 0.117019281670172327758019e-4 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [24, 38]. + /// + private static readonly double[] erf_imp_kn = { -0.000146674699277760365803642, 0.162666552112280519955647e-4, 0.269116248509165239294897e-5, 0.979584479468091935086972e-7, 0.101994647625723465722285e-8 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [24, 38]. + /// + private static readonly double[] erf_imp_kd = { 1, 0.165907812944847226546036, 0.0103361716191505884359634, 0.000286593026373868366935721, 0.298401570840900340874568e-5 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [38, 60]. + /// + private static readonly double[] erf_imp_ln = { -0.583905797629771786720406e-4, 0.412510325105496173512992e-5, 0.431790922420250949096906e-6, 0.993365155590013193345569e-8, 0.653480510020104699270084e-10 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [38, 60]. + /// + private static readonly double[] erf_imp_ld = { 1, 0.105077086072039915406159, 0.00414278428675475620830226, 0.726338754644523769144108e-4, 0.477818471047398785369849e-6 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [60, 85]. + /// + private static readonly double[] erf_imp_mn = { -0.196457797609229579459841e-4, 0.157243887666800692441195e-5, 0.543902511192700878690335e-7, 0.317472492369117710852685e-9 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [60, 85]. + /// + private static readonly double[] erf_imp_md = { 1, 0.052803989240957632204885, 0.000926876069151753290378112, 0.541011723226630257077328e-5, 0.535093845803642394908747e-15 }; + + /// Polynomial coefficients for a numerator in ErfImp + /// calculation for Erfc(x) in the interval [85, 110]. + /// + private static readonly double[] erf_imp_nn = { -0.789224703978722689089794e-5, 0.622088451660986955124162e-6, 0.145728445676882396797184e-7, 0.603715505542715364529243e-10 }; + + /// Polynomial coefficients for a denominator in ErfImp + /// calculation for Erfc(x) in the interval [85, 110]. + /// + private static readonly double[] erf_imp_nd = { 1, 0.0375328846356293715248719, 0.000467919535974625308126054, 0.193847039275845656900547e-5 }; + + /// + /// ************************************** + /// COEFFICIENTS FOR METHOD ErfInvImp * + /// ************************************** + /// + /// Polynomial coefficients for a numerator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0, 0.5]. + /// + private static readonly double[] erv_inv_imp_an = { -0.000508781949658280665617, -0.00836874819741736770379, 0.0334806625409744615033, -0.0126926147662974029034, -0.0365637971411762664006, 0.0219878681111168899165, 0.00822687874676915743155, -0.00538772965071242932965 }; + + /// Polynomial coefficients for a denominator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0, 0.5]. + /// + private static readonly double[] erv_inv_imp_ad = { 1, -0.970005043303290640362, -1.56574558234175846809, 1.56221558398423026363, 0.662328840472002992063, -0.71228902341542847553, -0.0527396382340099713954, 0.0795283687341571680018, -0.00233393759374190016776, 0.000886216390456424707504 }; + + /// Polynomial coefficients for a numerator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.5, 0.75]. + /// + private static readonly double[] erv_inv_imp_bn = { -0.202433508355938759655, 0.105264680699391713268, 8.37050328343119927838, 17.6447298408374015486, -18.8510648058714251895, -44.6382324441786960818, 17.445385985570866523, 21.1294655448340526258, -3.67192254707729348546 }; + + /// Polynomial coefficients for a denominator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.5, 0.75]. + /// + private static readonly double[] erv_inv_imp_bd = { 1, 6.24264124854247537712, 3.9713437953343869095, -28.6608180499800029974, -20.1432634680485188801, 48.5609213108739935468, 10.8268667355460159008, -22.6436933413139721736, 1.72114765761200282724 }; + + /// Polynomial coefficients for a numerator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x less than 3. + /// + private static readonly double[] erv_inv_imp_cn = { -0.131102781679951906451, -0.163794047193317060787, 0.117030156341995252019, 0.387079738972604337464, 0.337785538912035898924, 0.142869534408157156766, 0.0290157910005329060432, 0.00214558995388805277169, -0.679465575181126350155e-6, 0.285225331782217055858e-7, -0.681149956853776992068e-9 }; + + /// Polynomial coefficients for a denominator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x less than 3. + /// + private static readonly double[] erv_inv_imp_cd = { 1, 3.46625407242567245975, 5.38168345707006855425, 4.77846592945843778382, 2.59301921623620271374, 0.848854343457902036425, 0.152264338295331783612, 0.01105924229346489121 }; + + /// Polynomial coefficients for a numerator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 3 and 6. + /// + private static readonly double[] erv_inv_imp_dn = { -0.0350353787183177984712, -0.00222426529213447927281, 0.0185573306514231072324, 0.00950804701325919603619, 0.00187123492819559223345, 0.000157544617424960554631, 0.460469890584317994083e-5, -0.230404776911882601748e-9, 0.266339227425782031962e-11 }; + + /// Polynomial coefficients for a denominator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 3 and 6. + /// + private static readonly double[] erv_inv_imp_dd = { 1, 1.3653349817554063097, 0.762059164553623404043, 0.220091105764131249824, 0.0341589143670947727934, 0.00263861676657015992959, 0.764675292302794483503e-4 }; + + /// Polynomial coefficients for a numerator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 6 and 18. + /// + private static readonly double[] erv_inv_imp_en = { -0.0167431005076633737133, -0.00112951438745580278863, 0.00105628862152492910091, 0.000209386317487588078668, 0.149624783758342370182e-4, 0.449696789927706453732e-6, 0.462596163522878599135e-8, -0.281128735628831791805e-13, 0.99055709973310326855e-16 }; + + /// Polynomial coefficients for a denominator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 6 and 18. + /// + private static readonly double[] erv_inv_imp_ed = { 1, 0.591429344886417493481, 0.138151865749083321638, 0.0160746087093676504695, 0.000964011807005165528527, 0.275335474764726041141e-4, 0.282243172016108031869e-6 }; + + /// Polynomial coefficients for a numerator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 18 and 44. + /// + private static readonly double[] erv_inv_imp_fn = { -0.0024978212791898131227, -0.779190719229053954292e-5, 0.254723037413027451751e-4, 0.162397777342510920873e-5, 0.396341011304801168516e-7, 0.411632831190944208473e-9, 0.145596286718675035587e-11, -0.116765012397184275695e-17 }; + + /// Polynomial coefficients for a denominator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 18 and 44. + /// + private static readonly double[] erv_inv_imp_fd = { 1, 0.207123112214422517181, 0.0169410838120975906478, 0.000690538265622684595676, 0.145007359818232637924e-4, 0.144437756628144157666e-6, 0.509761276599778486139e-9 }; + + /// Polynomial coefficients for a numerator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x greater than 44. + /// + private static readonly double[] erv_inv_imp_gn = { -0.000539042911019078575891, -0.28398759004727721098e-6, 0.899465114892291446442e-6, 0.229345859265920864296e-7, 0.225561444863500149219e-9, 0.947846627503022684216e-12, 0.135880130108924861008e-14, -0.348890393399948882918e-21 }; + + /// Polynomial coefficients for a denominator of ErfInvImp + /// calculation for Erf^-1(z) in the interval [0.75, 1] with x greater than 44. + /// + private static readonly double[] erv_inv_imp_gd = { 1, 0.0845746234001899436914, 0.00282092984726264681981, 0.468292921940894236786e-4, 0.399968812193862100054e-6, 0.161809290887904476097e-8, 0.231558608310259605225e-11 }; + + /// Calculates the error function. + /// The value to evaluate. + /// the error function evaluated at given value. + /// + /// + /// returns 1 if x == double.PositiveInfinity. + /// returns -1 if x == double.NegativeInfinity. + /// + /// + public static double Erf(double x) + { + if (x == 0) + { + return 0; + } + + if (double.IsPositiveInfinity(x)) + { + return 1; + } + + if (double.IsNegativeInfinity(x)) + { + return -1; + } + + if (double.IsNaN(x)) + { + return double.NaN; + } + + return erfImp(x, false); + } + + /// Calculates the complementary error function. + /// The value to evaluate. + /// the complementary error function evaluated at given value. + /// + /// + /// returns 0 if x == double.PositiveInfinity. + /// returns 2 if x == double.NegativeInfinity. + /// + /// + public static double Erfc(double x) + { + if (x == 0) + { + return 1; + } + + if (double.IsPositiveInfinity(x)) + { + return 0; + } + + if (double.IsNegativeInfinity(x)) + { + return 2; + } + + if (double.IsNaN(x)) + { + return double.NaN; + } + + return erfImp(x, true); + } + + /// Calculates the inverse error function evaluated at z. + /// The inverse error function evaluated at given value. + /// + /// + /// returns double.PositiveInfinity if z >= 1.0. + /// returns double.NegativeInfinity if z <= -1.0. + /// + /// + /// Calculates the inverse error function evaluated at z. + /// value to evaluate. + /// the inverse error function evaluated at Z. + public static double ErfInv(double z) + { + if (z == 0.0) + { + return 0.0; + } + + if (z >= 1.0) + { + return double.PositiveInfinity; + } + + if (z <= -1.0) + { + return double.NegativeInfinity; + } + + double p, q, s; + + if (z < 0) + { + p = -z; + q = 1 - p; + s = -1; + } + else + { + p = z; + q = 1 - z; + s = 1; + } + + return erfInvImpl(p, q, s); + } + + /// + /// Implementation of the error function. + /// + /// Where to evaluate the error function. + /// Whether to compute 1 - the error function. + /// the error function. + private static double erfImp(double z, bool invert) + { + if (z < 0) + { + if (!invert) + { + return -erfImp(-z, false); + } + + if (z < -0.5) + { + return 2 - erfImp(-z, true); + } + + return 1 + erfImp(-z, false); + } + + double result; + + // Big bunch of selection statements now to pick which + // implementation to use, try to put most likely options + // first: + if (z < 0.5) + { + // We're going to calculate erf: + if (z < 1e-10) + { + result = (z * 1.125) + (z * 0.003379167095512573896158903121545171688); + } + else + { + // Worst case absolute error found: 6.688618532e-21 + result = (z * 1.125) + (z * evaluatePolynomial(z, erf_imp_an) / evaluatePolynomial(z, erf_imp_ad)); + } + } + else if (z < 110) + { + // We'll be calculating erfc: + invert = !invert; + double r, b; + + if (z < 0.75) + { + // Worst case absolute error found: 5.582813374e-21 + r = evaluatePolynomial(z - 0.5, erf_imp_bn) / evaluatePolynomial(z - 0.5, erf_imp_bd); + b = 0.3440242112F; + } + else if (z < 1.25) + { + // Worst case absolute error found: 4.01854729e-21 + r = evaluatePolynomial(z - 0.75, erf_imp_cn) / evaluatePolynomial(z - 0.75, erf_imp_cd); + b = 0.419990927F; + } + else if (z < 2.25) + { + // Worst case absolute error found: 2.866005373e-21 + r = evaluatePolynomial(z - 1.25, erf_imp_dn) / evaluatePolynomial(z - 1.25, erf_imp_dd); + b = 0.4898625016F; + } + else if (z < 3.5) + { + // Worst case absolute error found: 1.045355789e-21 + r = evaluatePolynomial(z - 2.25, erf_imp_en) / evaluatePolynomial(z - 2.25, erf_imp_ed); + b = 0.5317370892F; + } + else if (z < 5.25) + { + // Worst case absolute error found: 8.300028706e-22 + r = evaluatePolynomial(z - 3.5, erf_imp_fn) / evaluatePolynomial(z - 3.5, erf_imp_fd); + b = 0.5489973426F; + } + else if (z < 8) + { + // Worst case absolute error found: 1.700157534e-21 + r = evaluatePolynomial(z - 5.25, erf_imp_gn) / evaluatePolynomial(z - 5.25, erf_imp_gd); + b = 0.5571740866F; + } + else if (z < 11.5) + { + // Worst case absolute error found: 3.002278011e-22 + r = evaluatePolynomial(z - 8, erf_imp_hn) / evaluatePolynomial(z - 8, erf_imp_hd); + b = 0.5609807968F; + } + else if (z < 17) + { + // Worst case absolute error found: 6.741114695e-21 + r = evaluatePolynomial(z - 11.5, erf_imp_in) / evaluatePolynomial(z - 11.5, erf_imp_id); + b = 0.5626493692F; + } + else if (z < 24) + { + // Worst case absolute error found: 7.802346984e-22 + r = evaluatePolynomial(z - 17, erf_imp_jn) / evaluatePolynomial(z - 17, erf_imp_jd); + b = 0.5634598136F; + } + else if (z < 38) + { + // Worst case absolute error found: 2.414228989e-22 + r = evaluatePolynomial(z - 24, erf_imp_kn) / evaluatePolynomial(z - 24, erf_imp_kd); + b = 0.5638477802F; + } + else if (z < 60) + { + // Worst case absolute error found: 5.896543869e-24 + r = evaluatePolynomial(z - 38, erf_imp_ln) / evaluatePolynomial(z - 38, erf_imp_ld); + b = 0.5640528202F; + } + else if (z < 85) + { + // Worst case absolute error found: 3.080612264e-21 + r = evaluatePolynomial(z - 60, erf_imp_mn) / evaluatePolynomial(z - 60, erf_imp_md); + b = 0.5641309023F; + } + else + { + // Worst case absolute error found: 8.094633491e-22 + r = evaluatePolynomial(z - 85, erf_imp_nn) / evaluatePolynomial(z - 85, erf_imp_nd); + b = 0.5641584396F; + } + + double g = Math.Exp(-z * z) / z; + result = (g * b) + (g * r); + } + else + { + // Any value of z larger than 28 will underflow to zero: + result = 0; + invert = !invert; + } + + if (invert) + { + result = 1 - result; + } + + return result; + } + + /// Calculates the complementary inverse error function evaluated at z. + /// The complementary inverse error function evaluated at given value. + /// We have tested this implementation against the arbitrary precision mpmath library + /// and found cases where we can only guarantee 9 significant figures correct. + /// + /// returns double.PositiveInfinity if z <= 0.0. + /// returns double.NegativeInfinity if z >= 2.0. + /// + /// + /// calculates the complementary inverse error function evaluated at z. + /// value to evaluate. + /// the complementary inverse error function evaluated at Z. + public static double ErfcInv(double z) + { + if (z <= 0.0) + { + return double.PositiveInfinity; + } + + if (z >= 2.0) + { + return double.NegativeInfinity; + } + + double p, q, s; + + if (z > 1) + { + q = 2 - z; + p = 1 - q; + s = -1; + } + else + { + p = 1 - z; + q = z; + s = 1; + } + + return erfInvImpl(p, q, s); + } + + /// + /// The implementation of the inverse error function. + /// + /// First intermediate parameter. + /// Second intermediate parameter. + /// Third intermediate parameter. + /// the inverse error function. + private static double erfInvImpl(double p, double q, double s) + { + double result; + + if (p <= 0.5) + { + // Evaluate inverse erf using the rational approximation: + // + // x = p(p+10)(Y+R(p)) + // + // Where Y is a constant, and R(p) is optimized for a low + // absolute error compared to |Y|. + // + // double: Max error found: 2.001849e-18 + // long double: Max error found: 1.017064e-20 + // Maximum Deviation Found (actual error term at infinite precision) 8.030e-21 + const float y = 0.0891314744949340820313f; + double g = p * (p + 10); + double r = evaluatePolynomial(p, erv_inv_imp_an) / evaluatePolynomial(p, erv_inv_imp_ad); + result = (g * y) + (g * r); + } + else if (q >= 0.25) + { + // Rational approximation for 0.5 > q >= 0.25 + // + // x = sqrt(-2*log(q)) / (Y + R(q)) + // + // Where Y is a constant, and R(q) is optimized for a low + // absolute error compared to Y. + // + // double : Max error found: 7.403372e-17 + // long double : Max error found: 6.084616e-20 + // Maximum Deviation Found (error term) 4.811e-20 + const float y = 2.249481201171875f; + double g = Math.Sqrt(-2 * Math.Log(q)); + double xs = q - 0.25; + double r = evaluatePolynomial(xs, erv_inv_imp_bn) / evaluatePolynomial(xs, erv_inv_imp_bd); + result = g / (y + r); + } + else + { + // For q < 0.25 we have a series of rational approximations all + // of the general form: + // + // let: x = sqrt(-log(q)) + // + // Then the result is given by: + // + // x(Y+R(x-B)) + // + // where Y is a constant, B is the lowest value of x for which + // the approximation is valid, and R(x-B) is optimized for a low + // absolute error compared to Y. + // + // Note that almost all code will really go through the first + // or maybe second approximation. After than we're dealing with very + // small input values indeed: 80 and 128 bit long double's go all the + // way down to ~ 1e-5000 so the "tail" is rather long... + double x = Math.Sqrt(-Math.Log(q)); + + if (x < 3) + { + // Max error found: 1.089051e-20 + const float y = 0.807220458984375f; + double xs = x - 1.125; + double r = evaluatePolynomial(xs, erv_inv_imp_cn) / evaluatePolynomial(xs, erv_inv_imp_cd); + result = (y * x) + (r * x); + } + else if (x < 6) + { + // Max error found: 8.389174e-21 + const float y = 0.93995571136474609375f; + double xs = x - 3; + double r = evaluatePolynomial(xs, erv_inv_imp_dn) / evaluatePolynomial(xs, erv_inv_imp_dd); + result = (y * x) + (r * x); + } + else if (x < 18) + { + // Max error found: 1.481312e-19 + const float y = 0.98362827301025390625f; + double xs = x - 6; + double r = evaluatePolynomial(xs, erv_inv_imp_en) / evaluatePolynomial(xs, erv_inv_imp_ed); + result = (y * x) + (r * x); + } + else if (x < 44) + { + // Max error found: 5.697761e-20 + const float y = 0.99714565277099609375f; + double xs = x - 18; + double r = evaluatePolynomial(xs, erv_inv_imp_fn) / evaluatePolynomial(xs, erv_inv_imp_fd); + result = (y * x) + (r * x); + } + else + { + // Max error found: 1.279746e-20 + const float y = 0.99941349029541015625f; + double xs = x - 44; + double r = evaluatePolynomial(xs, erv_inv_imp_gn) / evaluatePolynomial(xs, erv_inv_imp_gd); + result = (y * x) + (r * x); + } + } + + return s * result; + } + + /// + /// Evaluate a polynomial at point x. + /// Coefficients are ordered ascending by power with power k at index k. + /// Example: coefficients [3,-1,2] represent y=2x^2-x+3. + /// + /// The location where to evaluate the polynomial at. + /// The coefficients of the polynomial, coefficient for power k at index k. + /// + /// is a null reference. + /// + private static double evaluatePolynomial(double z, params double[] coefficients) + { + // 2020-10-07 jbialogrodzki #730 Since this is public API we should probably + // handle null arguments? It doesn't seem to have been done consistently in this class though. + if (coefficients == null) + { + throw new ArgumentNullException(nameof(coefficients)); + } + + // 2020-10-07 jbialogrodzki #730 Zero polynomials need explicit handling. + // Without this check, we attempted to peek coefficients at negative indices! + int n = coefficients.Length; + + if (n == 0) + { + return 0; + } + + double sum = coefficients[n - 1]; + + for (int i = n - 2; i >= 0; --i) + { + sum *= z; + sum += coefficients[i]; + } + + return sum; + } + } +} From 089b27d4d3637beaf635ca7bddb92b3bce6ae47a Mon Sep 17 00:00:00 2001 From: Nathen Date: Fri, 8 Mar 2024 12:41:34 -0500 Subject: [PATCH 02/14] Fix pp counter --- osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs index ef62ee1250..188eccfd8d 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs @@ -103,12 +103,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills var bins = createBins(maxDiff); - double lowerBoundEstimate = 0.5 * maxDiff; + const double lower_bound = 0; double upperBoundEstimate = 3.0 * maxDiff; double skill = Chandrupatla.FindRootExpand( skill => fcProbabilityAtSkill(skill, bins) - fc_probability, - lowerBoundEstimate, + lower_bound, upperBoundEstimate, accuracy: 1e-4); @@ -120,12 +120,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills double maxDiff = difficulties.Max(); if (maxDiff <= 1e-10) return 0; - double lowerBoundEstimate = 0.5 * maxDiff; + const double lower_bound = 0; double upperBoundEstimate = 3.0 * maxDiff; double skill = Chandrupatla.FindRootExpand( skill => fcProbabilityAtSkill(skill) - fc_probability, - lowerBoundEstimate, + lower_bound, upperBoundEstimate, accuracy: 1e-4); From 266e6175d2c37e39a95ccf22f8fe1fee47ac2e25 Mon Sep 17 00:00:00 2001 From: nathen Date: Fri, 8 Mar 2024 20:41:27 -0500 Subject: [PATCH 03/14] Make hitProbability abstract --- osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs | 9 +++++++++ .../Difficulty/Skills/OsuProbSkill.cs | 16 ++++++---------- osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs | 6 ------ 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index e3e630cd57..1d37c35259 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -5,6 +5,7 @@ using System; using osu.Game.Rulesets.Difficulty.Preprocessing; using osu.Game.Rulesets.Mods; using osu.Game.Rulesets.Osu.Difficulty.Evaluators; +using osu.Game.Rulesets.Osu.Difficulty.Utils; namespace osu.Game.Rulesets.Osu.Difficulty.Skills { @@ -26,6 +27,14 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills private double skillMultiplier => 125; private double strainDecayBase => 0.15; + protected override double HitProbability(double skill, double difficulty) + { + if (skill <= 0) return 0; + if (difficulty <= 0) return 1; + + return SpecialFunctions.Erf(skill / (Math.Sqrt(2) * difficulty)); + } + private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000); protected override double StrainValueAt(DifficultyHitObject current) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs index 188eccfd8d..84c7583665 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs @@ -34,26 +34,22 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills difficulties.Add(StrainValueAt(current)); } - private static double hitProbability(double skill, double difficulty) - { - if (skill <= 0) return 0; - if (difficulty <= 0) return 1; + protected abstract double HitProbability(double skill, double difficulty); - return SpecialFunctions.Erf(skill / (Math.Sqrt(2) * difficulty)); - } - - private static double fcProbabilityAtSkill(double skill, IEnumerable bins) + private double fcProbabilityAtSkill(double skill, IEnumerable bins) { if (skill <= 0) return 0; - return bins.Aggregate(1.0, (current, bin) => current * bin.FcProbability(skill)); + double totalHitProbability(Bin bin) => Math.Pow(HitProbability(skill, bin.Difficulty), bin.Count); + + return bins.Aggregate(1.0, (current, bin) => current * totalHitProbability(bin)); } private double fcProbabilityAtSkill(double skill) { if (skill <= 0) return 0; - return difficulties.Aggregate(1, (current, d) => current * hitProbability(skill, d)); + return difficulties.Aggregate(1, (current, d) => current * HitProbability(skill, d)); } /// diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs index 210ef68578..943cc6958d 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs @@ -1,17 +1,11 @@ // Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. // See the LICENCE file in the repository root for full licence text. -using System; - namespace osu.Game.Rulesets.Osu.Difficulty.Utils { public struct Bin { public double Difficulty; public double Count; - - public double HitProbability(double skill) => SpecialFunctions.Erf(skill / (Math.Sqrt(2) * Difficulty)); - - public double FcProbability(double skill) => Math.Pow(HitProbability(skill), Count); } } From 32e95a809335a3b9cde1d53a4f7fd2547f9bc07b Mon Sep 17 00:00:00 2001 From: Nathen Date: Thu, 21 Mar 2024 19:54:02 -0400 Subject: [PATCH 04/14] Move CreateBins method to Bin.cs --- .../Difficulty/Skills/OsuProbSkill.cs | 50 ++----------------- osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs | 46 +++++++++++++++++ 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs index 84c7583665..f956427113 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs @@ -36,7 +36,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills protected abstract double HitProbability(double skill, double difficulty); - private double fcProbabilityAtSkill(double skill, IEnumerable bins) + private double fcProbabilityAtSkillBinned(double skill, IEnumerable bins) { if (skill <= 0) return 0; @@ -45,65 +45,25 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills return bins.Aggregate(1.0, (current, bin) => current * totalHitProbability(bin)); } - private double fcProbabilityAtSkill(double skill) + private double fcProbabilityAtSkillExact(double skill) { if (skill <= 0) return 0; return difficulties.Aggregate(1, (current, d) => current * HitProbability(skill, d)); } - /// - /// Create an array of equally spaced bins. Count is linearly interpolated into each bin. - /// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2, - /// we will add 0.8 to the count of 3's and 0.2 to the count of 4's - /// - private Bin[] createBins(double maxDifficulty) - { - var bins = new Bin[bin_count]; - - for (int i = 0; i < bin_count; i++) - { - bins[i].Difficulty = maxDifficulty * (i + 1) / bin_count; - } - - foreach (double d in difficulties) - { - double binIndex = bin_count * (d / maxDifficulty) - 1; - - int lowerBound = (int)Math.Floor(binIndex); - double t = binIndex - lowerBound; - - //This can be -1, corresponding to the zero difficulty bucket. - //We don't store that since it doesn't contribute to difficulty - if (lowerBound >= 0) - { - bins[lowerBound].Count += (1 - t); - } - - int upperBound = lowerBound + 1; - - // this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway - if (upperBound < bin_count) - { - bins[upperBound].Count += t; - } - } - - return bins; - } - private double difficultyValueBinned() { double maxDiff = difficulties.Max(); if (maxDiff <= 1e-10) return 0; - var bins = createBins(maxDiff); + var bins = Bin.CreateBins(difficulties, bin_count); const double lower_bound = 0; double upperBoundEstimate = 3.0 * maxDiff; double skill = Chandrupatla.FindRootExpand( - skill => fcProbabilityAtSkill(skill, bins) - fc_probability, + skill => fcProbabilityAtSkillBinned(skill, bins) - fc_probability, lower_bound, upperBoundEstimate, accuracy: 1e-4); @@ -120,7 +80,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills double upperBoundEstimate = 3.0 * maxDiff; double skill = Chandrupatla.FindRootExpand( - skill => fcProbabilityAtSkill(skill) - fc_probability, + skill => fcProbabilityAtSkillExact(skill) - fc_probability, lower_bound, upperBoundEstimate, accuracy: 1e-4); diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs index 943cc6958d..0fbc342916 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs @@ -1,11 +1,57 @@ // Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. // See the LICENCE file in the repository root for full licence text. +using System; +using System.Collections.Generic; +using System.Linq; + namespace osu.Game.Rulesets.Osu.Difficulty.Utils { public struct Bin { public double Difficulty; public double Count; + + /// + /// Create an array of equally spaced bins. Count is linearly interpolated into each bin. + /// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2, + /// we will add 0.8 to the count of 3's and 0.2 to the count of 4's + /// + public static Bin[] CreateBins(List difficulties, int binCount) + { + double maxDifficulty = difficulties.Max(); + + var bins = new Bin[binCount]; + + for (int i = 0; i < binCount; i++) + { + bins[i].Difficulty = maxDifficulty * (i + 1) / binCount; + } + + foreach (double d in difficulties) + { + double binIndex = binCount * (d / maxDifficulty) - 1; + + int lowerBound = (int)Math.Floor(binIndex); + double t = binIndex - lowerBound; + + //This can be -1, corresponding to the zero difficulty bucket. + //We don't store that since it doesn't contribute to difficulty + if (lowerBound >= 0) + { + bins[lowerBound].Count += (1 - t); + } + + int upperBound = lowerBound + 1; + + // this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway + if (upperBound < binCount) + { + bins[upperBound].Count += t; + } + } + + return bins; + } } } From 5f759ff8efaca4805f60666bb3742dcd72323771 Mon Sep 17 00:00:00 2001 From: Nathen Date: Mon, 25 Mar 2024 12:14:15 -0400 Subject: [PATCH 05/14] Make FcProbability abstract --- osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs | 2 ++ osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index 1d37c35259..72da687d9a 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -27,6 +27,8 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills private double skillMultiplier => 125; private double strainDecayBase => 0.15; + protected override double FcProbability => 0.02; + protected override double HitProbability(double skill, double difficulty) { if (skill <= 0) return 0; diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs index 84c7583665..d86bd455f0 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs @@ -18,7 +18,9 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills { } - private const double fc_probability = 0.02; + /// The skill level returned from this class will have FcProbability chance of hitting every note correctly. + /// A higher value rewards short, high difficulty sections, whereas a lower value rewards consistent, lower difficulty. + protected abstract double FcProbability { get; } private const int bin_count = 32; @@ -103,7 +105,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills double upperBoundEstimate = 3.0 * maxDiff; double skill = Chandrupatla.FindRootExpand( - skill => fcProbabilityAtSkill(skill, bins) - fc_probability, + skill => fcProbabilityAtSkill(skill, bins) - FcProbability, lower_bound, upperBoundEstimate, accuracy: 1e-4); @@ -120,7 +122,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills double upperBoundEstimate = 3.0 * maxDiff; double skill = Chandrupatla.FindRootExpand( - skill => fcProbabilityAtSkill(skill) - fc_probability, + skill => fcProbabilityAtSkill(skill) - FcProbability, lower_bound, upperBoundEstimate, accuracy: 1e-4); From fd53423121a314a6df0e77586a5d298b03625c57 Mon Sep 17 00:00:00 2001 From: Nathen Date: Mon, 1 Apr 2024 21:05:15 -0400 Subject: [PATCH 06/14] Restructuring --- .../Difficulty/Skills/OsuProbSkill.cs | 38 +++++++++---------- .../Utils/{Chandrupatla.cs => RootFinding.cs} | 5 +-- 2 files changed, 20 insertions(+), 23 deletions(-) rename osu.Game.Rulesets.Osu/Difficulty/Utils/{Chandrupatla.cs => RootFinding.cs} (94%) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs index da98b3f295..2dabbb1fdd 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs @@ -38,22 +38,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills protected abstract double HitProbability(double skill, double difficulty); - private double fcProbabilityAtSkillBinned(double skill, IEnumerable bins) - { - if (skill <= 0) return 0; - - double totalHitProbability(Bin bin) => Math.Pow(HitProbability(skill, bin.Difficulty), bin.Count); - - return bins.Aggregate(1.0, (current, bin) => current * totalHitProbability(bin)); - } - - private double fcProbabilityAtSkillExact(double skill) - { - if (skill <= 0) return 0; - - return difficulties.Aggregate(1, (current, d) => current * HitProbability(skill, d)); - } - private double difficultyValueBinned() { double maxDiff = difficulties.Max(); @@ -64,13 +48,20 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills const double lower_bound = 0; double upperBoundEstimate = 3.0 * maxDiff; - double skill = Chandrupatla.FindRootExpand( - skill => fcProbabilityAtSkillBinned(skill, bins) - FcProbability, + double skill = RootFinding.FindRootExpand( + skill => fcProbability(skill) - FcProbability, lower_bound, upperBoundEstimate, accuracy: 1e-4); return skill; + + double fcProbability(double s) + { + if (s <= 0) return 0; + + return bins.Aggregate(1.0, (current, bin) => current * Math.Pow(HitProbability(s, bin.Difficulty), bin.Count)); + } } private double difficultyValueExact() @@ -81,13 +72,20 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills const double lower_bound = 0; double upperBoundEstimate = 3.0 * maxDiff; - double skill = Chandrupatla.FindRootExpand( - skill => fcProbabilityAtSkillExact(skill) - FcProbability, + double skill = RootFinding.FindRootExpand( + skill => fcProbability(skill) - FcProbability, lower_bound, upperBoundEstimate, accuracy: 1e-4); return skill; + + double fcProbability(double s) + { + if (s <= 0) return 0; + + return difficulties.Aggregate(1, (current, d) => current * HitProbability(s, d)); + } } public override double DifficultyValue() diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Chandrupatla.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs similarity index 94% rename from osu.Game.Rulesets.Osu/Difficulty/Utils/Chandrupatla.cs rename to osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs index 1b450f2d29..64b4d38ec1 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/Chandrupatla.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs @@ -5,10 +5,10 @@ using System; namespace osu.Game.Rulesets.Osu.Difficulty.Utils { - public static class Chandrupatla + public static class RootFinding { /// - /// Finds the root of a , expanding the bounds if the root is not located within. + /// Finds the root of a using the Chandrupatla method, expanding the bounds if the root is not located within. /// Expansion only occurs for the upward bound, as this function is optimized for functions of range [0, x), /// which is useful for finding skill level (skill can never be below 0). /// @@ -18,7 +18,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils /// The maximum number of iterations before the function throws an error. /// The desired precision in which the root is returned. /// The multiplier on the upper bound when no root is found within the provided bounds. - /// public static double FindRootExpand(Func function, double guessLowerBound, double guessUpperBound, int maxIterations = 25, double accuracy = 1e-6D, double expansionFactor = 2) { double a = guessLowerBound; From ef0797ad30e1a1a75eb37f7ab8fad00aa244e061 Mon Sep 17 00:00:00 2001 From: Nathen Date: Tue, 23 Apr 2024 19:20:36 -0400 Subject: [PATCH 07/14] Implement quartic curve fitting miss penalty --- .../Difficulty/OsuDifficultyAttributes.cs | 6 + .../Difficulty/OsuDifficultyCalculator.cs | 2 + .../Difficulty/OsuPerformanceCalculator.cs | 12 +- .../Difficulty/Skills/Aim.cs | 32 +++++ .../Difficulty/Skills/OsuProbSkill.cs | 33 ++++- osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs | 14 ++- .../Difficulty/Utils/FitMissCountPoints.cs | 49 ++++++++ .../Difficulty/Utils/PoissonBinomial.cs | 119 ++++++++++++++++++ .../Difficulty/Utils/SpecialFunctions.cs | 38 ++++++ 9 files changed, 293 insertions(+), 12 deletions(-) create mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs create mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs index 83538a2f42..b37fa52a81 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs @@ -19,6 +19,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty [JsonProperty("aim_difficulty")] public double AimDifficulty { get; set; } + /// + /// The difficulty corresponding to the aim skill. + /// + [JsonProperty("aim_penalty_constants")] + public (double, double, double) AimPenaltyConstants { get; set; } + /// /// The difficulty corresponding to the speed skill. /// diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs index 007cd977e5..b9fc4f0993 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs @@ -37,6 +37,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty return new OsuDifficultyAttributes { Mods = mods }; double aimRating = Math.Sqrt(skills[0].DifficultyValue()) * difficulty_multiplier; + (double, double, double) aimPenaltyConstants = ((Aim)skills[0]).GetMissCountCoefficients(); double aimRatingNoSliders = Math.Sqrt(skills[1].DifficultyValue()) * difficulty_multiplier; double speedRating = Math.Sqrt(skills[2].DifficultyValue()) * difficulty_multiplier; double speedNotes = ((Speed)skills[2]).RelevantNoteCount(); @@ -97,6 +98,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty StarRating = starRating, Mods = mods, AimDifficulty = aimRating, + AimPenaltyConstants = aimPenaltyConstants, SpeedDifficulty = speedRating, SpeedNoteCount = speedNotes, FlashlightDifficulty = flashlightRating, diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index 2df4443915..336e19495b 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; using osu.Game.Rulesets.Difficulty; +using osu.Game.Rulesets.Osu.Difficulty.Utils; using osu.Game.Rulesets.Osu.Mods; using osu.Game.Rulesets.Scoring; using osu.Game.Scoring; @@ -93,9 +94,16 @@ namespace osu.Game.Rulesets.Osu.Difficulty // Penalize misses by assessing # of misses relative to the total # of objects. Default a 3% reduction for any # of misses. if (effectiveMissCount > 0) - aimValue *= 0.97 * Math.Pow(1 - Math.Pow(effectiveMissCount / totalHits, 0.775), effectiveMissCount); + { + double a = attributes.AimPenaltyConstants.Item1; + double b = attributes.AimPenaltyConstants.Item2; + double c = attributes.AimPenaltyConstants.Item3; + double d = Math.Log(totalHits + 1) - a - b - c; - aimValue *= getComboScalingFactor(attributes); + double penalty = Math.Pow(1 - RootFinding.FindRootExpand(x => Math.Exp(a * x * x * x * x + b * x * x * x + c * x * x + d * x) - 1 - effectiveMissCount, 0, 1), 1.5); + + aimValue *= penalty; + } double approachRateFactor = 0.0; if (attributes.ApproachRate > 10.33) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index 72da687d9a..2a8df994b4 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -46,5 +46,37 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills return currentStrain; } + + /// + /// The coefficients of a quartic fitted to the miss counts at each skill level. + /// + /// The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator. + public (double, double, double) GetMissCountCoefficients() + { + const int count = 21; + const double penalty_per_misscount = 1.0 / (count - 1); + + double fcSkill = DifficultyValue(); + + double[] misscounts = new double[count]; + + for (int i = 0; i < count; i++) + { + if (i == 0) + { + misscounts[i] = 0; + continue; + } + + double penalizedSkill = fcSkill - fcSkill * penalty_per_misscount * i; + + // Save misscounts as log form to give higher weight to lower values. Add 1 so that the lowest misscounts remain above 0. + misscounts[i] = Math.Log(GetMissCountAtSkill(penalizedSkill) + 1); + } + + double[] constants = FitMissCountPoints.GetPolynomialCoefficients(misscounts); + + return (constants[0], constants[1], constants[2]); + } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs index 2dabbb1fdd..0fece3069f 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs @@ -22,8 +22,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills /// A higher value rewards short, high difficulty sections, whereas a lower value rewards consistent, lower difficulty. protected abstract double FcProbability { get; } - private const int bin_count = 32; - private readonly List difficulties = new List(); /// @@ -43,7 +41,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills double maxDiff = difficulties.Max(); if (maxDiff <= 1e-10) return 0; - var bins = Bin.CreateBins(difficulties, bin_count); + var bins = Bin.CreateBins(difficulties); const double lower_bound = 0; double upperBoundEstimate = 3.0 * maxDiff; @@ -93,7 +91,34 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills if (difficulties.Count == 0) return 0; - return difficulties.Count < 2 * bin_count ? difficultyValueExact() : difficultyValueBinned(); + return difficulties.Count < 64 ? difficultyValueExact() : difficultyValueBinned(); + } + + /// + /// Find the lowest misscount that a player with the provided would have a 2% chance of achieving. + /// + public double GetMissCountAtSkill(double skill) + { + double maxDiff = difficulties.Max(); + + if (maxDiff == 0) + return 0; + if (skill <= 0) + return difficulties.Count; + + PoissonBinomial poiBin; + + if (difficulties.Count > 64) + { + var bins = Bin.CreateBins(difficulties); + poiBin = new PoissonBinomial(bins, skill, HitProbability); + } + else + { + poiBin = new PoissonBinomial(difficulties, skill, HitProbability); + } + + return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - FcProbability, -50, 1000, accuracy: 1e-4)); } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs index 0fbc342916..7ed3595c35 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs @@ -12,25 +12,27 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils public double Difficulty; public double Count; + private const int bin_count = 32; + /// /// Create an array of equally spaced bins. Count is linearly interpolated into each bin. /// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2, /// we will add 0.8 to the count of 3's and 0.2 to the count of 4's /// - public static Bin[] CreateBins(List difficulties, int binCount) + public static Bin[] CreateBins(List difficulties) { double maxDifficulty = difficulties.Max(); - var bins = new Bin[binCount]; + var bins = new Bin[bin_count]; - for (int i = 0; i < binCount; i++) + for (int i = 0; i < bin_count; i++) { - bins[i].Difficulty = maxDifficulty * (i + 1) / binCount; + bins[i].Difficulty = maxDifficulty * (i + 1) / bin_count; } foreach (double d in difficulties) { - double binIndex = binCount * (d / maxDifficulty) - 1; + double binIndex = bin_count * (d / maxDifficulty) - 1; int lowerBound = (int)Math.Floor(binIndex); double t = binIndex - lowerBound; @@ -45,7 +47,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils int upperBound = lowerBound + 1; // this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway - if (upperBound < binCount) + if (upperBound < bin_count) { bins[upperBound].Count += t; } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs new file mode 100644 index 0000000000..45dc5538e4 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs @@ -0,0 +1,49 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System.Linq; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + public class FitMissCountPoints + { + // A few operations are precomputed on the Vandermonde matrix of skill values (0, 0.05, 0.1, ..., 0.95, 1). + // This is to make regression simpler and less resource heavy. https://discord.com/channels/546120878908506119/1203154944492830791/1232333184306512002 + private static double[][] precomputedOperationsMatrix => new[] + { + new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 }, + new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 }, + new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 } + }; + + public static double[] GetPolynomialCoefficients(double[] missCounts) + { + double[] adjustedMissCounts = missCounts; + + // The polynomial will pass through the point (1, maxMisscount). + double maxMissCount = missCounts.Max(); + + for (int i = 0; i <= 20; i++) + { + adjustedMissCounts[i] -= maxMissCount * i / 20; + } + + // The precomputed matrix assumes the misscounts go in order of greatest to least. + // Temporary fix. + adjustedMissCounts = adjustedMissCounts.Reverse().ToArray(); + + double[] coefficients = new double[3]; + + // Now we dot product the adjusted misscounts with the precomputed matrix. + for (int row = 0; row < precomputedOperationsMatrix.Length; row++) + { + for (int column = 0; column < precomputedOperationsMatrix[row].Length; column++) + { + coefficients[row] += precomputedOperationsMatrix[row][column] * adjustedMissCounts[column]; + } + } + + return coefficients; + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs new file mode 100644 index 0000000000..1b99307b5f --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs @@ -0,0 +1,119 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + /// + /// Approximation of the Poisson binomial distribution: + /// https://en.wikipedia.org/wiki/Poisson_binomial_distribution + /// + /// + /// + /// For the approximation method, see "Refined Normal Approximation (RNA)" from: + /// Hong, Y. (2013). On computing the distribution function for the Poisson binomial distribution. Computational Statistics and Data Analysis, Vol. 59, pp. 41-51. + /// (https://www.researchgate.net/publication/257017356_On_computing_the_distribution_function_for_the_Poisson_binomial_distribution) + /// + /// + /// This has been verified against a reference implementation provided by the authors in the R package "poibin", + /// which can be viewed here: + /// https://rdrr.io/cran/poibin/man/poibin-package.html + /// + /// + public class PoissonBinomial + { + /// + /// The expected value of the distribution. + /// + private readonly double mu; + + /// + /// The standard deviation of the distribution. + /// + private readonly double sigma; + + /// + /// The gamma factor from equation (11) in the cited paper, pre-divided by 6 to save on re-computation. + /// + private readonly double v; + + /// + /// Creates a Poisson binomial distribution based on N trials with the provided difficulties, skill, and method for getting the miss probabilities. + /// + /// The list of difficulties in the map. + /// The skill level to get the miss probabilities with. + /// Converts difficulties and skill to miss probabilities. + public PoissonBinomial(IList difficulties, double skill, Func hitProbability) + { + double variance = 0; + double gamma = 0; + + foreach (double d in difficulties) + { + double p = 1 - hitProbability(skill, d); + + mu += p; + variance += p * (1 - p); + gamma += p * (1 - p) * (1 - 2 * p); + } + + sigma = Math.Sqrt(variance); + + v = gamma / (6 * Math.Pow(sigma, 3)); + } + + /// + /// Creates a Poisson binomial distribution based on N trials with the provided bins of difficulties, skill, and method for getting the miss probabilities. + /// + /// The bins of difficulties in the map. + /// The skill level to get the miss probabilities with. + /// /// Converts difficulties and skill to miss probabilities. + public PoissonBinomial(Bin[] bins, double skill, Func hitProbability) + { + double variance = 0; + double gamma = 0; + + foreach (Bin bin in bins) + { + double p = 1 - hitProbability(skill, bin.Difficulty); + + mu += p * bin.Count; + variance += p * (1 - p) * bin.Count; + gamma += p * (1 - p) * (1 - 2 * p) * bin.Count; + } + + sigma = Math.Sqrt(variance); + + v = gamma / (6 * Math.Pow(sigma, 3)); + } + + /// + /// Computes the value of the cumulative distribution function for this Poisson binomial distribution. + /// + /// + /// The argument of the CDF to sample the distribution for. + /// In the discrete case (when it is a whole number), this corresponds to the number + /// of successful Bernoulli trials to query the CDF for. + /// + /// + /// The value of the CDF at . + /// In the discrete case this corresponds to the probability that at most + /// Bernoulli trials ended in a success. + /// + // ReSharper disable once InconsistentNaming + public double CDF(double count) + { + if (sigma == 0) + return 1; + + double k = (count + 0.5 - mu) / sigma; + + // see equation (14) of the cited paper + double result = SpecialFunctions.NormalCdf(0, 1, k) + v * (1 - k * k) * SpecialFunctions.NormalPdf(0, 1, k); + + return Math.Clamp(result, 0, 1); + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs index 672f2396be..c926df8cbc 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs @@ -18,6 +18,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils { public class SpecialFunctions { + private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d; private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d; /// @@ -690,5 +691,42 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils return sum; } + + /// + /// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x. + /// + /// The mean (μ) of the normal distribution. + /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// The location at which to compute the density. + /// the density at . + /// MATLAB: normpdf + public static double NormalPdf(double mean, double stddev, double x) + { + if (stddev < 0.0) + { + throw new ArgumentException("Invalid parametrization for the distribution."); + } + + double d = (x - mean) / stddev; + return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev); + } + + /// + /// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x). + /// + /// The location at which to compute the cumulative distribution function. + /// The mean (μ) of the normal distribution. + /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// the cumulative distribution at location . + /// MATLAB: normcdf + public static double NormalCdf(double mean, double stddev, double x) + { + if (stddev < 0.0) + { + throw new ArgumentException("Invalid parametrization for the distribution."); + } + + return 0.5 * Erfc((mean - x) / (stddev * sqrt2)); + } } } From 85766894654dfcab0f7ded61bd7b6cce9a2425fd Mon Sep 17 00:00:00 2001 From: Nathen Date: Sat, 27 Apr 2024 15:35:03 -0400 Subject: [PATCH 08/14] Reduce polynomial to a cubic, move penalty slightly --- .../Difficulty/OsuDifficultyAttributes.cs | 2 +- .../Difficulty/OsuDifficultyCalculator.cs | 2 +- .../Difficulty/OsuPerformanceCalculator.cs | 27 +++++++++---------- .../Difficulty/Skills/Aim.cs | 4 +-- .../Difficulty/Utils/FitMissCountPoints.cs | 7 +++-- 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs index b37fa52a81..5bf7ff2de8 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs @@ -23,7 +23,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty /// The difficulty corresponding to the aim skill. /// [JsonProperty("aim_penalty_constants")] - public (double, double, double) AimPenaltyConstants { get; set; } + public (double, double) AimPenaltyConstants { get; set; } /// /// The difficulty corresponding to the speed skill. diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs index b9fc4f0993..cecf1479e8 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs @@ -37,7 +37,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty return new OsuDifficultyAttributes { Mods = mods }; double aimRating = Math.Sqrt(skills[0].DifficultyValue()) * difficulty_multiplier; - (double, double, double) aimPenaltyConstants = ((Aim)skills[0]).GetMissCountCoefficients(); + (double, double) aimPenaltyConstants = ((Aim)skills[0]).GetMissCountCoefficients(); double aimRatingNoSliders = Math.Sqrt(skills[1].DifficultyValue()) * difficulty_multiplier; double speedRating = Math.Sqrt(skills[2].DifficultyValue()) * difficulty_multiplier; double speedNotes = ((Speed)skills[2]).RelevantNoteCount(); diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index 336e19495b..c6096823c1 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -66,13 +66,14 @@ namespace osu.Game.Rulesets.Osu.Difficulty double speedValue = computeSpeedValue(score, osuAttributes); double accuracyValue = computeAccuracyValue(score, osuAttributes); double flashlightValue = computeFlashlightValue(score, osuAttributes); - double totalValue = + double totalValue = aimValue; + /* Math.Pow( Math.Pow(aimValue, 1.1) + Math.Pow(speedValue, 1.1) + Math.Pow(accuracyValue, 1.1) + Math.Pow(flashlightValue, 1.1), 1.0 / 1.1 - ) * multiplier; + ) * multiplier; */ return new OsuPerformanceAttributes { @@ -92,18 +93,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty double lengthBonus = 0.95 + 0.4 * Math.Min(1.0, totalHits / 2000.0) + (totalHits > 2000 ? Math.Log10(totalHits / 2000.0) * 0.5 : 0.0); - // Penalize misses by assessing # of misses relative to the total # of objects. Default a 3% reduction for any # of misses. - if (effectiveMissCount > 0) - { - double a = attributes.AimPenaltyConstants.Item1; - double b = attributes.AimPenaltyConstants.Item2; - double c = attributes.AimPenaltyConstants.Item3; - double d = Math.Log(totalHits + 1) - a - b - c; - - double penalty = Math.Pow(1 - RootFinding.FindRootExpand(x => Math.Exp(a * x * x * x * x + b * x * x * x + c * x * x + d * x) - 1 - effectiveMissCount, 0, 1), 1.5); - - aimValue *= penalty; - } + aimValue *= calculateAimMissPenalty(effectiveMissCount, attributes); double approachRateFactor = 0.0; if (attributes.ApproachRate > 10.33) @@ -253,6 +243,15 @@ namespace osu.Game.Rulesets.Osu.Difficulty return flashlightValue; } + private double calculateAimMissPenalty(double missCount, OsuDifficultyAttributes attributes) + { + double a = attributes.AimPenaltyConstants.Item1; + double b = attributes.AimPenaltyConstants.Item2; + double c = Math.Log(totalHits + 1) - a - b; // Setting the 3rd constant this way ensures that at a penalty of 100%, the number of misses = totalHits. + + return Math.Pow(1 - RootFinding.FindRootExpand(x => a * x * x * x + b * x * x + c * x - Math.Log(missCount + 1), 0, 1), 1.5); + } + private double calculateEffectiveMissCount(OsuDifficultyAttributes attributes) { // Guess the number of misses + slider breaks from combo diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index 2a8df994b4..9cbb5497dc 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -51,7 +51,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills /// The coefficients of a quartic fitted to the miss counts at each skill level. /// /// The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator. - public (double, double, double) GetMissCountCoefficients() + public (double, double) GetMissCountCoefficients() { const int count = 21; const double penalty_per_misscount = 1.0 / (count - 1); @@ -76,7 +76,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills double[] constants = FitMissCountPoints.GetPolynomialCoefficients(misscounts); - return (constants[0], constants[1], constants[2]); + return (constants[0], constants[1]); } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs index 45dc5538e4..6600c2df91 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs @@ -11,9 +11,8 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils // This is to make regression simpler and less resource heavy. https://discord.com/channels/546120878908506119/1203154944492830791/1232333184306512002 private static double[][] precomputedOperationsMatrix => new[] { - new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 }, - new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 }, - new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 } + new[] { 0.0, -0.897868, -1.5122, -1.8745, -2.01626, -1.96901, -1.76423, -1.43344, -1.00813, -0.519818, 3.55271e-15, 0.519818, 1.00813, 1.43344, 1.76423, 1.96901, 2.01626, 1.8745, 1.5122, 0.897868, 0.0 }, + new[] { 0.0, 1.27555, 2.1333, 2.62049, 2.78439, 2.67226, 2.33134, 1.8089, 1.1522, 0.408475, -0.375002, -1.15098, -1.8722, -2.49141, -2.96135, -3.23476, -3.2644, -3.00299, -2.4033, -1.41805, 0.0 }, }; public static double[] GetPolynomialCoefficients(double[] missCounts) @@ -32,7 +31,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils // Temporary fix. adjustedMissCounts = adjustedMissCounts.Reverse().ToArray(); - double[] coefficients = new double[3]; + double[] coefficients = new double[2]; // Now we dot product the adjusted misscounts with the precomputed matrix. for (int row = 0; row < precomputedOperationsMatrix.Length; row++) From dd951400a4d1908a6521d59e6a83649c00242df3 Mon Sep 17 00:00:00 2001 From: Nathen Date: Sat, 27 Apr 2024 15:35:49 -0400 Subject: [PATCH 09/14] Remove aimValue testing --- osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index c6096823c1..ab32a2a388 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -66,14 +66,13 @@ namespace osu.Game.Rulesets.Osu.Difficulty double speedValue = computeSpeedValue(score, osuAttributes); double accuracyValue = computeAccuracyValue(score, osuAttributes); double flashlightValue = computeFlashlightValue(score, osuAttributes); - double totalValue = aimValue; - /* + double totalValue = Math.Pow( Math.Pow(aimValue, 1.1) + Math.Pow(speedValue, 1.1) + Math.Pow(accuracyValue, 1.1) + Math.Pow(flashlightValue, 1.1), 1.0 / 1.1 - ) * multiplier; */ + ) * multiplier; return new OsuPerformanceAttributes { From 0e08858b17fd9004dea2191e2e2e46db6367ca5c Mon Sep 17 00:00:00 2001 From: Nathen Date: Tue, 30 Apr 2024 12:44:45 -0400 Subject: [PATCH 10/14] Refactoring and test solving polynomial algebraically --- .../Difficulty/OsuDifficultyAttributes.cs | 3 +- .../Difficulty/OsuDifficultyCalculator.cs | 5 +- .../Difficulty/OsuPerformanceCalculator.cs | 9 +- .../Difficulty/Skills/Aim.cs | 8 +- .../Difficulty/Utils/ExpPolynomial.cs | 90 +++++++ .../Difficulty/Utils/FitMissCountPoints.cs | 48 ---- .../Difficulty/Utils/SpecialFunctions.cs | 240 ++++++++++++++++++ 7 files changed, 344 insertions(+), 59 deletions(-) create mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs delete mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs index 5bf7ff2de8..672e6a933b 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs @@ -8,6 +8,7 @@ using Newtonsoft.Json; using osu.Game.Beatmaps; using osu.Game.Rulesets.Difficulty; using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Utils; namespace osu.Game.Rulesets.Osu.Difficulty { @@ -23,7 +24,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty /// The difficulty corresponding to the aim skill. /// [JsonProperty("aim_penalty_constants")] - public (double, double) AimPenaltyConstants { get; set; } + public ExpPolynomial AimMissCountPolynomial { get; set; } /// /// The difficulty corresponding to the speed skill. diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs index cecf1479e8..5c05045602 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs @@ -13,6 +13,7 @@ using osu.Game.Rulesets.Difficulty.Skills; using osu.Game.Rulesets.Mods; using osu.Game.Rulesets.Osu.Difficulty.Preprocessing; using osu.Game.Rulesets.Osu.Difficulty.Skills; +using osu.Game.Rulesets.Osu.Difficulty.Utils; using osu.Game.Rulesets.Osu.Mods; using osu.Game.Rulesets.Osu.Objects; using osu.Game.Rulesets.Osu.Scoring; @@ -37,7 +38,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty return new OsuDifficultyAttributes { Mods = mods }; double aimRating = Math.Sqrt(skills[0].DifficultyValue()) * difficulty_multiplier; - (double, double) aimPenaltyConstants = ((Aim)skills[0]).GetMissCountCoefficients(); + ExpPolynomial aimMissCountPolynomial = ((Aim)skills[0]).GetMissCountPolynomial(); double aimRatingNoSliders = Math.Sqrt(skills[1].DifficultyValue()) * difficulty_multiplier; double speedRating = Math.Sqrt(skills[2].DifficultyValue()) * difficulty_multiplier; double speedNotes = ((Speed)skills[2]).RelevantNoteCount(); @@ -98,7 +99,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty StarRating = starRating, Mods = mods, AimDifficulty = aimRating, - AimPenaltyConstants = aimPenaltyConstants, + AimMissCountPolynomial = aimMissCountPolynomial, SpeedDifficulty = speedRating, SpeedNoteCount = speedNotes, FlashlightDifficulty = flashlightRating, diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index ab32a2a388..1680885853 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Linq; using osu.Game.Rulesets.Difficulty; -using osu.Game.Rulesets.Osu.Difficulty.Utils; using osu.Game.Rulesets.Osu.Mods; using osu.Game.Rulesets.Scoring; using osu.Game.Scoring; @@ -244,11 +243,11 @@ namespace osu.Game.Rulesets.Osu.Difficulty private double calculateAimMissPenalty(double missCount, OsuDifficultyAttributes attributes) { - double a = attributes.AimPenaltyConstants.Item1; - double b = attributes.AimPenaltyConstants.Item2; - double c = Math.Log(totalHits + 1) - a - b; // Setting the 3rd constant this way ensures that at a penalty of 100%, the number of misses = totalHits. + double penalty = attributes.AimMissCountPolynomial.SolveBetweenZeroAndOne(missCount) ?? 1; - return Math.Pow(1 - RootFinding.FindRootExpand(x => a * x * x * x + b * x * x + c * x - Math.Log(missCount + 1), 0, 1), 1.5); + double multiplier = Math.Pow(1 - penalty, 1.5); + + return multiplier; } private double calculateEffectiveMissCount(OsuDifficultyAttributes attributes) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index 9cbb5497dc..ba64121565 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -51,7 +51,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills /// The coefficients of a quartic fitted to the miss counts at each skill level. /// /// The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator. - public (double, double) GetMissCountCoefficients() + public ExpPolynomial GetMissCountPolynomial() { const int count = 21; const double penalty_per_misscount = 1.0 / (count - 1); @@ -74,9 +74,11 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills misscounts[i] = Math.Log(GetMissCountAtSkill(penalizedSkill) + 1); } - double[] constants = FitMissCountPoints.GetPolynomialCoefficients(misscounts); + ExpPolynomial polynomial = new ExpPolynomial(); - return (constants[0], constants[1]); + polynomial.Compute(misscounts, 3); + + return polynomial; } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs new file mode 100644 index 0000000000..a29a636d30 --- /dev/null +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs @@ -0,0 +1,90 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace osu.Game.Rulesets.Osu.Difficulty.Utils +{ + public struct ExpPolynomial + { + private static double[]? coefficients; + + // The product of this matrix with 21 computed points at X values [0.0, 0.05, ..., 0.95, 1.0] returns the least squares fit polynomial coefficients. + private static double[][] quarticMatrix => new[] + { + new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 }, + new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 }, + new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 } + }; + + private static double[][] cubicMatrix => new[] + { + new[] { 0.0, -0.897868, -1.5122, -1.8745, -2.01626, -1.96901, -1.76423, -1.43344, -1.00813, -0.519818, 3.55271e-15, 0.519818, 1.00813, 1.43344, 1.76423, 1.96901, 2.01626, 1.8745, 1.5122, 0.897868, 0.0 }, + new[] { 0.0, 1.27555, 2.1333, 2.62049, 2.78439, 2.67226, 2.33134, 1.8089, 1.1522, 0.408475, -0.375002, -1.15098, -1.8722, -2.49141, -2.96135, -3.23476, -3.2644, -3.00299, -2.4033, -1.41805, 0.0 }, + }; + + /// + /// Computes a quartic or cubic function that starts at 0 and ends at the highest judgement count in the array. + /// + /// A list of judgements, with X values [0.0, 0.05, ..., 0.95, 1.0]. + /// The degree of the polynomial. Only supports cubic and quintic functions. + public void Compute(double[] judgementCounts, int degree) + { + if (degree != 3 && degree != 4) + return; + + double[] adjustedMissCounts = judgementCounts; + + // The polynomial will pass through the point (1, maxMisscount). + double maxMissCount = judgementCounts.Max(); + + for (int i = 0; i <= 20; i++) + { + adjustedMissCounts[i] -= maxMissCount * i / 20; + } + + // The precomputed matrix assumes the misscounts go in order of greatest to least. + // Temporary fix. + adjustedMissCounts = adjustedMissCounts.Reverse().ToArray(); + + double[][] matrix = degree == 4 ? quarticMatrix : cubicMatrix; + coefficients = new double[degree]; + + coefficients[degree - 1] = maxMissCount; + + // Now we dot product the adjusted misscounts with the precomputed matrix. + for (int row = 0; row < matrix.Length; row++) + { + for (int column = 0; column < matrix[row].Length; column++) + { + coefficients[row] += matrix[row][column] * adjustedMissCounts[column]; + } + + coefficients[degree - 1] -= coefficients[row]; + } + } + + /// + /// Solve for the largest corresponding x value of a polynomial within x = 0 and x = 1 at a specified y value. + /// + /// A value between 0 and 1, inclusive, to solve the polynomial at. + /// The x value at the specified y value, and null if no value exists. + public double? SolveBetweenZeroAndOne(double y) + { + if (coefficients is null) + return null; + + List listCoefficients = coefficients.ToList(); + listCoefficients.Add(-Math.Log(y + 1)); + + List xVals = SpecialFunctions.SolvePolynomialRoots(listCoefficients); + + const double max_error = 1e-7; + double? largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault(); + + return largestValue != null ? Math.Clamp(largestValue.Value, 0, 1) : null; + } + } +} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs deleted file mode 100644 index 6600c2df91..0000000000 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. -// See the LICENCE file in the repository root for full licence text. - -using System.Linq; - -namespace osu.Game.Rulesets.Osu.Difficulty.Utils -{ - public class FitMissCountPoints - { - // A few operations are precomputed on the Vandermonde matrix of skill values (0, 0.05, 0.1, ..., 0.95, 1). - // This is to make regression simpler and less resource heavy. https://discord.com/channels/546120878908506119/1203154944492830791/1232333184306512002 - private static double[][] precomputedOperationsMatrix => new[] - { - new[] { 0.0, -0.897868, -1.5122, -1.8745, -2.01626, -1.96901, -1.76423, -1.43344, -1.00813, -0.519818, 3.55271e-15, 0.519818, 1.00813, 1.43344, 1.76423, 1.96901, 2.01626, 1.8745, 1.5122, 0.897868, 0.0 }, - new[] { 0.0, 1.27555, 2.1333, 2.62049, 2.78439, 2.67226, 2.33134, 1.8089, 1.1522, 0.408475, -0.375002, -1.15098, -1.8722, -2.49141, -2.96135, -3.23476, -3.2644, -3.00299, -2.4033, -1.41805, 0.0 }, - }; - - public static double[] GetPolynomialCoefficients(double[] missCounts) - { - double[] adjustedMissCounts = missCounts; - - // The polynomial will pass through the point (1, maxMisscount). - double maxMissCount = missCounts.Max(); - - for (int i = 0; i <= 20; i++) - { - adjustedMissCounts[i] -= maxMissCount * i / 20; - } - - // The precomputed matrix assumes the misscounts go in order of greatest to least. - // Temporary fix. - adjustedMissCounts = adjustedMissCounts.Reverse().ToArray(); - - double[] coefficients = new double[2]; - - // Now we dot product the adjusted misscounts with the precomputed matrix. - for (int row = 0; row < precomputedOperationsMatrix.Length; row++) - { - for (int column = 0; column < precomputedOperationsMatrix[row].Length; column++) - { - coefficients[row] += precomputedOperationsMatrix[row][column] * adjustedMissCounts[column]; - } - } - - return coefficients; - } - } -} diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs index c926df8cbc..c183dbecb3 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs @@ -13,6 +13,9 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI */ using System; +using System.Collections.Generic; +using System.Linq; +using osu.Framework.Utils; namespace osu.Game.Rulesets.Osu.Difficulty.Utils { @@ -20,6 +23,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils { private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d; private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d; + private const double m_2_pi = 6.28318530717958647692528676655900576d; /// /// ************************************** @@ -728,5 +732,241 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils return 0.5 * Erfc((mean - x) / (stddev * sqrt2)); } + + /// + /// Solve for the exact real roots of any polynomial up to degree 4. + /// + /// The coefficients of the polynomial, in ascending order ([1, 3, 5] -> x^2 + 3x + 5). + /// The real roots of the polynomial, and null if the root does not exist. + public static List SolvePolynomialRoots(List coefficients) + { + List xVals = new List(); + + switch (coefficients.Count) + { + case 5: + xVals = solveP4(coefficients[0], coefficients[1], coefficients[2], coefficients[3], coefficients[4], out int _).ToList(); + break; + + case 4: + xVals = solveP3(coefficients[0], coefficients[1], coefficients[2], coefficients[3], out int _).ToList(); + break; + + case 3: + xVals = solveP2(coefficients[0], coefficients[1], coefficients[2], out int _).ToList(); + break; + + case 2: + xVals = solveP2(0, coefficients[1], coefficients[2], out int _).ToList(); + break; + } + + return xVals; + } + + private static double?[] solveP4(double a, double b, double c, double d, double e, out int nRoots) + { + double?[] xVals = new double?[4]; + + nRoots = 0; + + if (a == 0) + { + double?[] xValsCubic = solveP3(b, c, d, e, out nRoots); + + xVals[0] = xValsCubic[0]; + xVals[1] = xValsCubic[1]; + xVals[2] = xValsCubic[2]; + xVals[3] = null; + + return xVals; + } + + b /= a; + c /= a; + d /= a; + + double a3 = -c; + double b3 = b * d - 4 * e; + double c3 = -b * b * e - d * d + 4 * c * e; + + double?[] x3 = solveP3(1, a3, b3, c3, out int iZeroes); + + double q1, q2, p1, p2, sqD; + + double y = x3[0]!.Value; + + // Get the y value with the highest absolute value. + if (iZeroes != 1) + { + if (Math.Abs(x3[1]!.Value) > Math.Abs(y)) + y = x3[1]!.Value; + if (Math.Abs(x3[2]!.Value) > Math.Abs(y)) + y = x3[2]!.Value; + } + + double upperD = y * y - 4 * e; + + if (Precision.AlmostEquals(upperD, 0)) + { + q1 = q2 = y * 0.5; + + upperD = b * b - 4 * (c - y); + + if (Precision.AlmostEquals(upperD, 0)) + p1 = p2 = b * 0.5; + + else + { + sqD = Math.Sqrt(upperD); + p1 = (b + sqD) * 0.5; + p2 = (b - sqD) * 0.5; + } + } + else + { + sqD = Math.Sqrt(upperD); + q1 = (y + sqD) * 0.5; + q2 = (y - sqD) * 0.5; + + p1 = (b * q1 - c) / (q1 - q2); + p2 = (d - b * q2) / (q1 - q2); + } + + // solving quadratic eq. - x^2 + p1*x + q1 = 0 + upperD = p1 * p1 - 4 * q1; + + if (upperD >= 0) + { + nRoots += 2; + + sqD = Math.Sqrt(upperD); + xVals[0] = (-p1 + sqD) * 0.5; + xVals[1] = (-p1 - sqD) * 0.5; + } + + // solving quadratic eq. - x^2 + p2*x + q2 = 0 + upperD = p2 * p2 - 4 * q2; + + if (upperD >= 0) + { + nRoots += 2; + + sqD = Math.Sqrt(upperD); + xVals[2] = (-p2 + sqD) * 0.5; + xVals[3] = (-p2 - sqD) * 0.5; + } + + // Put the null roots at the end of the array. + var nonNulls = xVals.Where(x => x != null); + var nulls = xVals.Where(x => x == null); + xVals = nonNulls.Concat(nulls).ToArray(); + + return xVals; + } + + private static double?[] solveP3(double a, double b, double c, double d, out int nRoots) + { + double?[] xVals = new double?[3]; + + nRoots = 0; + + if (a == 0) + { + double?[] xValsQuadratic = solveP2(b, c, d, out nRoots); + + xVals[0] = xValsQuadratic[0]; + xVals[1] = xValsQuadratic[1]; + xVals[2] = null; + + return xVals; + } + + b /= a; + c /= a; + d /= a; + + double b2 = b * b; + double q = (b2 - 3 * c) / 9; + double q3 = q * q * q; + double r = (b * (2 * b2 - 9 * c) + 27 * d) / 54; + double r2 = r * r; + + if (r2 < q3) + { + nRoots = 3; + + double t = r / Math.Sqrt(q3); + t = Math.Clamp(t, -1, 1); + t = Math.Acos(t); + b /= 3; + q = -2 * Math.Sqrt(2); + + xVals[0] = q * Math.Cos(t / 3) - b; + xVals[1] = q * Math.Cos((t + m_2_pi) / 3) - b; + xVals[2] = q * Math.Cos((t - m_2_pi) / 3) - b; + + return xVals; + } + + double upperA = -Math.Cbrt(Math.Abs(r) + double.Sqrt(r2 - q3)); + + if (r < 0) + upperA = -upperA; + + double upperB = upperA == 0 ? 0 : q / upperA; + b /= 3; + + xVals[0] = upperA + upperB - b; + + if (Precision.AlmostEquals(0.5 * Math.Sqrt(3) * (upperA - upperB), 0)) + { + nRoots = 2; + xVals[1] = -0.5 * (upperA + upperB) - b; + + return xVals; + } + + nRoots = 1; + + return xVals; + } + + private static double?[] solveP2(double a, double b, double c, out int nRoots) + { + double?[] xVals = new double?[2]; + + nRoots = 0; + + if (a == 0) + { + if (b == 0) + return xVals; + + nRoots = 1; + xVals[0] = -c / b; + } + + double discriminant = b * b - 4 * a * c; + + switch (discriminant) + { + case < 0: + break; + + case 0: + nRoots = 1; + xVals[0] = -b / (2 * a); + break; + + default: + nRoots = 2; + xVals[0] = (-b + Math.Sqrt(discriminant)) / (2 * a); + xVals[1] = (-b - Math.Sqrt(discriminant)) / (2 * a); + break; + } + + return xVals; + } } } From 52984affc5c9f86f1c7336301123db0d2c36794d Mon Sep 17 00:00:00 2001 From: Nathen Date: Tue, 30 Apr 2024 17:59:24 -0400 Subject: [PATCH 11/14] Bug fixes and a bit of refactorage --- .../Difficulty/Skills/Aim.cs | 3 +-- .../Difficulty/Utils/ExpPolynomial.cs | 22 +++++++++---------- .../Difficulty/Utils/SpecialFunctions.cs | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index ba64121565..3e27c06292 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -70,8 +70,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills double penalizedSkill = fcSkill - fcSkill * penalty_per_misscount * i; - // Save misscounts as log form to give higher weight to lower values. Add 1 so that the lowest misscounts remain above 0. - misscounts[i] = Math.Log(GetMissCountAtSkill(penalizedSkill) + 1); + misscounts[i] = GetMissCountAtSkill(penalizedSkill); } ExpPolynomial polynomial = new ExpPolynomial(); diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs index a29a636d30..f88243da67 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs @@ -9,17 +9,17 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils { public struct ExpPolynomial { - private static double[]? coefficients; + private double[]? coefficients; // The product of this matrix with 21 computed points at X values [0.0, 0.05, ..., 0.95, 1.0] returns the least squares fit polynomial coefficients. - private static double[][] quarticMatrix => new[] + private static readonly double[][] quartic_matrix = { new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 }, new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 }, new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 } }; - private static double[][] cubicMatrix => new[] + private static readonly double[][] cubic_matrix = { new[] { 0.0, -0.897868, -1.5122, -1.8745, -2.01626, -1.96901, -1.76423, -1.43344, -1.00813, -0.519818, 3.55271e-15, 0.519818, 1.00813, 1.43344, 1.76423, 1.96901, 2.01626, 1.8745, 1.5122, 0.897868, 0.0 }, new[] { 0.0, 1.27555, 2.1333, 2.62049, 2.78439, 2.67226, 2.33134, 1.8089, 1.1522, 0.408475, -0.375002, -1.15098, -1.8722, -2.49141, -2.96135, -3.23476, -3.2644, -3.00299, -2.4033, -1.41805, 0.0 }, @@ -35,31 +35,31 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils if (degree != 3 && degree != 4) return; - double[] adjustedMissCounts = judgementCounts; + List logJudgementCounts = judgementCounts.Select(x => Math.Log(x + 1)).ToList(); - // The polynomial will pass through the point (1, maxMisscount). - double maxMissCount = judgementCounts.Max(); + // The polynomial will pass through the point (1, endPoint). + double endPoint = logJudgementCounts.Max(); for (int i = 0; i <= 20; i++) { - adjustedMissCounts[i] -= maxMissCount * i / 20; + logJudgementCounts[i] -= endPoint * i / 20; } // The precomputed matrix assumes the misscounts go in order of greatest to least. // Temporary fix. - adjustedMissCounts = adjustedMissCounts.Reverse().ToArray(); + logJudgementCounts.Reverse(); - double[][] matrix = degree == 4 ? quarticMatrix : cubicMatrix; + double[][] matrix = degree == 4 ? quartic_matrix : cubic_matrix; coefficients = new double[degree]; - coefficients[degree - 1] = maxMissCount; + coefficients[degree - 1] = endPoint; // Now we dot product the adjusted misscounts with the precomputed matrix. for (int row = 0; row < matrix.Length; row++) { for (int column = 0; column < matrix[row].Length; column++) { - coefficients[row] += matrix[row][column] * adjustedMissCounts[column]; + coefficients[row] += matrix[row][column] * logJudgementCounts[column]; } coefficients[degree - 1] -= coefficients[row]; diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs index c183dbecb3..c8853ab8a5 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs @@ -900,7 +900,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils t = Math.Clamp(t, -1, 1); t = Math.Acos(t); b /= 3; - q = -2 * Math.Sqrt(2); + q = -2 * Math.Sqrt(q); xVals[0] = q * Math.Cos(t / 3) - b; xVals[1] = q * Math.Cos((t + m_2_pi) / 3) - b; From 7bd4c3d3134c50d33e4c32faa4895feea6d1d7a6 Mon Sep 17 00:00:00 2001 From: Nathen Date: Sat, 12 Oct 2024 22:34:53 -0400 Subject: [PATCH 12/14] Clean everything up --- .../{Skills => Aggregation}/OsuProbSkill.cs | 53 +++++++++++++----- .../{Skills => Aggregation}/OsuStrainSkill.cs | 2 +- .../Difficulty/OsuDifficultyAttributes.cs | 2 +- .../Difficulty/OsuDifficultyCalculator.cs | 7 +-- .../Difficulty/OsuPerformanceCalculator.cs | 3 +- .../Difficulty/Skills/Aim.cs | 34 +----------- .../Difficulty/Skills/Speed.cs | 1 + .../Difficulty/Utils/ExpPolynomial.cs | 55 +++++++------------ 8 files changed, 70 insertions(+), 87 deletions(-) rename osu.Game.Rulesets.Osu/Difficulty/{Skills => Aggregation}/OsuProbSkill.cs (71%) rename osu.Game.Rulesets.Osu/Difficulty/{Skills => Aggregation}/OsuStrainSkill.cs (98%) diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbSkill.cs similarity index 71% rename from osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs rename to osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbSkill.cs index 0fece3069f..25bccf869b 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbSkill.cs @@ -9,7 +9,7 @@ using osu.Game.Rulesets.Difficulty.Skills; using osu.Game.Rulesets.Mods; using osu.Game.Rulesets.Osu.Difficulty.Utils; -namespace osu.Game.Rulesets.Osu.Difficulty.Skills +namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation { public abstract class OsuProbSkill : Skill { @@ -94,10 +94,47 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills return difficulties.Count < 64 ? difficultyValueExact() : difficultyValueBinned(); } + /// + /// The coefficients of a quartic fitted to the miss counts at each skill level. + /// + /// The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator. + public ExpPolynomial GetMissPenaltyCurve() + { + double[] missCounts = new double[7]; + double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 }; + + double fcSkill = DifficultyValue(); + + ExpPolynomial curve = new ExpPolynomial(); + + // If there are no notes, we just return the empty polynomial. + if (difficulties.Count == 0 || difficulties.Max() == 0) + return curve; + + var bins = Bin.CreateBins(difficulties); + + for (int i = 0; i < penalties.Length; i++) + { + if (i == 0) + { + missCounts[i] = 0; + continue; + } + + double penalizedSkill = fcSkill * penalties[i]; + + missCounts[i] = getMissCountAtSkill(penalizedSkill, bins); + } + + curve.Fit(missCounts); + + return curve; + } + /// /// Find the lowest misscount that a player with the provided would have a 2% chance of achieving. /// - public double GetMissCountAtSkill(double skill) + private double getMissCountAtSkill(double skill, Bin[] bins) { double maxDiff = difficulties.Max(); @@ -106,17 +143,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills if (skill <= 0) return difficulties.Count; - PoissonBinomial poiBin; - - if (difficulties.Count > 64) - { - var bins = Bin.CreateBins(difficulties); - poiBin = new PoissonBinomial(bins, skill, HitProbability); - } - else - { - poiBin = new PoissonBinomial(difficulties, skill, HitProbability); - } + var poiBin = difficulties.Count > 64 ? new PoissonBinomial(bins, skill, HitProbability) : new PoissonBinomial(difficulties, skill, HitProbability); return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - FcProbability, -50, 1000, accuracy: 1e-4)); } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuStrainSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuStrainSkill.cs similarity index 98% rename from osu.Game.Rulesets.Osu/Difficulty/Skills/OsuStrainSkill.cs rename to osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuStrainSkill.cs index 96180c0aa1..2754ec9718 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/OsuStrainSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuStrainSkill.cs @@ -8,7 +8,7 @@ using osu.Game.Rulesets.Mods; using System.Linq; using osu.Framework.Utils; -namespace osu.Game.Rulesets.Osu.Difficulty.Skills +namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation { public abstract class OsuStrainSkill : StrainSkill { diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs index e970ab1461..e2c6b8fe99 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyAttributes.cs @@ -24,7 +24,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty /// The difficulty corresponding to the aim skill. /// [JsonProperty("aim_penalty_constants")] - public ExpPolynomial AimMissCountPolynomial { get; set; } + public ExpPolynomial AimMissPenaltyCurve { get; set; } /// /// The difficulty corresponding to the speed skill. diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs index d991b29bcf..64f75b9319 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs @@ -11,6 +11,7 @@ using osu.Game.Rulesets.Difficulty; using osu.Game.Rulesets.Difficulty.Preprocessing; using osu.Game.Rulesets.Difficulty.Skills; using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; using osu.Game.Rulesets.Osu.Difficulty.Preprocessing; using osu.Game.Rulesets.Osu.Difficulty.Skills; using osu.Game.Rulesets.Osu.Difficulty.Utils; @@ -38,7 +39,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty return new OsuDifficultyAttributes { Mods = mods }; double aimRating = Math.Sqrt(skills[0].DifficultyValue()) * difficulty_multiplier; - ExpPolynomial aimMissCountPolynomial = ((Aim)skills[0]).GetMissCountPolynomial(); double aimRatingNoSliders = Math.Sqrt(skills[1].DifficultyValue()) * difficulty_multiplier; double speedRating = Math.Sqrt(skills[2].DifficultyValue()) * difficulty_multiplier; double speedNotes = ((Speed)skills[2]).RelevantNoteCount(); @@ -50,7 +50,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty double sliderFactor = aimRating > 0 ? aimRatingNoSliders / aimRating : 1; - double aimDifficultyStrainCount = ((OsuStrainSkill)skills[0]).CountDifficultStrains(); + ExpPolynomial aimMissPenaltyCurve = ((Aim)skills[0]).GetMissPenaltyCurve(); double speedDifficultyStrainCount = ((OsuStrainSkill)skills[2]).CountDifficultStrains(); if (mods.Any(m => m is OsuModTouchDevice)) @@ -101,12 +101,11 @@ namespace osu.Game.Rulesets.Osu.Difficulty StarRating = starRating, Mods = mods, AimDifficulty = aimRating, - AimMissCountPolynomial = aimMissCountPolynomial, SpeedDifficulty = speedRating, SpeedNoteCount = speedNotes, FlashlightDifficulty = flashlightRating, SliderFactor = sliderFactor, - AimDifficultStrainCount = aimDifficultyStrainCount, + AimMissPenaltyCurve = aimMissPenaltyCurve, SpeedDifficultStrainCount = speedDifficultyStrainCount, ApproachRate = preempt > 1200 ? (1800 - preempt) / 120 : (1200 - preempt) / 150 + 5, OverallDifficulty = (80 - hitWindowGreat) / 6, diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index 4733eefbce..17a9abecce 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; using osu.Game.Rulesets.Difficulty; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; using osu.Game.Rulesets.Osu.Difficulty.Skills; using osu.Game.Rulesets.Osu.Mods; using osu.Game.Rulesets.Scoring; @@ -247,7 +248,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty private double calculateAimMissPenalty(double missCount, OsuDifficultyAttributes attributes) { - double penalty = attributes.AimMissCountPolynomial.SolveBetweenZeroAndOne(missCount) ?? 1; + double penalty = attributes.AimMissPenaltyCurve.GetPenaltyAt(missCount); double multiplier = Math.Pow(1 - penalty, 1.5); diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index 1a966b14d7..5c5852c4dd 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -4,6 +4,7 @@ using System; using osu.Game.Rulesets.Difficulty.Preprocessing; using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; using osu.Game.Rulesets.Osu.Difficulty.Evaluators; using osu.Game.Rulesets.Osu.Difficulty.Utils; @@ -46,38 +47,5 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills return currentStrain; } - - /// - /// The coefficients of a quartic fitted to the miss counts at each skill level. - /// - /// The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator. - public ExpPolynomial GetMissCountPolynomial() - { - const int count = 21; - const double penalty_per_misscount = 1.0 / (count - 1); - - double fcSkill = DifficultyValue(); - - double[] misscounts = new double[count]; - - for (int i = 0; i < count; i++) - { - if (i == 0) - { - misscounts[i] = 0; - continue; - } - - double penalizedSkill = fcSkill - fcSkill * penalty_per_misscount * i; - - misscounts[i] = GetMissCountAtSkill(penalizedSkill); - } - - ExpPolynomial polynomial = new ExpPolynomial(); - - polynomial.Compute(misscounts, 3); - - return polynomial; - } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs index e5aa25c1eb..88230c54ff 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Speed.cs @@ -7,6 +7,7 @@ using osu.Game.Rulesets.Mods; using osu.Game.Rulesets.Osu.Difficulty.Evaluators; using osu.Game.Rulesets.Osu.Difficulty.Preprocessing; using System.Linq; +using osu.Game.Rulesets.Osu.Difficulty.Aggregation; namespace osu.Game.Rulesets.Osu.Difficulty.Skills { diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs index f88243da67..1e814556ed 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs @@ -11,80 +11,67 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils { private double[]? coefficients; - // The product of this matrix with 21 computed points at X values [0.0, 0.05, ..., 0.95, 1.0] returns the least squares fit polynomial coefficients. - private static readonly double[][] quartic_matrix = + private static readonly double[][] matrix = { - new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 }, - new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 }, - new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 } - }; - - private static readonly double[][] cubic_matrix = - { - new[] { 0.0, -0.897868, -1.5122, -1.8745, -2.01626, -1.96901, -1.76423, -1.43344, -1.00813, -0.519818, 3.55271e-15, 0.519818, 1.00813, 1.43344, 1.76423, 1.96901, 2.01626, 1.8745, 1.5122, 0.897868, 0.0 }, - new[] { 0.0, 1.27555, 2.1333, 2.62049, 2.78439, 2.67226, 2.33134, 1.8089, 1.1522, 0.408475, -0.375002, -1.15098, -1.8722, -2.49141, -2.96135, -3.23476, -3.2644, -3.00299, -2.4033, -1.41805, 0.0 }, + new[] { 0.0, 3.14395, 5.18439, 6.46975, 1.4638, -9.53526, 0.0 }, + new[] { 0.0, -4.85829, -8.09612, -10.4498, -3.84479, 12.1626, 0.0 } }; /// /// Computes a quartic or cubic function that starts at 0 and ends at the highest judgement count in the array. /// /// A list of judgements, with X values [0.0, 0.05, ..., 0.95, 1.0]. - /// The degree of the polynomial. Only supports cubic and quintic functions. - public void Compute(double[] judgementCounts, int degree) + public void Fit(double[] judgementCounts) { - if (degree != 3 && degree != 4) - return; - - List logJudgementCounts = judgementCounts.Select(x => Math.Log(x + 1)).ToList(); + List logMissCounts = judgementCounts.Select(x => Math.Log(x + 1)).ToList(); // The polynomial will pass through the point (1, endPoint). - double endPoint = logJudgementCounts.Max(); + double endPoint = logMissCounts.Max(); - for (int i = 0; i <= 20; i++) + double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 }; + + for (int i = 0; i < logMissCounts.Count; i++) { - logJudgementCounts[i] -= endPoint * i / 20; + logMissCounts[i] -= endPoint * (1 - penalties[i]); } // The precomputed matrix assumes the misscounts go in order of greatest to least. - // Temporary fix. - logJudgementCounts.Reverse(); + logMissCounts.Reverse(); - double[][] matrix = degree == 4 ? quartic_matrix : cubic_matrix; - coefficients = new double[degree]; + coefficients = new double[3]; - coefficients[degree - 1] = endPoint; + coefficients[2] = endPoint; // Now we dot product the adjusted misscounts with the precomputed matrix. for (int row = 0; row < matrix.Length; row++) { for (int column = 0; column < matrix[row].Length; column++) { - coefficients[row] += matrix[row][column] * logJudgementCounts[column]; + coefficients[row] += matrix[row][column] * logMissCounts[column]; } - coefficients[degree - 1] -= coefficients[row]; + coefficients[2] -= coefficients[row]; } } /// - /// Solve for the largest corresponding x value of a polynomial within x = 0 and x = 1 at a specified y value. + /// Solve for the miss penalty at a specified miss count. /// - /// A value between 0 and 1, inclusive, to solve the polynomial at. - /// The x value at the specified y value, and null if no value exists. - public double? SolveBetweenZeroAndOne(double y) + /// The penalty value at the specified miss count. + public double GetPenaltyAt(double missCount) { if (coefficients is null) - return null; + return 1; List listCoefficients = coefficients.ToList(); - listCoefficients.Add(-Math.Log(y + 1)); + listCoefficients.Add(-Math.Log(missCount + 1)); List xVals = SpecialFunctions.SolvePolynomialRoots(listCoefficients); const double max_error = 1e-7; double? largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault(); - return largestValue != null ? Math.Clamp(largestValue.Value, 0, 1) : null; + return largestValue != null ? Math.Clamp(largestValue.Value, 0, 1) : 1; } } } From 45174a417039dc933e870a93e206ff3c66297e78 Mon Sep 17 00:00:00 2001 From: Nathen Date: Sat, 30 Nov 2024 18:56:57 -0500 Subject: [PATCH 13/14] General cleanup, fixes, and renames --- ...OsuProbSkill.cs => OsuProbabilitySkill.cs} | 110 +- .../Difficulty/OsuDifficultyCalculator.cs | 2 +- .../Difficulty/OsuPerformanceCalculator.cs | 14 +- .../Difficulty/Skills/Aim.cs | 10 +- osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs | 27 +- .../Difficulty/Utils/ExpPolynomial.cs | 33 +- .../Difficulty/Utils/PoissonBinomial.cs | 3 +- .../Difficulty/Utils/RootFinding.cs | 20 +- .../Difficulty/Utils/SpecialFunctions.cs | 972 ------------------ osu.Game/Utils/SpecialFunctions.cs | 280 +++++ 10 files changed, 407 insertions(+), 1064 deletions(-) rename osu.Game.Rulesets.Osu/Difficulty/Aggregation/{OsuProbSkill.cs => OsuProbabilitySkill.cs} (64%) delete mode 100644 osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs diff --git a/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbSkill.cs b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbabilitySkill.cs similarity index 64% rename from osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbSkill.cs rename to osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbabilitySkill.cs index 25bccf869b..dde817a2e1 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbSkill.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Aggregation/OsuProbabilitySkill.cs @@ -11,16 +11,22 @@ using osu.Game.Rulesets.Osu.Difficulty.Utils; namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation { - public abstract class OsuProbSkill : Skill + public abstract class OsuProbabilitySkill : Skill { - protected OsuProbSkill(Mod[] mods) + protected OsuProbabilitySkill(Mod[] mods) : base(mods) { } - /// The skill level returned from this class will have FcProbability chance of hitting every note correctly. - /// A higher value rewards short, high difficulty sections, whereas a lower value rewards consistent, lower difficulty. - protected abstract double FcProbability { get; } + // We assume players have a 2% chance to hit every note in the map. + // A higher value of fc_probability increases the influence of difficulty spikes, + // while a lower value increases the influence of length and consistent difficulty. + private const double fc_probability = 0.02; + + private const int bin_count = 32; + + // The number of difficulties there must be before we can be sure that binning difficulties would not change the output significantly. + private double binThreshold => 2 * bin_count; private readonly List difficulties = new List(); @@ -36,32 +42,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation protected abstract double HitProbability(double skill, double difficulty); - private double difficultyValueBinned() - { - double maxDiff = difficulties.Max(); - if (maxDiff <= 1e-10) return 0; - - var bins = Bin.CreateBins(difficulties); - - const double lower_bound = 0; - double upperBoundEstimate = 3.0 * maxDiff; - - double skill = RootFinding.FindRootExpand( - skill => fcProbability(skill) - FcProbability, - lower_bound, - upperBoundEstimate, - accuracy: 1e-4); - - return skill; - - double fcProbability(double s) - { - if (s <= 0) return 0; - - return bins.Aggregate(1.0, (current, bin) => current * Math.Pow(HitProbability(s, bin.Difficulty), bin.Count)); - } - } - private double difficultyValueExact() { double maxDiff = difficulties.Max(); @@ -71,7 +51,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation double upperBoundEstimate = 3.0 * maxDiff; double skill = RootFinding.FindRootExpand( - skill => fcProbability(skill) - FcProbability, + skill => fcProbability(skill) - fc_probability, lower_bound, upperBoundEstimate, accuracy: 1e-4); @@ -86,32 +66,56 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation } } - public override double DifficultyValue() + private double difficultyValueBinned() { - if (difficulties.Count == 0) - return 0; + double maxDiff = difficulties.Max(); + if (maxDiff <= 1e-10) return 0; - return difficulties.Count < 64 ? difficultyValueExact() : difficultyValueBinned(); + var bins = Bin.CreateBins(difficulties, bin_count); + + const double lower_bound = 0; + double upperBoundEstimate = 3.0 * maxDiff; + + double skill = RootFinding.FindRootExpand( + skill => fcProbability(skill) - fc_probability, + lower_bound, + upperBoundEstimate, + accuracy: 1e-4); + + return skill; + + double fcProbability(double s) + { + if (s <= 0) return 0; + + return bins.Aggregate(1.0, (current, bin) => current * Math.Pow(HitProbability(s, bin.Difficulty), bin.Count)); + } } - /// - /// The coefficients of a quartic fitted to the miss counts at each skill level. - /// - /// The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator. + public override double DifficultyValue() + { + if (difficulties.Count == 0) return 0; + + return difficulties.Count > binThreshold ? difficultyValueBinned() : difficultyValueExact(); + } + + /// + /// A polynomial fitted to the miss counts at each skill level. + /// public ExpPolynomial GetMissPenaltyCurve() { double[] missCounts = new double[7]; double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 }; + ExpPolynomial missPenaltyCurve = new ExpPolynomial(); + + // If there are no notes, we just return the curve with all coefficients set to zero. + if (difficulties.Count == 0 || difficulties.Max() == 0) + return missPenaltyCurve; + double fcSkill = DifficultyValue(); - ExpPolynomial curve = new ExpPolynomial(); - - // If there are no notes, we just return the empty polynomial. - if (difficulties.Count == 0 || difficulties.Max() == 0) - return curve; - - var bins = Bin.CreateBins(difficulties); + var bins = Bin.CreateBins(difficulties, bin_count); for (int i = 0; i < penalties.Length; i++) { @@ -126,15 +130,15 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation missCounts[i] = getMissCountAtSkill(penalizedSkill, bins); } - curve.Fit(missCounts); + missPenaltyCurve.Fit(missCounts); - return curve; + return missPenaltyCurve; } /// - /// Find the lowest misscount that a player with the provided would have a 2% chance of achieving. + /// Find the lowest miss count that a player with the provided would have a 2% chance of achieving or better. /// - private double getMissCountAtSkill(double skill, Bin[] bins) + private double getMissCountAtSkill(double skill, List bins) { double maxDiff = difficulties.Max(); @@ -143,9 +147,9 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation if (skill <= 0) return difficulties.Count; - var poiBin = difficulties.Count > 64 ? new PoissonBinomial(bins, skill, HitProbability) : new PoissonBinomial(difficulties, skill, HitProbability); + var poiBin = difficulties.Count > binThreshold ? new PoissonBinomial(bins, skill, HitProbability) : new PoissonBinomial(difficulties, skill, HitProbability); - return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - FcProbability, -50, 1000, accuracy: 1e-4)); + return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - fc_probability, -50, 1000, accuracy: 1e-4)); } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs index 55fb27183e..f0f32c4d26 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuDifficultyCalculator.cs @@ -50,7 +50,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty double sliderFactor = aimRating > 0 ? aimRatingNoSliders / aimRating : 1; - ExpPolynomial aimMissPenaltyCurve = ((Aim)skills[0]).GetMissPenaltyCurve(); + ExpPolynomial aimMissPenaltyCurve = ((OsuProbabilitySkill)skills[0]).GetMissPenaltyCurve(); double speedDifficultyStrainCount = ((OsuStrainSkill)skills[2]).CountTopWeightedStrains(); if (mods.Any(m => m is OsuModTouchDevice)) diff --git a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs index 21afa122f9..00b22f66ed 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/OsuPerformanceCalculator.cs @@ -306,14 +306,16 @@ namespace osu.Game.Rulesets.Osu.Difficulty return flashlightValue; } - // The miss penalty formulas assume that a player will miss on the hardest parts of a map. - // With the curve fitted miss penalty, we use a pre-computed curve of skill levels for each - // miss count, raised to the power of 1.5 to account for skill -> sr -> pp changing the exponent. + // Due to the unavailability of miss location in PP, the following formulas assume that a player will miss on the hardest parts of a map. + + // With the curve fitted miss penalty, we use a pre-computed curve of skill levels for each miss count, raised to the power of 1.5 as + // the multiple of the exponents on star rating and PP. This power should be changed if either SR or PP begin to use a different exponent. + // As a result, this exponent is not subject to balance. private double calculateCurveFittedMissPenalty(double missCount, ExpPolynomial curve) => Math.Pow(1 - curve.GetPenaltyAt(missCount), 1.5); - // With the strain count miss penalty, we use the amount of relatively difficult sections to - // adjust miss penalty to make it more punishing on maps with lower amount of hard sections. - private double calculateStrainCountMissPenalty(double missCount, double difficultStrainCount) => 0.96 / ((missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94))) + 1); + // With the strain count miss penalty, we use the amount of relatively difficult sections to adjust the miss penalty, + // to make it more punishing on maps with lower amount of hard sections. This formula is subject to balance. + private double calculateStrainCountMissPenalty(double missCount, double difficultStrainCount) => 0.96 / (missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94)) + 1); private double getComboScalingFactor(OsuDifficultyAttributes attributes) => attributes.MaxCombo <= 0 ? 1.0 : Math.Min(Math.Pow(scoreMaxCombo, 0.8) / Math.Pow(attributes.MaxCombo, 0.8), 1.0); private int totalHits => countGreat + countOk + countMeh + countMiss; diff --git a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs index 5c5852c4dd..751696d349 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Skills/Aim.cs @@ -6,14 +6,14 @@ using osu.Game.Rulesets.Difficulty.Preprocessing; using osu.Game.Rulesets.Mods; using osu.Game.Rulesets.Osu.Difficulty.Aggregation; using osu.Game.Rulesets.Osu.Difficulty.Evaluators; -using osu.Game.Rulesets.Osu.Difficulty.Utils; +using osu.Game.Utils; namespace osu.Game.Rulesets.Osu.Difficulty.Skills { /// /// Represents the skill required to correctly aim at every object in the map with a uniform CircleSize and normalized distances. /// - public class Aim : OsuProbSkill + public class Aim : OsuProbabilitySkill { public Aim(Mod[] mods, bool withSliders) : base(mods) @@ -25,15 +25,13 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills private double currentStrain; - private double skillMultiplier => 130; + private double skillMultiplier => 132; private double strainDecayBase => 0.15; - protected override double FcProbability => 0.02; - protected override double HitProbability(double skill, double difficulty) { - if (skill <= 0) return 0; if (difficulty <= 0) return 1; + if (skill <= 0) return 0; return SpecialFunctions.Erf(skill / (Math.Sqrt(2) * difficulty)); } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs index 7ed3595c35..f3a8bdbd55 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs @@ -12,27 +12,25 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils public double Difficulty; public double Count; - private const int bin_count = 32; - /// - /// Create an array of equally spaced bins. Count is linearly interpolated into each bin. + /// Create an array of spaced bins. Count is linearly interpolated into each bin. /// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2, /// we will add 0.8 to the count of 3's and 0.2 to the count of 4's /// - public static Bin[] CreateBins(List difficulties) + public static List CreateBins(List difficulties, int totalBins) { double maxDifficulty = difficulties.Max(); - var bins = new Bin[bin_count]; + var binsArray = new Bin[totalBins]; - for (int i = 0; i < bin_count; i++) + for (int i = 0; i < totalBins; i++) { - bins[i].Difficulty = maxDifficulty * (i + 1) / bin_count; + binsArray[i].Difficulty = maxDifficulty * (i + 1) / totalBins; } foreach (double d in difficulties) { - double binIndex = bin_count * (d / maxDifficulty) - 1; + double binIndex = totalBins * (d / maxDifficulty) - 1; int lowerBound = (int)Math.Floor(binIndex); double t = binIndex - lowerBound; @@ -41,19 +39,24 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils //We don't store that since it doesn't contribute to difficulty if (lowerBound >= 0) { - bins[lowerBound].Count += (1 - t); + binsArray[lowerBound].Count += (1 - t); } int upperBound = lowerBound + 1; // this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway - if (upperBound < bin_count) + if (upperBound < totalBins) { - bins[upperBound].Count += t; + binsArray[upperBound].Count += t; } } - return bins; + var binsList = binsArray.ToList(); + + // For a slight performance improvement, we remove bins that don't contribute to difficulty. + binsList.RemoveAll(bin => bin.Count == 0); + + return binsList; } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs index 1e814556ed..c76002a770 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs @@ -4,13 +4,22 @@ using System; using System.Collections.Generic; using System.Linq; +using osu.Game.Utils; namespace osu.Game.Rulesets.Osu.Difficulty.Utils { + /// + /// Represents a polynomial fitted to the logarithm of a given set of points. + /// The resulting polynomial is exponentiated, ensuring low residuals for + /// small inputs while handling exponentially increasing trends in the data. + /// This approach is useful for modelling the results of decreasing skill with few coefficients, + /// as linear decreases in skill correspond with exponential increases in miss counts. + /// public struct ExpPolynomial { private double[]? coefficients; + // The matrix that minimizes the square error at X values [0.0, 0.30, 0.60, 0.80, 0.90, 0.95, 1.0]. private static readonly double[][] matrix = { new[] { 0.0, 3.14395, 5.18439, 6.46975, 1.4638, -9.53526, 0.0 }, @@ -18,14 +27,13 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils }; /// - /// Computes a quartic or cubic function that starts at 0 and ends at the highest judgement count in the array. + /// Computes the coefficients of a quartic polynomial, starting at 0 and ending at the highest miss count in the array. /// - /// A list of judgements, with X values [0.0, 0.05, ..., 0.95, 1.0]. - public void Fit(double[] judgementCounts) + /// A list of miss counts, with X values [1, 0.95, 0.9, 0.8, 0.6, 0.3, 0] corresponding to their skill levels. + public void Fit(double[] missCounts) { - List logMissCounts = judgementCounts.Select(x => Math.Log(x + 1)).ToList(); + List logMissCounts = missCounts.Select(x => Math.Log(x + 1)).ToList(); - // The polynomial will pass through the point (1, endPoint). double endPoint = logMissCounts.Max(); double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 }; @@ -35,14 +43,14 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils logMissCounts[i] -= endPoint * (1 - penalties[i]); } - // The precomputed matrix assumes the misscounts go in order of greatest to least. + // The precomputed matrix assumes the miss counts go in order of greatest to least. logMissCounts.Reverse(); - coefficients = new double[3]; + coefficients = new double[4]; - coefficients[2] = endPoint; + coefficients[3] = endPoint; - // Now we dot product the adjusted misscounts with the precomputed matrix. + // Now we dot product the adjusted miss counts with the precomputed matrix. for (int row = 0; row < matrix.Length; row++) { for (int column = 0; column < matrix[row].Length; column++) @@ -57,7 +65,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils /// /// Solve for the miss penalty at a specified miss count. /// - /// The penalty value at the specified miss count. public double GetPenaltyAt(double missCount) { if (coefficients is null) @@ -69,9 +76,11 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils List xVals = SpecialFunctions.SolvePolynomialRoots(listCoefficients); const double max_error = 1e-7; - double? largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault(); - return largestValue != null ? Math.Clamp(largestValue.Value, 0, 1) : 1; + // We find the largest value of x (corresponding to the penalty) found as a root of the function, with a fallback of a 100% penalty if no roots were found. + double largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault() ?? 1; + + return Math.Clamp(largestValue, 0, 1); } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs index 1b99307b5f..52005513cb 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using osu.Game.Utils; namespace osu.Game.Rulesets.Osu.Difficulty.Utils { @@ -70,7 +71,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils /// The bins of difficulties in the map. /// The skill level to get the miss probabilities with. /// /// Converts difficulties and skill to miss probabilities. - public PoissonBinomial(Bin[] bins, double skill, Func hitProbability) + public PoissonBinomial(List bins, double skill, Func hitProbability) { double variance = 0; double gamma = 0; diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs index 64b4d38ec1..bb01e8162c 100644 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs +++ b/osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs @@ -18,19 +18,29 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils /// The maximum number of iterations before the function throws an error. /// The desired precision in which the root is returned. /// The multiplier on the upper bound when no root is found within the provided bounds. - public static double FindRootExpand(Func function, double guessLowerBound, double guessUpperBound, int maxIterations = 25, double accuracy = 1e-6D, double expansionFactor = 2) + /// The maximum number of times the bounds of the function should increase. + public static double FindRootExpand(Func function, double guessLowerBound, double guessUpperBound, int maxIterations = 25, double accuracy = 1e-6D, double expansionFactor = 2, double maxExpansions = 32) { double a = guessLowerBound; double b = guessUpperBound; double fa = function(a); double fb = function(b); + int expansions = 0; + while (fa * fb > 0) { a = b; b *= expansionFactor; fa = function(a); fb = function(b); + + expansions++; + + if (expansions > maxExpansions) + { + throw new MaximumIterationsException("No root was found within the provided function."); + } } double t = 0.5; @@ -97,5 +107,13 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils return 0; } + + private class MaximumIterationsException : Exception + { + public MaximumIterationsException(string message) + : base(message) + { + } + } } } diff --git a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs b/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs deleted file mode 100644 index c8853ab8a5..0000000000 --- a/osu.Game.Rulesets.Osu/Difficulty/Utils/SpecialFunctions.cs +++ /dev/null @@ -1,972 +0,0 @@ -// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. -// See the LICENCE file in the repository root for full licence text. - -// All code is referenced from the following: -// https://github.com/mathnet/mathnet-numerics/blob/master/src/Numerics/SpecialFunctions/Erf.cs -// https://github.com/mathnet/mathnet-numerics/blob/master/src/Numerics/Optimization/NelderMeadSimplex.cs - -/* - Copyright (c) 2002-2022 Math.NET -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -using System; -using System.Collections.Generic; -using System.Linq; -using osu.Framework.Utils; - -namespace osu.Game.Rulesets.Osu.Difficulty.Utils -{ - public class SpecialFunctions - { - private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d; - private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d; - private const double m_2_pi = 6.28318530717958647692528676655900576d; - - /// - /// ************************************** - /// COEFFICIENTS FOR METHOD ErfImp * - /// ************************************** - /// - /// Polynomial coefficients for a numerator of ErfImp - /// calculation for Erf(x) in the interval [1e-10, 0.5]. - /// - private static readonly double[] erf_imp_an = { 0.00337916709551257388990745, -0.00073695653048167948530905, -0.374732337392919607868241, 0.0817442448733587196071743, -0.0421089319936548595203468, 0.0070165709512095756344528, -0.00495091255982435110337458, 0.000871646599037922480317225 }; - - /// Polynomial coefficients for a denominator of ErfImp - /// calculation for Erf(x) in the interval [1e-10, 0.5]. - /// - private static readonly double[] erf_imp_ad = { 1, -0.218088218087924645390535, 0.412542972725442099083918, -0.0841891147873106755410271, 0.0655338856400241519690695, -0.0120019604454941768171266, 0.00408165558926174048329689, -0.000615900721557769691924509 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [0.5, 0.75]. - /// - private static readonly double[] erf_imp_bn = { -0.0361790390718262471360258, 0.292251883444882683221149, 0.281447041797604512774415, 0.125610208862766947294894, 0.0274135028268930549240776, 0.00250839672168065762786937 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [0.5, 0.75]. - /// - private static readonly double[] erf_imp_bd = { 1, 1.8545005897903486499845, 1.43575803037831418074962, 0.582827658753036572454135, 0.124810476932949746447682, 0.0113724176546353285778481 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [0.75, 1.25]. - /// - private static readonly double[] erf_imp_cn = { -0.0397876892611136856954425, 0.153165212467878293257683, 0.191260295600936245503129, 0.10276327061989304213645, 0.029637090615738836726027, 0.0046093486780275489468812, 0.000307607820348680180548455 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [0.75, 1.25]. - /// - private static readonly double[] erf_imp_cd = { 1, 1.95520072987627704987886, 1.64762317199384860109595, 0.768238607022126250082483, 0.209793185936509782784315, 0.0319569316899913392596356, 0.00213363160895785378615014 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [1.25, 2.25]. - /// - private static readonly double[] erf_imp_dn = { -0.0300838560557949717328341, 0.0538578829844454508530552, 0.0726211541651914182692959, 0.0367628469888049348429018, 0.00964629015572527529605267, 0.00133453480075291076745275, 0.778087599782504251917881e-4 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [1.25, 2.25]. - /// - private static readonly double[] erf_imp_dd = { 1, 1.75967098147167528287343, 1.32883571437961120556307, 0.552528596508757581287907, 0.133793056941332861912279, 0.0179509645176280768640766, 0.00104712440019937356634038, -0.106640381820357337177643e-7 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [2.25, 3.5]. - /// - private static readonly double[] erf_imp_en = { -0.0117907570137227847827732, 0.014262132090538809896674, 0.0202234435902960820020765, 0.00930668299990432009042239, 0.00213357802422065994322516, 0.00025022987386460102395382, 0.120534912219588189822126e-4 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [2.25, 3.5]. - /// - private static readonly double[] erf_imp_ed = { 1, 1.50376225203620482047419, 0.965397786204462896346934, 0.339265230476796681555511, 0.0689740649541569716897427, 0.00771060262491768307365526, 0.000371421101531069302990367 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [3.5, 5.25]. - /// - private static readonly double[] erf_imp_fn = { -0.00546954795538729307482955, 0.00404190278731707110245394, 0.0054963369553161170521356, 0.00212616472603945399437862, 0.000394984014495083900689956, 0.365565477064442377259271e-4, 0.135485897109932323253786e-5 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [3.5, 5.25]. - /// - private static readonly double[] erf_imp_fd = { 1, 1.21019697773630784832251, 0.620914668221143886601045, 0.173038430661142762569515, 0.0276550813773432047594539, 0.00240625974424309709745382, 0.891811817251336577241006e-4, -0.465528836283382684461025e-11 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [5.25, 8]. - /// - private static readonly double[] erf_imp_gn = { -0.00270722535905778347999196, 0.0013187563425029400461378, 0.00119925933261002333923989, 0.00027849619811344664248235, 0.267822988218331849989363e-4, 0.923043672315028197865066e-6 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [5.25, 8]. - /// - private static readonly double[] erf_imp_gd = { 1, 0.814632808543141591118279, 0.268901665856299542168425, 0.0449877216103041118694989, 0.00381759663320248459168994, 0.000131571897888596914350697, 0.404815359675764138445257e-11 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [8, 11.5]. - /// - private static readonly double[] erf_imp_hn = { -0.00109946720691742196814323, 0.000406425442750422675169153, 0.000274499489416900707787024, 0.465293770646659383436343e-4, 0.320955425395767463401993e-5, 0.778286018145020892261936e-7 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [8, 11.5]. - /// - private static readonly double[] erf_imp_hd = { 1, 0.588173710611846046373373, 0.139363331289409746077541, 0.0166329340417083678763028, 0.00100023921310234908642639, 0.24254837521587225125068e-4 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [11.5, 17]. - /// - private static readonly double[] erf_imp_in = { -0.00056907993601094962855594, 0.000169498540373762264416984, 0.518472354581100890120501e-4, 0.382819312231928859704678e-5, 0.824989931281894431781794e-7 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [11.5, 17]. - /// - private static readonly double[] erf_imp_id = { 1, 0.339637250051139347430323, 0.043472647870310663055044, 0.00248549335224637114641629, 0.535633305337152900549536e-4, -0.117490944405459578783846e-12 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [17, 24]. - /// - private static readonly double[] erf_imp_jn = { -0.000241313599483991337479091, 0.574224975202501512365975e-4, 0.115998962927383778460557e-4, 0.581762134402593739370875e-6, 0.853971555085673614607418e-8 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [17, 24]. - /// - private static readonly double[] erf_imp_jd = { 1, 0.233044138299687841018015, 0.0204186940546440312625597, 0.000797185647564398289151125, 0.117019281670172327758019e-4 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [24, 38]. - /// - private static readonly double[] erf_imp_kn = { -0.000146674699277760365803642, 0.162666552112280519955647e-4, 0.269116248509165239294897e-5, 0.979584479468091935086972e-7, 0.101994647625723465722285e-8 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [24, 38]. - /// - private static readonly double[] erf_imp_kd = { 1, 0.165907812944847226546036, 0.0103361716191505884359634, 0.000286593026373868366935721, 0.298401570840900340874568e-5 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [38, 60]. - /// - private static readonly double[] erf_imp_ln = { -0.583905797629771786720406e-4, 0.412510325105496173512992e-5, 0.431790922420250949096906e-6, 0.993365155590013193345569e-8, 0.653480510020104699270084e-10 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [38, 60]. - /// - private static readonly double[] erf_imp_ld = { 1, 0.105077086072039915406159, 0.00414278428675475620830226, 0.726338754644523769144108e-4, 0.477818471047398785369849e-6 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [60, 85]. - /// - private static readonly double[] erf_imp_mn = { -0.196457797609229579459841e-4, 0.157243887666800692441195e-5, 0.543902511192700878690335e-7, 0.317472492369117710852685e-9 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [60, 85]. - /// - private static readonly double[] erf_imp_md = { 1, 0.052803989240957632204885, 0.000926876069151753290378112, 0.541011723226630257077328e-5, 0.535093845803642394908747e-15 }; - - /// Polynomial coefficients for a numerator in ErfImp - /// calculation for Erfc(x) in the interval [85, 110]. - /// - private static readonly double[] erf_imp_nn = { -0.789224703978722689089794e-5, 0.622088451660986955124162e-6, 0.145728445676882396797184e-7, 0.603715505542715364529243e-10 }; - - /// Polynomial coefficients for a denominator in ErfImp - /// calculation for Erfc(x) in the interval [85, 110]. - /// - private static readonly double[] erf_imp_nd = { 1, 0.0375328846356293715248719, 0.000467919535974625308126054, 0.193847039275845656900547e-5 }; - - /// - /// ************************************** - /// COEFFICIENTS FOR METHOD ErfInvImp * - /// ************************************** - /// - /// Polynomial coefficients for a numerator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0, 0.5]. - /// - private static readonly double[] erv_inv_imp_an = { -0.000508781949658280665617, -0.00836874819741736770379, 0.0334806625409744615033, -0.0126926147662974029034, -0.0365637971411762664006, 0.0219878681111168899165, 0.00822687874676915743155, -0.00538772965071242932965 }; - - /// Polynomial coefficients for a denominator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0, 0.5]. - /// - private static readonly double[] erv_inv_imp_ad = { 1, -0.970005043303290640362, -1.56574558234175846809, 1.56221558398423026363, 0.662328840472002992063, -0.71228902341542847553, -0.0527396382340099713954, 0.0795283687341571680018, -0.00233393759374190016776, 0.000886216390456424707504 }; - - /// Polynomial coefficients for a numerator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.5, 0.75]. - /// - private static readonly double[] erv_inv_imp_bn = { -0.202433508355938759655, 0.105264680699391713268, 8.37050328343119927838, 17.6447298408374015486, -18.8510648058714251895, -44.6382324441786960818, 17.445385985570866523, 21.1294655448340526258, -3.67192254707729348546 }; - - /// Polynomial coefficients for a denominator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.5, 0.75]. - /// - private static readonly double[] erv_inv_imp_bd = { 1, 6.24264124854247537712, 3.9713437953343869095, -28.6608180499800029974, -20.1432634680485188801, 48.5609213108739935468, 10.8268667355460159008, -22.6436933413139721736, 1.72114765761200282724 }; - - /// Polynomial coefficients for a numerator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x less than 3. - /// - private static readonly double[] erv_inv_imp_cn = { -0.131102781679951906451, -0.163794047193317060787, 0.117030156341995252019, 0.387079738972604337464, 0.337785538912035898924, 0.142869534408157156766, 0.0290157910005329060432, 0.00214558995388805277169, -0.679465575181126350155e-6, 0.285225331782217055858e-7, -0.681149956853776992068e-9 }; - - /// Polynomial coefficients for a denominator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x less than 3. - /// - private static readonly double[] erv_inv_imp_cd = { 1, 3.46625407242567245975, 5.38168345707006855425, 4.77846592945843778382, 2.59301921623620271374, 0.848854343457902036425, 0.152264338295331783612, 0.01105924229346489121 }; - - /// Polynomial coefficients for a numerator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 3 and 6. - /// - private static readonly double[] erv_inv_imp_dn = { -0.0350353787183177984712, -0.00222426529213447927281, 0.0185573306514231072324, 0.00950804701325919603619, 0.00187123492819559223345, 0.000157544617424960554631, 0.460469890584317994083e-5, -0.230404776911882601748e-9, 0.266339227425782031962e-11 }; - - /// Polynomial coefficients for a denominator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 3 and 6. - /// - private static readonly double[] erv_inv_imp_dd = { 1, 1.3653349817554063097, 0.762059164553623404043, 0.220091105764131249824, 0.0341589143670947727934, 0.00263861676657015992959, 0.764675292302794483503e-4 }; - - /// Polynomial coefficients for a numerator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 6 and 18. - /// - private static readonly double[] erv_inv_imp_en = { -0.0167431005076633737133, -0.00112951438745580278863, 0.00105628862152492910091, 0.000209386317487588078668, 0.149624783758342370182e-4, 0.449696789927706453732e-6, 0.462596163522878599135e-8, -0.281128735628831791805e-13, 0.99055709973310326855e-16 }; - - /// Polynomial coefficients for a denominator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 6 and 18. - /// - private static readonly double[] erv_inv_imp_ed = { 1, 0.591429344886417493481, 0.138151865749083321638, 0.0160746087093676504695, 0.000964011807005165528527, 0.275335474764726041141e-4, 0.282243172016108031869e-6 }; - - /// Polynomial coefficients for a numerator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 18 and 44. - /// - private static readonly double[] erv_inv_imp_fn = { -0.0024978212791898131227, -0.779190719229053954292e-5, 0.254723037413027451751e-4, 0.162397777342510920873e-5, 0.396341011304801168516e-7, 0.411632831190944208473e-9, 0.145596286718675035587e-11, -0.116765012397184275695e-17 }; - - /// Polynomial coefficients for a denominator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x between 18 and 44. - /// - private static readonly double[] erv_inv_imp_fd = { 1, 0.207123112214422517181, 0.0169410838120975906478, 0.000690538265622684595676, 0.145007359818232637924e-4, 0.144437756628144157666e-6, 0.509761276599778486139e-9 }; - - /// Polynomial coefficients for a numerator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x greater than 44. - /// - private static readonly double[] erv_inv_imp_gn = { -0.000539042911019078575891, -0.28398759004727721098e-6, 0.899465114892291446442e-6, 0.229345859265920864296e-7, 0.225561444863500149219e-9, 0.947846627503022684216e-12, 0.135880130108924861008e-14, -0.348890393399948882918e-21 }; - - /// Polynomial coefficients for a denominator of ErfInvImp - /// calculation for Erf^-1(z) in the interval [0.75, 1] with x greater than 44. - /// - private static readonly double[] erv_inv_imp_gd = { 1, 0.0845746234001899436914, 0.00282092984726264681981, 0.468292921940894236786e-4, 0.399968812193862100054e-6, 0.161809290887904476097e-8, 0.231558608310259605225e-11 }; - - /// Calculates the error function. - /// The value to evaluate. - /// the error function evaluated at given value. - /// - /// - /// returns 1 if x == double.PositiveInfinity. - /// returns -1 if x == double.NegativeInfinity. - /// - /// - public static double Erf(double x) - { - if (x == 0) - { - return 0; - } - - if (double.IsPositiveInfinity(x)) - { - return 1; - } - - if (double.IsNegativeInfinity(x)) - { - return -1; - } - - if (double.IsNaN(x)) - { - return double.NaN; - } - - return erfImp(x, false); - } - - /// Calculates the complementary error function. - /// The value to evaluate. - /// the complementary error function evaluated at given value. - /// - /// - /// returns 0 if x == double.PositiveInfinity. - /// returns 2 if x == double.NegativeInfinity. - /// - /// - public static double Erfc(double x) - { - if (x == 0) - { - return 1; - } - - if (double.IsPositiveInfinity(x)) - { - return 0; - } - - if (double.IsNegativeInfinity(x)) - { - return 2; - } - - if (double.IsNaN(x)) - { - return double.NaN; - } - - return erfImp(x, true); - } - - /// Calculates the inverse error function evaluated at z. - /// The inverse error function evaluated at given value. - /// - /// - /// returns double.PositiveInfinity if z >= 1.0. - /// returns double.NegativeInfinity if z <= -1.0. - /// - /// - /// Calculates the inverse error function evaluated at z. - /// value to evaluate. - /// the inverse error function evaluated at Z. - public static double ErfInv(double z) - { - if (z == 0.0) - { - return 0.0; - } - - if (z >= 1.0) - { - return double.PositiveInfinity; - } - - if (z <= -1.0) - { - return double.NegativeInfinity; - } - - double p, q, s; - - if (z < 0) - { - p = -z; - q = 1 - p; - s = -1; - } - else - { - p = z; - q = 1 - z; - s = 1; - } - - return erfInvImpl(p, q, s); - } - - /// - /// Implementation of the error function. - /// - /// Where to evaluate the error function. - /// Whether to compute 1 - the error function. - /// the error function. - private static double erfImp(double z, bool invert) - { - if (z < 0) - { - if (!invert) - { - return -erfImp(-z, false); - } - - if (z < -0.5) - { - return 2 - erfImp(-z, true); - } - - return 1 + erfImp(-z, false); - } - - double result; - - // Big bunch of selection statements now to pick which - // implementation to use, try to put most likely options - // first: - if (z < 0.5) - { - // We're going to calculate erf: - if (z < 1e-10) - { - result = (z * 1.125) + (z * 0.003379167095512573896158903121545171688); - } - else - { - // Worst case absolute error found: 6.688618532e-21 - result = (z * 1.125) + (z * evaluatePolynomial(z, erf_imp_an) / evaluatePolynomial(z, erf_imp_ad)); - } - } - else if (z < 110) - { - // We'll be calculating erfc: - invert = !invert; - double r, b; - - if (z < 0.75) - { - // Worst case absolute error found: 5.582813374e-21 - r = evaluatePolynomial(z - 0.5, erf_imp_bn) / evaluatePolynomial(z - 0.5, erf_imp_bd); - b = 0.3440242112F; - } - else if (z < 1.25) - { - // Worst case absolute error found: 4.01854729e-21 - r = evaluatePolynomial(z - 0.75, erf_imp_cn) / evaluatePolynomial(z - 0.75, erf_imp_cd); - b = 0.419990927F; - } - else if (z < 2.25) - { - // Worst case absolute error found: 2.866005373e-21 - r = evaluatePolynomial(z - 1.25, erf_imp_dn) / evaluatePolynomial(z - 1.25, erf_imp_dd); - b = 0.4898625016F; - } - else if (z < 3.5) - { - // Worst case absolute error found: 1.045355789e-21 - r = evaluatePolynomial(z - 2.25, erf_imp_en) / evaluatePolynomial(z - 2.25, erf_imp_ed); - b = 0.5317370892F; - } - else if (z < 5.25) - { - // Worst case absolute error found: 8.300028706e-22 - r = evaluatePolynomial(z - 3.5, erf_imp_fn) / evaluatePolynomial(z - 3.5, erf_imp_fd); - b = 0.5489973426F; - } - else if (z < 8) - { - // Worst case absolute error found: 1.700157534e-21 - r = evaluatePolynomial(z - 5.25, erf_imp_gn) / evaluatePolynomial(z - 5.25, erf_imp_gd); - b = 0.5571740866F; - } - else if (z < 11.5) - { - // Worst case absolute error found: 3.002278011e-22 - r = evaluatePolynomial(z - 8, erf_imp_hn) / evaluatePolynomial(z - 8, erf_imp_hd); - b = 0.5609807968F; - } - else if (z < 17) - { - // Worst case absolute error found: 6.741114695e-21 - r = evaluatePolynomial(z - 11.5, erf_imp_in) / evaluatePolynomial(z - 11.5, erf_imp_id); - b = 0.5626493692F; - } - else if (z < 24) - { - // Worst case absolute error found: 7.802346984e-22 - r = evaluatePolynomial(z - 17, erf_imp_jn) / evaluatePolynomial(z - 17, erf_imp_jd); - b = 0.5634598136F; - } - else if (z < 38) - { - // Worst case absolute error found: 2.414228989e-22 - r = evaluatePolynomial(z - 24, erf_imp_kn) / evaluatePolynomial(z - 24, erf_imp_kd); - b = 0.5638477802F; - } - else if (z < 60) - { - // Worst case absolute error found: 5.896543869e-24 - r = evaluatePolynomial(z - 38, erf_imp_ln) / evaluatePolynomial(z - 38, erf_imp_ld); - b = 0.5640528202F; - } - else if (z < 85) - { - // Worst case absolute error found: 3.080612264e-21 - r = evaluatePolynomial(z - 60, erf_imp_mn) / evaluatePolynomial(z - 60, erf_imp_md); - b = 0.5641309023F; - } - else - { - // Worst case absolute error found: 8.094633491e-22 - r = evaluatePolynomial(z - 85, erf_imp_nn) / evaluatePolynomial(z - 85, erf_imp_nd); - b = 0.5641584396F; - } - - double g = Math.Exp(-z * z) / z; - result = (g * b) + (g * r); - } - else - { - // Any value of z larger than 28 will underflow to zero: - result = 0; - invert = !invert; - } - - if (invert) - { - result = 1 - result; - } - - return result; - } - - /// Calculates the complementary inverse error function evaluated at z. - /// The complementary inverse error function evaluated at given value. - /// We have tested this implementation against the arbitrary precision mpmath library - /// and found cases where we can only guarantee 9 significant figures correct. - /// - /// returns double.PositiveInfinity if z <= 0.0. - /// returns double.NegativeInfinity if z >= 2.0. - /// - /// - /// calculates the complementary inverse error function evaluated at z. - /// value to evaluate. - /// the complementary inverse error function evaluated at Z. - public static double ErfcInv(double z) - { - if (z <= 0.0) - { - return double.PositiveInfinity; - } - - if (z >= 2.0) - { - return double.NegativeInfinity; - } - - double p, q, s; - - if (z > 1) - { - q = 2 - z; - p = 1 - q; - s = -1; - } - else - { - p = 1 - z; - q = z; - s = 1; - } - - return erfInvImpl(p, q, s); - } - - /// - /// The implementation of the inverse error function. - /// - /// First intermediate parameter. - /// Second intermediate parameter. - /// Third intermediate parameter. - /// the inverse error function. - private static double erfInvImpl(double p, double q, double s) - { - double result; - - if (p <= 0.5) - { - // Evaluate inverse erf using the rational approximation: - // - // x = p(p+10)(Y+R(p)) - // - // Where Y is a constant, and R(p) is optimized for a low - // absolute error compared to |Y|. - // - // double: Max error found: 2.001849e-18 - // long double: Max error found: 1.017064e-20 - // Maximum Deviation Found (actual error term at infinite precision) 8.030e-21 - const float y = 0.0891314744949340820313f; - double g = p * (p + 10); - double r = evaluatePolynomial(p, erv_inv_imp_an) / evaluatePolynomial(p, erv_inv_imp_ad); - result = (g * y) + (g * r); - } - else if (q >= 0.25) - { - // Rational approximation for 0.5 > q >= 0.25 - // - // x = sqrt(-2*log(q)) / (Y + R(q)) - // - // Where Y is a constant, and R(q) is optimized for a low - // absolute error compared to Y. - // - // double : Max error found: 7.403372e-17 - // long double : Max error found: 6.084616e-20 - // Maximum Deviation Found (error term) 4.811e-20 - const float y = 2.249481201171875f; - double g = Math.Sqrt(-2 * Math.Log(q)); - double xs = q - 0.25; - double r = evaluatePolynomial(xs, erv_inv_imp_bn) / evaluatePolynomial(xs, erv_inv_imp_bd); - result = g / (y + r); - } - else - { - // For q < 0.25 we have a series of rational approximations all - // of the general form: - // - // let: x = sqrt(-log(q)) - // - // Then the result is given by: - // - // x(Y+R(x-B)) - // - // where Y is a constant, B is the lowest value of x for which - // the approximation is valid, and R(x-B) is optimized for a low - // absolute error compared to Y. - // - // Note that almost all code will really go through the first - // or maybe second approximation. After than we're dealing with very - // small input values indeed: 80 and 128 bit long double's go all the - // way down to ~ 1e-5000 so the "tail" is rather long... - double x = Math.Sqrt(-Math.Log(q)); - - if (x < 3) - { - // Max error found: 1.089051e-20 - const float y = 0.807220458984375f; - double xs = x - 1.125; - double r = evaluatePolynomial(xs, erv_inv_imp_cn) / evaluatePolynomial(xs, erv_inv_imp_cd); - result = (y * x) + (r * x); - } - else if (x < 6) - { - // Max error found: 8.389174e-21 - const float y = 0.93995571136474609375f; - double xs = x - 3; - double r = evaluatePolynomial(xs, erv_inv_imp_dn) / evaluatePolynomial(xs, erv_inv_imp_dd); - result = (y * x) + (r * x); - } - else if (x < 18) - { - // Max error found: 1.481312e-19 - const float y = 0.98362827301025390625f; - double xs = x - 6; - double r = evaluatePolynomial(xs, erv_inv_imp_en) / evaluatePolynomial(xs, erv_inv_imp_ed); - result = (y * x) + (r * x); - } - else if (x < 44) - { - // Max error found: 5.697761e-20 - const float y = 0.99714565277099609375f; - double xs = x - 18; - double r = evaluatePolynomial(xs, erv_inv_imp_fn) / evaluatePolynomial(xs, erv_inv_imp_fd); - result = (y * x) + (r * x); - } - else - { - // Max error found: 1.279746e-20 - const float y = 0.99941349029541015625f; - double xs = x - 44; - double r = evaluatePolynomial(xs, erv_inv_imp_gn) / evaluatePolynomial(xs, erv_inv_imp_gd); - result = (y * x) + (r * x); - } - } - - return s * result; - } - - /// - /// Evaluate a polynomial at point x. - /// Coefficients are ordered ascending by power with power k at index k. - /// Example: coefficients [3,-1,2] represent y=2x^2-x+3. - /// - /// The location where to evaluate the polynomial at. - /// The coefficients of the polynomial, coefficient for power k at index k. - /// - /// is a null reference. - /// - private static double evaluatePolynomial(double z, params double[] coefficients) - { - // 2020-10-07 jbialogrodzki #730 Since this is public API we should probably - // handle null arguments? It doesn't seem to have been done consistently in this class though. - if (coefficients == null) - { - throw new ArgumentNullException(nameof(coefficients)); - } - - // 2020-10-07 jbialogrodzki #730 Zero polynomials need explicit handling. - // Without this check, we attempted to peek coefficients at negative indices! - int n = coefficients.Length; - - if (n == 0) - { - return 0; - } - - double sum = coefficients[n - 1]; - - for (int i = n - 2; i >= 0; --i) - { - sum *= z; - sum += coefficients[i]; - } - - return sum; - } - - /// - /// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x. - /// - /// The mean (μ) of the normal distribution. - /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. - /// The location at which to compute the density. - /// the density at . - /// MATLAB: normpdf - public static double NormalPdf(double mean, double stddev, double x) - { - if (stddev < 0.0) - { - throw new ArgumentException("Invalid parametrization for the distribution."); - } - - double d = (x - mean) / stddev; - return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev); - } - - /// - /// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x). - /// - /// The location at which to compute the cumulative distribution function. - /// The mean (μ) of the normal distribution. - /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. - /// the cumulative distribution at location . - /// MATLAB: normcdf - public static double NormalCdf(double mean, double stddev, double x) - { - if (stddev < 0.0) - { - throw new ArgumentException("Invalid parametrization for the distribution."); - } - - return 0.5 * Erfc((mean - x) / (stddev * sqrt2)); - } - - /// - /// Solve for the exact real roots of any polynomial up to degree 4. - /// - /// The coefficients of the polynomial, in ascending order ([1, 3, 5] -> x^2 + 3x + 5). - /// The real roots of the polynomial, and null if the root does not exist. - public static List SolvePolynomialRoots(List coefficients) - { - List xVals = new List(); - - switch (coefficients.Count) - { - case 5: - xVals = solveP4(coefficients[0], coefficients[1], coefficients[2], coefficients[3], coefficients[4], out int _).ToList(); - break; - - case 4: - xVals = solveP3(coefficients[0], coefficients[1], coefficients[2], coefficients[3], out int _).ToList(); - break; - - case 3: - xVals = solveP2(coefficients[0], coefficients[1], coefficients[2], out int _).ToList(); - break; - - case 2: - xVals = solveP2(0, coefficients[1], coefficients[2], out int _).ToList(); - break; - } - - return xVals; - } - - private static double?[] solveP4(double a, double b, double c, double d, double e, out int nRoots) - { - double?[] xVals = new double?[4]; - - nRoots = 0; - - if (a == 0) - { - double?[] xValsCubic = solveP3(b, c, d, e, out nRoots); - - xVals[0] = xValsCubic[0]; - xVals[1] = xValsCubic[1]; - xVals[2] = xValsCubic[2]; - xVals[3] = null; - - return xVals; - } - - b /= a; - c /= a; - d /= a; - - double a3 = -c; - double b3 = b * d - 4 * e; - double c3 = -b * b * e - d * d + 4 * c * e; - - double?[] x3 = solveP3(1, a3, b3, c3, out int iZeroes); - - double q1, q2, p1, p2, sqD; - - double y = x3[0]!.Value; - - // Get the y value with the highest absolute value. - if (iZeroes != 1) - { - if (Math.Abs(x3[1]!.Value) > Math.Abs(y)) - y = x3[1]!.Value; - if (Math.Abs(x3[2]!.Value) > Math.Abs(y)) - y = x3[2]!.Value; - } - - double upperD = y * y - 4 * e; - - if (Precision.AlmostEquals(upperD, 0)) - { - q1 = q2 = y * 0.5; - - upperD = b * b - 4 * (c - y); - - if (Precision.AlmostEquals(upperD, 0)) - p1 = p2 = b * 0.5; - - else - { - sqD = Math.Sqrt(upperD); - p1 = (b + sqD) * 0.5; - p2 = (b - sqD) * 0.5; - } - } - else - { - sqD = Math.Sqrt(upperD); - q1 = (y + sqD) * 0.5; - q2 = (y - sqD) * 0.5; - - p1 = (b * q1 - c) / (q1 - q2); - p2 = (d - b * q2) / (q1 - q2); - } - - // solving quadratic eq. - x^2 + p1*x + q1 = 0 - upperD = p1 * p1 - 4 * q1; - - if (upperD >= 0) - { - nRoots += 2; - - sqD = Math.Sqrt(upperD); - xVals[0] = (-p1 + sqD) * 0.5; - xVals[1] = (-p1 - sqD) * 0.5; - } - - // solving quadratic eq. - x^2 + p2*x + q2 = 0 - upperD = p2 * p2 - 4 * q2; - - if (upperD >= 0) - { - nRoots += 2; - - sqD = Math.Sqrt(upperD); - xVals[2] = (-p2 + sqD) * 0.5; - xVals[3] = (-p2 - sqD) * 0.5; - } - - // Put the null roots at the end of the array. - var nonNulls = xVals.Where(x => x != null); - var nulls = xVals.Where(x => x == null); - xVals = nonNulls.Concat(nulls).ToArray(); - - return xVals; - } - - private static double?[] solveP3(double a, double b, double c, double d, out int nRoots) - { - double?[] xVals = new double?[3]; - - nRoots = 0; - - if (a == 0) - { - double?[] xValsQuadratic = solveP2(b, c, d, out nRoots); - - xVals[0] = xValsQuadratic[0]; - xVals[1] = xValsQuadratic[1]; - xVals[2] = null; - - return xVals; - } - - b /= a; - c /= a; - d /= a; - - double b2 = b * b; - double q = (b2 - 3 * c) / 9; - double q3 = q * q * q; - double r = (b * (2 * b2 - 9 * c) + 27 * d) / 54; - double r2 = r * r; - - if (r2 < q3) - { - nRoots = 3; - - double t = r / Math.Sqrt(q3); - t = Math.Clamp(t, -1, 1); - t = Math.Acos(t); - b /= 3; - q = -2 * Math.Sqrt(q); - - xVals[0] = q * Math.Cos(t / 3) - b; - xVals[1] = q * Math.Cos((t + m_2_pi) / 3) - b; - xVals[2] = q * Math.Cos((t - m_2_pi) / 3) - b; - - return xVals; - } - - double upperA = -Math.Cbrt(Math.Abs(r) + double.Sqrt(r2 - q3)); - - if (r < 0) - upperA = -upperA; - - double upperB = upperA == 0 ? 0 : q / upperA; - b /= 3; - - xVals[0] = upperA + upperB - b; - - if (Precision.AlmostEquals(0.5 * Math.Sqrt(3) * (upperA - upperB), 0)) - { - nRoots = 2; - xVals[1] = -0.5 * (upperA + upperB) - b; - - return xVals; - } - - nRoots = 1; - - return xVals; - } - - private static double?[] solveP2(double a, double b, double c, out int nRoots) - { - double?[] xVals = new double?[2]; - - nRoots = 0; - - if (a == 0) - { - if (b == 0) - return xVals; - - nRoots = 1; - xVals[0] = -c / b; - } - - double discriminant = b * b - 4 * a * c; - - switch (discriminant) - { - case < 0: - break; - - case 0: - nRoots = 1; - xVals[0] = -b / (2 * a); - break; - - default: - nRoots = 2; - xVals[0] = (-b + Math.Sqrt(discriminant)) / (2 * a); - xVals[1] = (-b - Math.Sqrt(discriminant)) / (2 * a); - break; - } - - return xVals; - } - } -} diff --git a/osu.Game/Utils/SpecialFunctions.cs b/osu.Game/Utils/SpecialFunctions.cs index 0b0f0598bb..bf7a6151ae 100644 --- a/osu.Game/Utils/SpecialFunctions.cs +++ b/osu.Game/Utils/SpecialFunctions.cs @@ -13,12 +13,17 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI */ using System; +using System.Collections.Generic; +using System.Linq; +using osu.Framework.Utils; namespace osu.Game.Utils { public class SpecialFunctions { + private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d; private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d; + private const double m_2_pi = 6.28318530717958647692528676655900576d; /// /// ************************************** @@ -690,5 +695,280 @@ namespace osu.Game.Utils return sum; } + + /// + /// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x. + /// + /// The mean (μ) of the normal distribution. + /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// The location at which to compute the density. + /// the density at . + /// MATLAB: normpdf + public static double NormalPdf(double mean, double stddev, double x) + { + if (stddev < 0.0) + { + throw new ArgumentException("Invalid parametrization for the distribution."); + } + + double d = (x - mean) / stddev; + return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev); + } + + /// + /// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x). + /// + /// The location at which to compute the cumulative distribution function. + /// The mean (μ) of the normal distribution. + /// The standard deviation (σ) of the normal distribution. Range: σ ≥ 0. + /// the cumulative distribution at location . + /// MATLAB: normcdf + public static double NormalCdf(double mean, double stddev, double x) + { + if (stddev < 0.0) + { + throw new ArgumentException("Invalid parametrization for the distribution."); + } + + return 0.5 * Erfc((mean - x) / (stddev * sqrt2)); + } + + /// + /// Solve for the exact real roots of any polynomial up to degree 4. + /// + /// The coefficients of the polynomial, in ascending order ([1, 3, 5] -> x^2 + 3x + 5). + /// The real roots of the polynomial, and null if the root does not exist. + public static List SolvePolynomialRoots(List coefficients) + { + List xVals = new List(); + + switch (coefficients.Count) + { + case 5: + xVals = solveP4(coefficients[0], coefficients[1], coefficients[2], coefficients[3], coefficients[4], out int _).ToList(); + break; + + case 4: + xVals = solveP3(coefficients[0], coefficients[1], coefficients[2], coefficients[3], out int _).ToList(); + break; + + case 3: + xVals = solveP2(coefficients[0], coefficients[1], coefficients[2], out int _).ToList(); + break; + + case 2: + xVals = solveP2(0, coefficients[1], coefficients[2], out int _).ToList(); + break; + } + + return xVals; + } + + // https://github.com/sasamil/Quartic/blob/master/quartic.cpp + private static double?[] solveP4(double a, double b, double c, double d, double e, out int nRoots) + { + double?[] xVals = new double?[4]; + + nRoots = 0; + + if (a == 0) + { + double?[] xValsCubic = solveP3(b, c, d, e, out nRoots); + + xVals[0] = xValsCubic[0]; + xVals[1] = xValsCubic[1]; + xVals[2] = xValsCubic[2]; + xVals[3] = null; + + return xVals; + } + + b /= a; + c /= a; + d /= a; + e /= a; + + double a3 = -c; + double b3 = b * d - 4 * e; + double c3 = -b * b * e - d * d + 4 * c * e; + + double?[] x3 = solveP3(1, a3, b3, c3, out int iZeroes); + + double q1, q2, p1, p2, sqD; + + double y = x3[0]!.Value; + + // Get the y value with the highest absolute value. + if (iZeroes != 1) + { + if (Math.Abs(x3[1]!.Value) > Math.Abs(y)) + y = x3[1]!.Value; + if (Math.Abs(x3[2]!.Value) > Math.Abs(y)) + y = x3[2]!.Value; + } + + double upperD = y * y - 4 * e; + + if (Precision.AlmostEquals(upperD, 0)) + { + q1 = q2 = y * 0.5; + + upperD = b * b - 4 * (c - y); + + if (Precision.AlmostEquals(upperD, 0)) + p1 = p2 = b * 0.5; + + else + { + sqD = Math.Sqrt(upperD); + p1 = (b + sqD) * 0.5; + p2 = (b - sqD) * 0.5; + } + } + else + { + sqD = Math.Sqrt(upperD); + q1 = (y + sqD) * 0.5; + q2 = (y - sqD) * 0.5; + + p1 = (b * q1 - d) / (q1 - q2); + p2 = (d - b * q2) / (q1 - q2); + } + + // solving quadratic eq. - x^2 + p1*x + q1 = 0 + upperD = p1 * p1 - 4 * q1; + + if (upperD >= 0) + { + nRoots += 2; + + sqD = Math.Sqrt(upperD); + xVals[0] = (-p1 + sqD) * 0.5; + xVals[1] = (-p1 - sqD) * 0.5; + } + + // solving quadratic eq. - x^2 + p2*x + q2 = 0 + upperD = p2 * p2 - 4 * q2; + + if (upperD >= 0) + { + nRoots += 2; + + sqD = Math.Sqrt(upperD); + xVals[2] = (-p2 + sqD) * 0.5; + xVals[3] = (-p2 - sqD) * 0.5; + } + + // Put the null roots at the end of the array. + var nonNulls = xVals.Where(x => x != null); + var nulls = xVals.Where(x => x == null); + xVals = nonNulls.Concat(nulls).ToArray(); + + return xVals; + } + + private static double?[] solveP3(double a, double b, double c, double d, out int nRoots) + { + double?[] xVals = new double?[3]; + + nRoots = 0; + + if (a == 0) + { + double?[] xValsQuadratic = solveP2(b, c, d, out nRoots); + + xVals[0] = xValsQuadratic[0]; + xVals[1] = xValsQuadratic[1]; + xVals[2] = null; + + return xVals; + } + + b /= a; + c /= a; + d /= a; + + double a2 = b * b; + double q = (a2 - 3 * c) / 9; + double q3 = q * q * q; + double r = (b * (2 * a2 - 9 * c) + 27 * d) / 54; + double r2 = r * r; + + if (r2 < q3) + { + nRoots = 3; + + double t = r / Math.Sqrt(q3); + t = Math.Clamp(t, -1, 1); + t = Math.Acos(t); + b /= 3; + q = -2 * Math.Sqrt(q); + + xVals[0] = q * Math.Cos(t / 3) - b; + xVals[1] = q * Math.Cos((t + m_2_pi) / 3) - b; + xVals[2] = q * Math.Cos((t - m_2_pi) / 3) - b; + + return xVals; + } + + double upperA = -Math.Cbrt(Math.Abs(r) + Math.Sqrt(r2 - q3)); + + if (r < 0) + upperA = -upperA; + + double upperB = upperA == 0 ? 0 : q / upperA; + b /= 3; + + xVals[0] = upperA + upperB - b; + + if (Precision.AlmostEquals(0.5 * Math.Sqrt(3) * (upperA - upperB), 0)) + { + nRoots = 2; + xVals[1] = -0.5 * (upperA + upperB) - b; + + return xVals; + } + + nRoots = 1; + + return xVals; + } + + private static double?[] solveP2(double a, double b, double c, out int nRoots) + { + double?[] xVals = new double?[2]; + + nRoots = 0; + + if (a == 0) + { + if (b == 0) + return xVals; + + nRoots = 1; + xVals[0] = -c / b; + } + + double discriminant = b * b - 4 * a * c; + + switch (discriminant) + { + case < 0: + break; + + case 0: + nRoots = 1; + xVals[0] = -b / (2 * a); + break; + + default: + nRoots = 2; + xVals[0] = (-b + Math.Sqrt(discriminant)) / (2 * a); + xVals[1] = (-b - Math.Sqrt(discriminant)) / (2 * a); + break; + } + + return xVals; + } } } From 153ff3771544dbfb5e7b7f7c509be66ee172982a Mon Sep 17 00:00:00 2001 From: Nathen Date: Sun, 1 Dec 2024 11:33:40 -0500 Subject: [PATCH 14/14] Code quality tests --- osu.Game/Utils/SpecialFunctions.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/osu.Game/Utils/SpecialFunctions.cs b/osu.Game/Utils/SpecialFunctions.cs index bf7a6151ae..3d8433279f 100644 --- a/osu.Game/Utils/SpecialFunctions.cs +++ b/osu.Game/Utils/SpecialFunctions.cs @@ -23,7 +23,7 @@ namespace osu.Game.Utils { private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d; private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d; - private const double m_2_pi = 6.28318530717958647692528676655900576d; + private const double pi_mult_2 = 6.28318530717958647692528676655900576d; /// /// ************************************** @@ -905,8 +905,8 @@ namespace osu.Game.Utils q = -2 * Math.Sqrt(q); xVals[0] = q * Math.Cos(t / 3) - b; - xVals[1] = q * Math.Cos((t + m_2_pi) / 3) - b; - xVals[2] = q * Math.Cos((t - m_2_pi) / 3) - b; + xVals[1] = q * Math.Cos((t + pi_mult_2) / 3) - b; + xVals[2] = q * Math.Cos((t - pi_mult_2) / 3) - b; return xVals; }