1
0
mirror of https://github.com/ppy/osu.git synced 2024-12-05 09:42:54 +08:00

Refactoring and test solving polynomial algebraically

This commit is contained in:
Nathen 2024-04-30 12:44:45 -04:00
parent dd951400a4
commit 0e08858b17
7 changed files with 344 additions and 59 deletions

View File

@ -8,6 +8,7 @@ using Newtonsoft.Json;
using osu.Game.Beatmaps;
using osu.Game.Rulesets.Difficulty;
using osu.Game.Rulesets.Mods;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty
{
@ -23,7 +24,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
/// The difficulty corresponding to the aim skill.
/// </summary>
[JsonProperty("aim_penalty_constants")]
public (double, double) AimPenaltyConstants { get; set; }
public ExpPolynomial AimMissCountPolynomial { get; set; }
/// <summary>
/// The difficulty corresponding to the speed skill.

View File

@ -13,6 +13,7 @@ using osu.Game.Rulesets.Difficulty.Skills;
using osu.Game.Rulesets.Mods;
using osu.Game.Rulesets.Osu.Difficulty.Preprocessing;
using osu.Game.Rulesets.Osu.Difficulty.Skills;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
using osu.Game.Rulesets.Osu.Mods;
using osu.Game.Rulesets.Osu.Objects;
using osu.Game.Rulesets.Osu.Scoring;
@ -37,7 +38,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
return new OsuDifficultyAttributes { Mods = mods };
double aimRating = Math.Sqrt(skills[0].DifficultyValue()) * difficulty_multiplier;
(double, double) aimPenaltyConstants = ((Aim)skills[0]).GetMissCountCoefficients();
ExpPolynomial aimMissCountPolynomial = ((Aim)skills[0]).GetMissCountPolynomial();
double aimRatingNoSliders = Math.Sqrt(skills[1].DifficultyValue()) * difficulty_multiplier;
double speedRating = Math.Sqrt(skills[2].DifficultyValue()) * difficulty_multiplier;
double speedNotes = ((Speed)skills[2]).RelevantNoteCount();
@ -98,7 +99,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
StarRating = starRating,
Mods = mods,
AimDifficulty = aimRating,
AimPenaltyConstants = aimPenaltyConstants,
AimMissCountPolynomial = aimMissCountPolynomial,
SpeedDifficulty = speedRating,
SpeedNoteCount = speedNotes,
FlashlightDifficulty = flashlightRating,

View File

@ -5,7 +5,6 @@ using System;
using System.Collections.Generic;
using System.Linq;
using osu.Game.Rulesets.Difficulty;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
using osu.Game.Rulesets.Osu.Mods;
using osu.Game.Rulesets.Scoring;
using osu.Game.Scoring;
@ -244,11 +243,11 @@ namespace osu.Game.Rulesets.Osu.Difficulty
private double calculateAimMissPenalty(double missCount, OsuDifficultyAttributes attributes)
{
double a = attributes.AimPenaltyConstants.Item1;
double b = attributes.AimPenaltyConstants.Item2;
double c = Math.Log(totalHits + 1) - a - b; // Setting the 3rd constant this way ensures that at a penalty of 100%, the number of misses = totalHits.
double penalty = attributes.AimMissCountPolynomial.SolveBetweenZeroAndOne(missCount) ?? 1;
return Math.Pow(1 - RootFinding.FindRootExpand(x => a * x * x * x + b * x * x + c * x - Math.Log(missCount + 1), 0, 1), 1.5);
double multiplier = Math.Pow(1 - penalty, 1.5);
return multiplier;
}
private double calculateEffectiveMissCount(OsuDifficultyAttributes attributes)

View File

@ -51,7 +51,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
/// The coefficients of a quartic fitted to the miss counts at each skill level.
/// </summary>
/// <returns>The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator.</returns>
public (double, double) GetMissCountCoefficients()
public ExpPolynomial GetMissCountPolynomial()
{
const int count = 21;
const double penalty_per_misscount = 1.0 / (count - 1);
@ -74,9 +74,11 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
misscounts[i] = Math.Log(GetMissCountAtSkill(penalizedSkill) + 1);
}
double[] constants = FitMissCountPoints.GetPolynomialCoefficients(misscounts);
ExpPolynomial polynomial = new ExpPolynomial();
return (constants[0], constants[1]);
polynomial.Compute(misscounts, 3);
return polynomial;
}
}
}

