1
0
mirror of https://github.com/ppy/osu.git synced 2024-12-05 09:42:54 +08:00
This commit is contained in:
Natelytle 2024-12-03 16:30:45 +08:00 committed by GitHub
commit 4f68a6e7e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 862 additions and 14 deletions

View File

@ -0,0 +1,155 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.Linq;
using osu.Game.Rulesets.Difficulty.Preprocessing;
using osu.Game.Rulesets.Difficulty.Skills;
using osu.Game.Rulesets.Mods;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation
{
public abstract class OsuProbabilitySkill : Skill
{
protected OsuProbabilitySkill(Mod[] mods)
: base(mods)
{
}
// We assume players have a 2% chance to hit every note in the map.
// A higher value of fc_probability increases the influence of difficulty spikes,
// while a lower value increases the influence of length and consistent difficulty.
private const double fc_probability = 0.02;
private const int bin_count = 32;
// The number of difficulties there must be before we can be sure that binning difficulties would not change the output significantly.
private double binThreshold => 2 * bin_count;
private readonly List<double> difficulties = new List<double>();
/// <summary>
/// Returns the strain value at <see cref="DifficultyHitObject"/>. This value is calculated with or without respect to previous objects.
/// </summary>
protected abstract double StrainValueAt(DifficultyHitObject current);
public override void Process(DifficultyHitObject current)
{
difficulties.Add(StrainValueAt(current));
}
protected abstract double HitProbability(double skill, double difficulty);
private double difficultyValueExact()
{
double maxDiff = difficulties.Max();
if (maxDiff <= 1e-10) return 0;
const double lower_bound = 0;
double upperBoundEstimate = 3.0 * maxDiff;
double skill = RootFinding.FindRootExpand(
skill => fcProbability(skill) - fc_probability,
lower_bound,
upperBoundEstimate,
accuracy: 1e-4);
return skill;
double fcProbability(double s)
{
if (s <= 0) return 0;
return difficulties.Aggregate<double, double>(1, (current, d) => current * HitProbability(s, d));
}
}
private double difficultyValueBinned()
{
double maxDiff = difficulties.Max();
if (maxDiff <= 1e-10) return 0;
var bins = Bin.CreateBins(difficulties, bin_count);
const double lower_bound = 0;
double upperBoundEstimate = 3.0 * maxDiff;
double skill = RootFinding.FindRootExpand(
skill => fcProbability(skill) - fc_probability,
lower_bound,
upperBoundEstimate,
accuracy: 1e-4);
return skill;
double fcProbability(double s)
{
if (s <= 0) return 0;
return bins.Aggregate(1.0, (current, bin) => current * Math.Pow(HitProbability(s, bin.Difficulty), bin.Count));
}
}
public override double DifficultyValue()
{
if (difficulties.Count == 0) return 0;
return difficulties.Count > binThreshold ? difficultyValueBinned() : difficultyValueExact();
}
/// <returns>
/// A polynomial fitted to the miss counts at each skill level.
/// </returns>
public ExpPolynomial GetMissPenaltyCurve()
{
double[] missCounts = new double[7];
double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 };
ExpPolynomial missPenaltyCurve = new ExpPolynomial();
// If there are no notes, we just return the curve with all coefficients set to zero.
if (difficulties.Count == 0 || difficulties.Max() == 0)
return missPenaltyCurve;
double fcSkill = DifficultyValue();
var bins = Bin.CreateBins(difficulties, bin_count);
for (int i = 0; i < penalties.Length; i++)
{
if (i == 0)
{
missCounts[i] = 0;
continue;
}
double penalizedSkill = fcSkill * penalties[i];
missCounts[i] = getMissCountAtSkill(penalizedSkill, bins);
}
missPenaltyCurve.Fit(missCounts);
return missPenaltyCurve;
}
/// <summary>
/// Find the lowest miss count that a player with the provided <paramref name="skill"/> would have a 2% chance of achieving or better.
/// </summary>
private double getMissCountAtSkill(double skill, List<Bin> bins)
{
double maxDiff = difficulties.Max();
if (maxDiff == 0)
return 0;
if (skill <= 0)
return difficulties.Count;
var poiBin = difficulties.Count > binThreshold ? new PoissonBinomial(bins, skill, HitProbability) : new PoissonBinomial(difficulties, skill, HitProbability);
return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - fc_probability, -50, 1000, accuracy: 1e-4));
}
}
}

