Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions Maths/EulerMethod.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package Maths;

import java.util.ArrayList;
import java.util.function.BiFunction;

/**
* In mathematics and computational science, the Euler method (also called forward Euler method) is
* a first-order numerical procedure for solving ordinary differential equations (ODEs) with a given
* initial value. It is the most basic explicit method for numerical integration of ordinary
* differential equations. The method proceeds in a series of steps. At each step the y-value is
* calculated by evaluating the differential equation at the previous step, multiplying the result
* with the step-size and adding it to the last y-value: y_n+1 = y_n + stepSize * f(x_n, y_n).
* (description adapted from https://en.wikipedia.org/wiki/Euler_method ) (see also:
* https://www.geeksforgeeks.org/euler-method-solving-differential-equation/ )
*/
public class EulerMethod {

/** Illustrates how the algorithm is used in 3 examples and prints the results to the console. */
public static void main(String[] args) {
System.out.println("example 1:");
BiFunction<Double, Double, Double> exampleEquation1 = (x, y) -> x;
ArrayList<double[]> points1 = eulerFull(0, 4, 0.1, 0, exampleEquation1);
assert points1.get(points1.size() - 1)[1] == 7.800000000000003;
points1.forEach(
point -> System.out.println(String.format("x: %1$f; y: %2$f", point[0], point[1])));

// example from https://en.wikipedia.org/wiki/Euler_method
System.out.println("\n\nexample 2:");
BiFunction<Double, Double, Double> exampleEquation2 = (x, y) -> y;
ArrayList<double[]> points2 = eulerFull(0, 4, 0.1, 1, exampleEquation2);
assert points2.get(points2.size() - 1)[1] == 45.25925556817596;
points2.forEach(
point -> System.out.println(String.format("x: %1$f; y: %2$f", point[0], point[1])));

// example from https://www.geeksforgeeks.org/euler-method-solving-differential-equation/
System.out.println("\n\nexample 3:");
BiFunction<Double, Double, Double> exampleEquation3 = (x, y) -> x + y + x * y;
ArrayList<double[]> points3 = eulerFull(0, 0.1, 0.025, 1, exampleEquation3);
assert points3.get(points3.size() - 1)[1] == 1.1116729841674804;
points3.forEach(
point -> System.out.println(String.format("x: %1$f; y: %2$f", point[0], point[1])));
}

/**
* calculates the next y-value based on the current value of x, y and the stepSize the console.
*
* @param xCurrent Current x-value.
* @param stepSize Step-size on the x-axis.
* @param yCurrent Current y-value.
* @param differentialEquation The differential equation to be solved.
* @return The next y-value.
*/
public static double eulerStep(
double xCurrent,
double stepSize,
double yCurrent,
BiFunction<Double, Double, Double> differentialEquation) {
if (stepSize <= 0) {
throw new IllegalArgumentException("stepSize should be greater than zero");
}
double yNext = yCurrent + stepSize * differentialEquation.apply(xCurrent, yCurrent);
return yNext;
}

/**
* Loops through all the steps until xEnd is reached, adds a point for each step and then returns
* all the points
*
* @param xStart First x-value.
* @param xEnd Last x-value.
* @param stepSize Step-size on the x-axis.
* @param yStart First y-value.
* @param differentialEquation The differential equation to be solved.
* @return The points constituting the solution of the differential equation.
*/
public static ArrayList<double[]> eulerFull(
double xStart,
double xEnd,
double stepSize,
double yStart,
BiFunction<Double, Double, Double> differentialEquation) {
if (xStart >= xEnd) {
throw new IllegalArgumentException("xEnd should be greater than xStart");
}
if (stepSize <= 0) {
throw new IllegalArgumentException("stepSize should be greater than zero");
}

ArrayList<double[]> points = new ArrayList<double[]>();
double[] firstPoint = {xStart, yStart};
points.add(firstPoint);
double yCurrent = yStart;
double xCurrent = xStart;

while (xCurrent < xEnd) {
// Euler method for next step
yCurrent = eulerStep(xCurrent, stepSize, yCurrent, differentialEquation);
xCurrent += stepSize;
double[] point = {xCurrent, yCurrent};
points.add(point);
}

return points;
}
}