View File

@ -0,0 +1,90 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.Linq;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
public struct ExpPolynomial
{
private static double[]? coefficients;
// The product of this matrix with 21 computed points at X values [0.0, 0.05, ..., 0.95, 1.0] returns the least squares fit polynomial coefficients.
private static double[][] quarticMatrix => new[]
{
new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 },
new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 },
new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 }
};
private static double[][] cubicMatrix => new[]
{
new[] { 0.0, -0.897868, -1.5122, -1.8745, -2.01626, -1.96901, -1.76423, -1.43344, -1.00813, -0.519818, 3.55271e-15, 0.519818, 1.00813, 1.43344, 1.76423, 1.96901, 2.01626, 1.8745, 1.5122, 0.897868, 0.0 },
new[] { 0.0, 1.27555, 2.1333, 2.62049, 2.78439, 2.67226, 2.33134, 1.8089, 1.1522, 0.408475, -0.375002, -1.15098, -1.8722, -2.49141, -2.96135, -3.23476, -3.2644, -3.00299, -2.4033, -1.41805, 0.0 },
};
/// <summary>
/// Computes a quartic or cubic function that starts at 0 and ends at the highest judgement count in the array.
/// </summary>
/// <param name="judgementCounts">A list of judgements, with X values [0.0, 0.05, ..., 0.95, 1.0].</param>
/// <param name="degree">The degree of the polynomial. Only supports cubic and quintic functions.</param>
public void Compute(double[] judgementCounts, int degree)
{
if (degree != 3 && degree != 4)
return;
double[] adjustedMissCounts = judgementCounts;
// The polynomial will pass through the point (1, maxMisscount).
double maxMissCount = judgementCounts.Max();
for (int i = 0; i <= 20; i++)
{
adjustedMissCounts[i] -= maxMissCount * i / 20;
}
// The precomputed matrix assumes the misscounts go in order of greatest to least.
// Temporary fix.
adjustedMissCounts = adjustedMissCounts.Reverse().ToArray();
double[][] matrix = degree == 4 ? quarticMatrix : cubicMatrix;
coefficients = new double[degree];
coefficients[degree - 1] = maxMissCount;
// Now we dot product the adjusted misscounts with the precomputed matrix.
for (int row = 0; row < matrix.Length; row++)
{
for (int column = 0; column < matrix[row].Length; column++)
{
coefficients[row] += matrix[row][column] * adjustedMissCounts[column];
}
coefficients[degree - 1] -= coefficients[row];
}
}
/// <summary>
/// Solve for the largest corresponding x value of a polynomial within x = 0 and x = 1 at a specified y value.
/// </summary>
/// <param name="y">A value between 0 and 1, inclusive, to solve the polynomial at.</param>
/// <returns>The x value at the specified y value, and null if no value exists.</returns>
public double? SolveBetweenZeroAndOne(double y)
{
if (coefficients is null)
return null;
List<double> listCoefficients = coefficients.ToList();
listCoefficients.Add(-Math.Log(y + 1));
List<double?> xVals = SpecialFunctions.SolvePolynomialRoots(listCoefficients);
const double max_error = 1e-7;
double? largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault();
return largestValue != null ? Math.Clamp(largestValue.Value, 0, 1) : null;
}
}
}

View File