View File

@ -8,7 +8,7 @@ using osu.Game.Rulesets.Mods;
using System.Linq;
using osu.Framework.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty.Skills
namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation
{
public abstract class OsuStrainSkill : StrainSkill
{

View File

@ -8,6 +8,7 @@ using Newtonsoft.Json;
using osu.Game.Beatmaps;
using osu.Game.Rulesets.Difficulty;
using osu.Game.Rulesets.Mods;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty
{
@ -19,6 +20,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty
[JsonProperty("aim_difficulty")]
public double AimDifficulty { get; set; }
/// <summary>
/// The difficulty corresponding to the aim skill.
/// </summary>
[JsonProperty("aim_penalty_constants")]
public ExpPolynomial AimMissPenaltyCurve { get; set; }
/// <summary>
/// The difficulty corresponding to the speed skill.
/// </summary>

View File

@ -11,8 +11,10 @@ using osu.Game.Rulesets.Difficulty;
using osu.Game.Rulesets.Difficulty.Preprocessing;
using osu.Game.Rulesets.Difficulty.Skills;
using osu.Game.Rulesets.Mods;
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
using osu.Game.Rulesets.Osu.Difficulty.Preprocessing;
using osu.Game.Rulesets.Osu.Difficulty.Skills;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
using osu.Game.Rulesets.Osu.Mods;
using osu.Game.Rulesets.Osu.Objects;
using osu.Game.Rulesets.Osu.Scoring;
@ -48,7 +50,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
double sliderFactor = aimRating > 0 ? aimRatingNoSliders / aimRating : 1;
double aimDifficultyStrainCount = ((OsuStrainSkill)skills[0]).CountTopWeightedStrains();
ExpPolynomial aimMissPenaltyCurve = ((OsuProbabilitySkill)skills[0]).GetMissPenaltyCurve();
double speedDifficultyStrainCount = ((OsuStrainSkill)skills[2]).CountTopWeightedStrains();
if (mods.Any(m => m is OsuModTouchDevice))
@ -103,7 +105,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
SpeedNoteCount = speedNotes,
FlashlightDifficulty = flashlightRating,
SliderFactor = sliderFactor,
AimDifficultStrainCount = aimDifficultyStrainCount,
AimMissPenaltyCurve = aimMissPenaltyCurve,
SpeedDifficultStrainCount = speedDifficultyStrainCount,
ApproachRate = preempt > 1200 ? (1800 - preempt) / 120 : (1200 - preempt) / 150 + 5,
OverallDifficulty = (80 - hitWindowGreat) / 6,

View File

@ -5,7 +5,9 @@ using System;
using System.Collections.Generic;
using System.Linq;
using osu.Game.Rulesets.Difficulty;
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
using osu.Game.Rulesets.Osu.Difficulty.Skills;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
using osu.Game.Rulesets.Osu.Mods;
using osu.Game.Rulesets.Scoring;
using osu.Game.Scoring;
@ -139,10 +141,9 @@ namespace osu.Game.Rulesets.Osu.Difficulty
double lengthBonus = 0.95 + 0.4 * Math.Min(1.0, totalHits / 2000.0) +
(totalHits > 2000 ? Math.Log10(totalHits / 2000.0) * 0.5 : 0.0);
aimValue *= lengthBonus;
if (effectiveMissCount > 0)
aimValue *= calculateMissPenalty(effectiveMissCount, attributes.AimDifficultStrainCount);
aimValue *= calculateCurveFittedMissPenalty(effectiveMissCount, attributes.AimMissPenaltyCurve);
double approachRateFactor = 0.0;
if (attributes.ApproachRate > 10.33)
@ -206,7 +207,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
speedValue *= lengthBonus;
if (effectiveMissCount > 0)
speedValue *= calculateMissPenalty(effectiveMissCount, attributes.SpeedDifficultStrainCount);
speedValue *= calculateStrainCountMissPenalty(effectiveMissCount, attributes.SpeedDifficultStrainCount);
double approachRateFactor = 0.0;
if (attributes.ApproachRate > 10.33)
@ -305,10 +306,17 @@ namespace osu.Game.Rulesets.Osu.Difficulty
return flashlightValue;
}
// Miss penalty assumes that a player will miss on the hardest parts of a map,
// so we use the amount of relatively difficult sections to adjust miss penalty
// to make it more punishing on maps with lower amount of hard sections.
private double calculateMissPenalty(double missCount, double difficultStrainCount) => 0.96 / ((missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94))) + 1);
// Due to the unavailability of miss location in PP, the following formulas assume that a player will miss on the hardest parts of a map.
// With the curve fitted miss penalty, we use a pre-computed curve of skill levels for each miss count, raised to the power of 1.5 as
// the multiple of the exponents on star rating and PP. This power should be changed if either SR or PP begin to use a different exponent.
// As a result, this exponent is not subject to balance.
private double calculateCurveFittedMissPenalty(double missCount, ExpPolynomial curve) => Math.Pow(1 - curve.GetPenaltyAt(missCount), 1.5);
// With the strain count miss penalty, we use the amount of relatively difficult sections to adjust the miss penalty,
// to make it more punishing on maps with lower amount of hard sections. This formula is subject to balance.
private double calculateStrainCountMissPenalty(double missCount, double difficultStrainCount) => 0.96 / (missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94)) + 1);
private double getComboScalingFactor(OsuDifficultyAttributes attributes) => attributes.MaxCombo <= 0 ? 1.0 : Math.Min(Math.Pow(scoreMaxCombo, 0.8) / Math.Pow(attributes.MaxCombo, 0.8), 1.0);
private int totalHits => countGreat + countOk + countMeh + countMiss;
private int totalImperfectHits => countOk + countMeh + countMiss;