@ -1,48 +0,0 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System.Linq;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
public class FitMissCountPoints
{
// A few operations are precomputed on the Vandermonde matrix of skill values (0, 0.05, 0.1, ..., 0.95, 1).
// This is to make regression simpler and less resource heavy. https://discord.com/channels/546120878908506119/1203154944492830791/1232333184306512002
private static double[][] precomputedOperationsMatrix => new[]
{
new[] { 0.0, -0.897868, -1.5122, -1.8745, -2.01626, -1.96901, -1.76423, -1.43344, -1.00813, -0.519818, 3.55271e-15, 0.519818, 1.00813, 1.43344, 1.76423, 1.96901, 2.01626, 1.8745, 1.5122, 0.897868, 0.0 },
new[] { 0.0, 1.27555, 2.1333, 2.62049, 2.78439, 2.67226, 2.33134, 1.8089, 1.1522, 0.408475, -0.375002, -1.15098, -1.8722, -2.49141, -2.96135, -3.23476, -3.2644, -3.00299, -2.4033, -1.41805, 0.0 },
};
public static double[] GetPolynomialCoefficients(double[] missCounts)
{
double[] adjustedMissCounts = missCounts;
// The polynomial will pass through the point (1, maxMisscount).
double maxMissCount = missCounts.Max();
for (int i = 0; i <= 20; i++)
{
adjustedMissCounts[i] -= maxMissCount * i / 20;
}
// The precomputed matrix assumes the misscounts go in order of greatest to least.
// Temporary fix.
adjustedMissCounts = adjustedMissCounts.Reverse().ToArray();
double[] coefficients = new double[2];
// Now we dot product the adjusted misscounts with the precomputed matrix.
for (int row = 0; row < precomputedOperationsMatrix.Length; row++)
{
for (int column = 0; column < precomputedOperationsMatrix[row].Length; column++)
{
coefficients[row] += precomputedOperationsMatrix[row][column] * adjustedMissCounts[column];
}
}
return coefficients;
}
}
}

View File