View File

@ -4,14 +4,16 @@
using System;
using osu.Game.Rulesets.Difficulty.Preprocessing;
using osu.Game.Rulesets.Mods;
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
using osu.Game.Rulesets.Osu.Difficulty.Evaluators;
using osu.Game.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty.Skills
{
/// <summary>
/// Represents the skill required to correctly aim at every object in the map with a uniform CircleSize and normalized distances.
/// </summary>
public class Aim : OsuStrainSkill
public class Aim : OsuProbabilitySkill
{
public Aim(Mod[] mods, bool withSliders)
: base(mods)
@ -23,12 +25,18 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
private double currentStrain;
private double skillMultiplier => 25.18;
private double skillMultiplier => 132;
private double strainDecayBase => 0.15;
private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000);
protected override double HitProbability(double skill, double difficulty)
{
if (difficulty <= 0) return 1;
if (skill <= 0) return 0;
protected override double CalculateInitialStrain(double time, DifficultyHitObject current) => currentStrain * strainDecay(time - current.Previous(0).StartTime);
return SpecialFunctions.Erf(skill / (Math.Sqrt(2) * difficulty));
}
private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000);
protected override double StrainValueAt(DifficultyHitObject current)
{

View File

@ -7,6 +7,7 @@ using osu.Game.Rulesets.Mods;
using osu.Game.Rulesets.Osu.Difficulty.Evaluators;
using osu.Game.Rulesets.Osu.Difficulty.Preprocessing;
using System.Linq;
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
namespace osu.Game.Rulesets.Osu.Difficulty.Skills
{

View File

@ -0,0 +1,62 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.Linq;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
public struct Bin
{
public double Difficulty;
public double Count;
/// <summary>
/// Create an array of spaced bins. Count is linearly interpolated into each bin.
/// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2,
/// we will add 0.8 to the count of 3's and 0.2 to the count of 4's
/// </summary>
public static List<Bin> CreateBins(List<double> difficulties, int totalBins)
{
double maxDifficulty = difficulties.Max();
var binsArray = new Bin[totalBins];
for (int i = 0; i < totalBins; i++)
{
binsArray[i].Difficulty = maxDifficulty * (i + 1) / totalBins;
}
foreach (double d in difficulties)
{
double binIndex = totalBins * (d / maxDifficulty) - 1;
int lowerBound = (int)Math.Floor(binIndex);
double t = binIndex - lowerBound;
//This can be -1, corresponding to the zero difficulty bucket.
//We don't store that since it doesn't contribute to difficulty
if (lowerBound >= 0)
{
binsArray[lowerBound].Count += (1 - t);
}
int upperBound = lowerBound + 1;
// this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway
if (upperBound < totalBins)
{
binsArray[upperBound].Count += t;
}
}
var binsList = binsArray.ToList();
// For a slight performance improvement, we remove bins that don't contribute to difficulty.
binsList.RemoveAll(bin => bin.Count == 0);
return binsList;
}
}
}

View File

@ -0,0 +1,86 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.Linq;
using osu.Game.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
/// <summary>
/// Represents a polynomial fitted to the logarithm of a given set of points.
/// The resulting polynomial is exponentiated, ensuring low residuals for
/// small inputs while handling exponentially increasing trends in the data.
/// This approach is useful for modelling the results of decreasing skill with few coefficients,
/// as linear decreases in skill correspond with exponential increases in miss counts.
/// </summary>
public struct ExpPolynomial
{
private double[]? coefficients;
// The matrix that minimizes the square error at X values [0.0, 0.30, 0.60, 0.80, 0.90, 0.95, 1.0].
private static readonly double[][] matrix =
{
new[] { 0.0, 3.14395, 5.18439, 6.46975, 1.4638, -9.53526, 0.0 },
new[] { 0.0, -4.85829, -8.09612, -10.4498, -3.84479, 12.1626, 0.0 }
};
/// <summary>
/// Computes the coefficients of a quartic polynomial, starting at 0 and ending at the highest miss count in the array.
/// </summary>
/// <param name="missCounts">A list of miss counts, with X values [1, 0.95, 0.9, 0.8, 0.6, 0.3, 0] corresponding to their skill levels.</param>
public void Fit(double[] missCounts)
{
List<double> logMissCounts = missCounts.Select(x => Math.Log(x + 1)).ToList();
double endPoint = logMissCounts.Max();
double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 };
for (int i = 0; i < logMissCounts.Count; i++)
{
logMissCounts[i] -= endPoint * (1 - penalties[i]);
}
// The precomputed matrix assumes the miss counts go in order of greatest to least.
logMissCounts.Reverse();
coefficients = new double[4];
coefficients[3] = endPoint;
// Now we dot product the adjusted miss counts with the precomputed matrix.
for (int row = 0; row < matrix.Length; row++)
{
for (int column = 0; column < matrix[row].Length; column++)
{
coefficients[row] += matrix[row][column] * logMissCounts[column];
}
coefficients[2] -= coefficients[row];
}
}
/// <summary>
/// Solve for the miss penalty at a specified miss count.
/// </summary>
public double GetPenaltyAt(double missCount)
{
if (coefficients is null)
return 1;
List<double> listCoefficients = coefficients.ToList();
listCoefficients.Add(-Math.Log(missCount + 1));
List<double?> xVals = SpecialFunctions.SolvePolynomialRoots(listCoefficients);
const double max_error = 1e-7;
// We find the largest value of x (corresponding to the penalty) found as a root of the function, with a fallback of a 100% penalty if no roots were found.
double largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault() ?? 1;
return Math.Clamp(largestValue, 0, 1);
}
}
}

View File

@ -0,0 +1,120 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using osu.Game.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
/// <summary>
/// Approximation of the Poisson binomial distribution:
/// https://en.wikipedia.org/wiki/Poisson_binomial_distribution
/// </summary>
/// <remarks>
/// <para>
/// For the approximation method, see "Refined Normal Approximation (RNA)" from:
/// Hong, Y. (2013). On computing the distribution function for the Poisson binomial distribution. Computational Statistics and Data Analysis, Vol. 59, pp. 41-51.
/// (https://www.researchgate.net/publication/257017356_On_computing_the_distribution_function_for_the_Poisson_binomial_distribution)
/// </para>
/// <para>
/// This has been verified against a reference implementation provided by the authors in the R package "poibin",
/// which can be viewed here:
/// https://rdrr.io/cran/poibin/man/poibin-package.html
/// </para>
/// </remarks>
public class PoissonBinomial
{
/// <summary>
/// The expected value of the distribution.
/// </summary>
private readonly double mu;
/// <summary>
/// The standard deviation of the distribution.
/// </summary>
private readonly double sigma;
/// <summary>
/// The gamma factor from equation (11) in the cited paper, pre-divided by 6 to save on re-computation.
/// </summary>
private readonly double v;
/// <summary>
/// Creates a Poisson binomial distribution based on N trials with the provided difficulties, skill, and method for getting the miss probabilities.
/// </summary>
/// <param name="difficulties">The list of difficulties in the map.</param>
/// <param name="skill">The skill level to get the miss probabilities with.</param>
/// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
public PoissonBinomial(IList<double> difficulties, double skill, Func<double, double, double> hitProbability)
{
double variance = 0;
double gamma = 0;
foreach (double d in difficulties)
{
double p = 1 - hitProbability(skill, d);
mu += p;
variance += p * (1 - p);
gamma += p * (1 - p) * (1 - 2 * p);
}
sigma = Math.Sqrt(variance);
v = gamma / (6 * Math.Pow(sigma, 3));
}
/// <summary>
/// Creates a Poisson binomial distribution based on N trials with the provided bins of difficulties, skill, and method for getting the miss probabilities.
/// </summary>
/// <param name="bins">The bins of difficulties in the map.</param>
/// <param name="skill">The skill level to get the miss probabilities with.</param>
/// /// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
public PoissonBinomial(List<Bin> bins, double skill, Func<double, double, double> hitProbability)
{
double variance = 0;
double gamma = 0;
foreach (Bin bin in bins)
{
double p = 1 - hitProbability(skill, bin.Difficulty);
mu += p * bin.Count;
variance += p * (1 - p) * bin.Count;
gamma += p * (1 - p) * (1 - 2 * p) * bin.Count;
}
sigma = Math.Sqrt(variance);
v = gamma / (6 * Math.Pow(sigma, 3));
}
/// <summary>
/// Computes the value of the cumulative distribution function for this Poisson binomial distribution.
/// </summary>
/// <param name="count">
/// The argument of the CDF to sample the distribution for.
/// In the discrete case (when it is a whole number), this corresponds to the number
/// of successful Bernoulli trials to query the CDF for.
/// </param>
/// <returns>
/// The value of the CDF at <paramref name="count"/>.
/// In the discrete case this corresponds to the probability that at most <paramref name="count"/>
/// Bernoulli trials ended in a success.
/// </returns>
// ReSharper disable once InconsistentNaming
public double CDF(double count)
{
if (sigma == 0)
return 1;
double k = (count + 0.5 - mu) / sigma;
// see equation (14) of the cited paper
double result = SpecialFunctions.NormalCdf(0, 1, k) + v * (1 - k * k) * SpecialFunctions.NormalPdf(0, 1, k);
return Math.Clamp(result, 0, 1);
}
}
}