@ -13,6 +13,9 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
*/
using System;
using System.Collections.Generic;
using System.Linq;
using osu.Framework.Utils;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
@ -20,6 +23,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d;
private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d;
private const double m_2_pi = 6.28318530717958647692528676655900576d;
/// <summary>
/// **************************************
@ -728,5 +732,241 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
return 0.5 * Erfc((mean - x) / (stddev * sqrt2));
}
/// <summary>
/// Solve for the exact real roots of any polynomial up to degree 4.
/// </summary>
/// <param name="coefficients">The coefficients of the polynomial, in ascending order ([1, 3, 5] -> x^2 + 3x + 5).</param>
/// <returns>The real roots of the polynomial, and null if the root does not exist.</returns>
public static List<double?> SolvePolynomialRoots(List<double> coefficients)
{
List<double?> xVals = new List<double?>();
switch (coefficients.Count)
{
case 5:
xVals = solveP4(coefficients[0], coefficients[1], coefficients[2], coefficients[3], coefficients[4], out int _).ToList();
break;
case 4:
xVals = solveP3(coefficients[0], coefficients[1], coefficients[2], coefficients[3], out int _).ToList();
break;
case 3:
xVals = solveP2(coefficients[0], coefficients[1], coefficients[2], out int _).ToList();
break;
case 2:
xVals = solveP2(0, coefficients[1], coefficients[2], out int _).ToList();
break;
}
return xVals;
}
private static double?[] solveP4(double a, double b, double c, double d, double e, out int nRoots)
{
double?[] xVals = new double?[4];
nRoots = 0;
if (a == 0)
{
double?[] xValsCubic = solveP3(b, c, d, e, out nRoots);
xVals[0] = xValsCubic[0];
xVals[1] = xValsCubic[1];
xVals[2] = xValsCubic[2];
xVals[3] = null;
return xVals;
}
b /= a;
c /= a;
d /= a;
double a3 = -c;
double b3 = b * d - 4 * e;
double c3 = -b * b * e - d * d + 4 * c * e;
double?[] x3 = solveP3(1, a3, b3, c3, out int iZeroes);
double q1, q2, p1, p2, sqD;
double y = x3[0]!.Value;
// Get the y value with the highest absolute value.
if (iZeroes != 1)
{
if (Math.Abs(x3[1]!.Value) > Math.Abs(y))
y = x3[1]!.Value;
if (Math.Abs(x3[2]!.Value) > Math.Abs(y))
y = x3[2]!.Value;
}
double upperD = y * y - 4 * e;
if (Precision.AlmostEquals(upperD, 0))
{
q1 = q2 = y * 0.5;
upperD = b * b - 4 * (c - y);
if (Precision.AlmostEquals(upperD, 0))
p1 = p2 = b * 0.5;
else
{
sqD = Math.Sqrt(upperD);
p1 = (b + sqD) * 0.5;
p2 = (b - sqD) * 0.5;
}
}
else
{
sqD = Math.Sqrt(upperD);
q1 = (y + sqD) * 0.5;
q2 = (y - sqD) * 0.5;
p1 = (b * q1 - c) / (q1 - q2);
p2 = (d - b * q2) / (q1 - q2);
}
// solving quadratic eq. - x^2 + p1*x + q1 = 0
upperD = p1 * p1 - 4 * q1;
if (upperD >= 0)
{
nRoots += 2;
sqD = Math.Sqrt(upperD);
xVals[0] = (-p1 + sqD) * 0.5;
xVals[1] = (-p1 - sqD) * 0.5;
}
// solving quadratic eq. - x^2 + p2*x + q2 = 0
upperD = p2 * p2 - 4 * q2;
if (upperD >= 0)
{
nRoots += 2;
sqD = Math.Sqrt(upperD);
xVals[2] = (-p2 + sqD) * 0.5;
xVals[3] = (-p2 - sqD) * 0.5;
}
// Put the null roots at the end of the array.
var nonNulls = xVals.Where(x => x != null);
var nulls = xVals.Where(x => x == null);
xVals = nonNulls.Concat(nulls).ToArray();
return xVals;
}
private static double?[] solveP3(double a, double b, double c, double d, out int nRoots)
{
double?[] xVals = new double?[3];
nRoots = 0;
if (a == 0)
{
double?[] xValsQuadratic = solveP2(b, c, d, out nRoots);
xVals[0] = xValsQuadratic[0];
xVals[1] = xValsQuadratic[1];
xVals[2] = null;
return xVals;
}
b /= a;
c /= a;
d /= a;
double b2 = b * b;
double q = (b2 - 3 * c) / 9;
double q3 = q * q * q;
double r = (b * (2 * b2 - 9 * c) + 27 * d) / 54;
double r2 = r * r;
if (r2 < q3)
{
nRoots = 3;
double t = r / Math.Sqrt(q3);
t = Math.Clamp(t, -1, 1);
t = Math.Acos(t);
b /= 3;
q = -2 * Math.Sqrt(2);
xVals[0] = q * Math.Cos(t / 3) - b;
xVals[1] = q * Math.Cos((t + m_2_pi) / 3) - b;
xVals[2] = q * Math.Cos((t - m_2_pi) / 3) - b;
return xVals;
}
double upperA = -Math.Cbrt(Math.Abs(r) + double.Sqrt(r2 - q3));
if (r < 0)
upperA = -upperA;
double upperB = upperA == 0 ? 0 : q / upperA;
b /= 3;
xVals[0] = upperA + upperB - b;
if (Precision.AlmostEquals(0.5 * Math.Sqrt(3) * (upperA - upperB), 0))
{
nRoots = 2;
xVals[1] = -0.5 * (upperA + upperB) - b;
return xVals;
}
nRoots = 1;
return xVals;
}
private static double?[] solveP2(double a, double b, double c, out int nRoots)
{
double?[] xVals = new double?[2];
nRoots = 0;
if (a == 0)
{
if (b == 0)
return xVals;
nRoots = 1;
xVals[0] = -c / b;
}
double discriminant = b * b - 4 * a * c;
switch (discriminant)
{
case < 0:
break;
case 0:
nRoots = 1;
xVals[0] = -b / (2 * a);
break;
default:
nRoots = 2;
xVals[0] = (-b + Math.Sqrt(discriminant)) / (2 * a);
xVals[1] = (-b - Math.Sqrt(discriminant)) / (2 * a);
break;
}
return xVals;
}
}
}