View File

@ -0,0 +1,119 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
public static class RootFinding
{
/// <summary>
/// Finds the root of a <paramref name="function"/> using the Chandrupatla method, expanding the bounds if the root is not located within.
/// Expansion only occurs for the upward bound, as this function is optimized for functions of range [0, x),
/// which is useful for finding skill level (skill can never be below 0).
/// </summary>
/// <param name="function">The function of which to find the root.</param>
/// <param name="guessLowerBound">The lower bound of the function inputs.</param>
/// <param name="guessUpperBound">The upper bound of the function inputs.</param>
/// <param name="maxIterations">The maximum number of iterations before the function throws an error.</param>
/// <param name="accuracy">The desired precision in which the root is returned.</param>
/// <param name="expansionFactor">The multiplier on the upper bound when no root is found within the provided bounds.</param>
/// <param name="maxExpansions">The maximum number of times the bounds of the function should increase.</param>
public static double FindRootExpand(Func<double, double> function, double guessLowerBound, double guessUpperBound, int maxIterations = 25, double accuracy = 1e-6D, double expansionFactor = 2, double maxExpansions = 32)
{
double a = guessLowerBound;
double b = guessUpperBound;
double fa = function(a);
double fb = function(b);
int expansions = 0;
while (fa * fb > 0)
{
a = b;
b *= expansionFactor;
fa = function(a);
fb = function(b);
expansions++;
if (expansions > maxExpansions)
{
throw new MaximumIterationsException("No root was found within the provided function.");
}
}
double t = 0.5;
for (int i = 0; i < maxIterations; i++)
{
double xt = a + t * (b - a);
double ft = function(xt);
double c;
double fc;
if (Math.Sign(ft) == Math.Sign(fa))
{
c = a;
fc = fa;
}
else
{
c = b;
b = a;
fc = fb;
fb = fa;
}
a = xt;
fa = ft;
double xm, fm;
if (Math.Abs(fa) < Math.Abs(fb))
{
xm = a;
fm = fa;
}
else
{
xm = b;
fm = fb;
}
if (fm == 0)
return xm;
double tol = 2 * accuracy * Math.Abs(xm) + 2 * accuracy;
double tlim = tol / Math.Abs(b - c);
if (tlim > 0.5)
{
return xm;
}
double chi = (a - b) / (c - b);
double phi = (fa - fb) / (fc - fb);
bool iqi = phi * phi < chi && (1 - phi) * (1 - phi) < chi;
if (iqi)
t = fa / (fb - fa) * fc / (fb - fc) + (c - a) / (b - a) * fa / (fc - fa) * fb / (fc - fb);
else
t = 0.5;
t = Math.Min(1 - tlim, Math.Max(tlim, t));
}
return 0;
}
private class MaximumIterationsException : Exception
{
public MaximumIterationsException(string message)
: base(message)
{
}
}
}
}

View File

@ -13,12 +13,17 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
*/
using System;
using System.Collections.Generic;
using System.Linq;
using osu.Framework.Utils;
namespace osu.Game.Utils
{
public class SpecialFunctions
{
private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d;
private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d;
private const double pi_mult_2 = 6.28318530717958647692528676655900576d;
/// <summary>
/// **************************************
@ -690,5 +695,280 @@ namespace osu.Game.Utils
return sum;
}
/// <summary>
/// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x.
/// </summary>
/// <param name="mean">The mean (μ) of the normal distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
/// <param name="x">The location at which to compute the density.</param>
/// <returns>the density at <paramref name="x"/>.</returns>
/// <remarks>MATLAB: normpdf</remarks>
public static double NormalPdf(double mean, double stddev, double x)
{
if (stddev < 0.0)
{
throw new ArgumentException("Invalid parametrization for the distribution.");
}
double d = (x - mean) / stddev;
return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev);
}
/// <summary>
/// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x).
/// </summary>
/// <param name="x">The location at which to compute the cumulative distribution function.</param>
/// <param name="mean">The mean (μ) of the normal distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
/// <returns>the cumulative distribution at location <paramref name="x"/>.</returns>
/// <remarks>MATLAB: normcdf</remarks>
public static double NormalCdf(double mean, double stddev, double x)
{
if (stddev < 0.0)
{
throw new ArgumentException("Invalid parametrization for the distribution.");
}
return 0.5 * Erfc((mean - x) / (stddev * sqrt2));
}
/// <summary>
/// Solve for the exact real roots of any polynomial up to degree 4.
/// </summary>
/// <param name="coefficients">The coefficients of the polynomial, in ascending order ([1, 3, 5] -> x^2 + 3x + 5).</param>
/// <returns>The real roots of the polynomial, and null if the root does not exist.</returns>
public static List<double?> SolvePolynomialRoots(List<double> coefficients)
{
List<double?> xVals = new List<double?>();
switch (coefficients.Count)
{
case 5:
xVals = solveP4(coefficients[0], coefficients[1], coefficients[2], coefficients[3], coefficients[4], out int _).ToList();
break;
case 4:
xVals = solveP3(coefficients[0], coefficients[1], coefficients[2], coefficients[3], out int _).ToList();
break;
case 3:
xVals = solveP2(coefficients[0], coefficients[1], coefficients[2], out int _).ToList();
break;
case 2:
xVals = solveP2(0, coefficients[1], coefficients[2], out int _).ToList();
break;
}
return xVals;
}
// https://github.com/sasamil/Quartic/blob/master/quartic.cpp
private static double?[] solveP4(double a, double b, double c, double d, double e, out int nRoots)
{
double?[] xVals = new double?[4];
nRoots = 0;
if (a == 0)
{
double?[] xValsCubic = solveP3(b, c, d, e, out nRoots);
xVals[0] = xValsCubic[0];
xVals[1] = xValsCubic[1];
xVals[2] = xValsCubic[2];
xVals[3] = null;
return xVals;
}
b /= a;
c /= a;
d /= a;
e /= a;
double a3 = -c;
double b3 = b * d - 4 * e;
double c3 = -b * b * e - d * d + 4 * c * e;
double?[] x3 = solveP3(1, a3, b3, c3, out int iZeroes);
double q1, q2, p1, p2, sqD;
double y = x3[0]!.Value;
// Get the y value with the highest absolute value.
if (iZeroes != 1)
{
if (Math.Abs(x3[1]!.Value) > Math.Abs(y))
y = x3[1]!.Value;
if (Math.Abs(x3[2]!.Value) > Math.Abs(y))
y = x3[2]!.Value;
}
double upperD = y * y - 4 * e;
if (Precision.AlmostEquals(upperD, 0))
{
q1 = q2 = y * 0.5;
upperD = b * b - 4 * (c - y);
if (Precision.AlmostEquals(upperD, 0))
p1 = p2 = b * 0.5;
else
{
sqD = Math.Sqrt(upperD);
p1 = (b + sqD) * 0.5;
p2 = (b - sqD) * 0.5;
}
}
else
{
sqD = Math.Sqrt(upperD);
q1 = (y + sqD) * 0.5;
q2 = (y - sqD) * 0.5;
p1 = (b * q1 - d) / (q1 - q2);
p2 = (d - b * q2) / (q1 - q2);
}
// solving quadratic eq. - x^2 + p1*x + q1 = 0
upperD = p1 * p1 - 4 * q1;
if (upperD >= 0)
{
nRoots += 2;
sqD = Math.Sqrt(upperD);
xVals[0] = (-p1 + sqD) * 0.5;
xVals[1] = (-p1 - sqD) * 0.5;
}
// solving quadratic eq. - x^2 + p2*x + q2 = 0
upperD = p2 * p2 - 4 * q2;
if (upperD >= 0)
{
nRoots += 2;
sqD = Math.Sqrt(upperD);
xVals[2] = (-p2 + sqD) * 0.5;
xVals[3] = (-p2 - sqD) * 0.5;
}
// Put the null roots at the end of the array.
var nonNulls = xVals.Where(x => x != null);
var nulls = xVals.Where(x => x == null);
xVals = nonNulls.Concat(nulls).ToArray();
return xVals;
}
private static double?[] solveP3(double a, double b, double c, double d, out int nRoots)
{
double?[] xVals = new double?[3];
nRoots = 0;
if (a == 0)
{
double?[] xValsQuadratic = solveP2(b, c, d, out nRoots);
xVals[0] = xValsQuadratic[0];
xVals[1] = xValsQuadratic[1];
xVals[2] = null;
return xVals;
}
b /= a;
c /= a;
d /= a;
double a2 = b * b;
double q = (a2 - 3 * c) / 9;
double q3 = q * q * q;
double r = (b * (2 * a2 - 9 * c) + 27 * d) / 54;
double r2 = r * r;
if (r2 < q3)
{
nRoots = 3;
double t = r / Math.Sqrt(q3);
t = Math.Clamp(t, -1, 1);
t = Math.Acos(t);
b /= 3;
q = -2 * Math.Sqrt(q);
xVals[0] = q * Math.Cos(t / 3) - b;
xVals[1] = q * Math.Cos((t + pi_mult_2) / 3) - b;
xVals[2] = q * Math.Cos((t - pi_mult_2) / 3) - b;
return xVals;
}
double upperA = -Math.Cbrt(Math.Abs(r) + Math.Sqrt(r2 - q3));
if (r < 0)
upperA = -upperA;
double upperB = upperA == 0 ? 0 : q / upperA;
b /= 3;
xVals[0] = upperA + upperB - b;
if (Precision.AlmostEquals(0.5 * Math.Sqrt(3) * (upperA - upperB), 0))
{
nRoots = 2;
xVals[1] = -0.5 * (upperA + upperB) - b;
return xVals;
}
nRoots = 1;
return xVals;
}
private static double?[] solveP2(double a, double b, double c, out int nRoots)
{
double?[] xVals = new double?[2];
nRoots = 0;
if (a == 0)
{
if (b == 0)
return xVals;
nRoots = 1;
xVals[0] = -c / b;
}
double discriminant = b * b - 4 * a * c;
switch (discriminant)
{
case < 0:
break;
case 0:
nRoots = 1;
xVals[0] = -b / (2 * a);
break;
default:
nRoots = 2;
xVals[0] = (-b + Math.Sqrt(discriminant)) / (2 * a);
xVals[1] = (-b - Math.Sqrt(discriminant)) / (2 * a);
break;
}
return xVals;
}
}
}