''' A root finding routine. See "All Problems Are Simple" by Jack Crenshaw, Embedded Systems Programming, May, 2002, pg 7-14, jcrens@earthlink.com. Can be downloaded from www.embedded.com/code.htm. Translated from Crenshaw's C code modified by Don Peterson 20 May 2003. Crenshaw states this routine will converge rapidly on most functions, typically adding 4 digits to the solution on each iteration. The method is something called "inverse parabolic interpolation". The routine works by starting with x0, x2, and finding a third x1 by bisection. The ordinates are gotten, then a horizontally- opening parabola is fitted to the points. The parabola's root's abcissa is gotten, and the iteration is repeated. The function root_find will find a root of the function f(x) in the interval [x0, x2]. We must have that f(x0)*f(x2) < 0. The root value is returned. 31 Dec 2005: Added the root_find_D() function, which performs using python's Decimal objects. This will allow you to find roots to any desired accuracy if you're willing to wait. ''' import math from decimal import * def root_find(x0, x2, f, eps, itmax): '''Root lies between x0 and x2. f is the function to evaluate; it takes one float argument and returns a float. eps is the precision to find the root to and itmax is the maximum number of iterations allowed. Returns a tuple (x, numits) where x is the root. numits is the number of iterations taken. The routine will throw an exception if it receives bad input data or it doesn't converge. ''' x1=y0=y1=y2=b=c=temp=y10=y20=y21=xm=ym=xmlast=x0=0.0 assert(x0 < x2) assert(eps > 0.0) assert(itmax > 0) y0 = f(x0) if y0 == 0.0: return x0, 0 y2 = f(x2) if y2 == 0.0: return x2, 0 if y2 * y0 > 0.0: raise "Bad data: y0 = %f, y2 = %f\n"% (y0, y2) for ix in xrange(itmax): x1 = 0.5 * (x2 + x0) y1 = f(x1) if (y1 == 0.0) or (math.fabs(x1 - x0) < eps): return x1, ix+1 if y1 * y0 > 0.0: temp = x0 x0 = x2 x2 = temp temp = y0 y0 = y2 y2 = temp y10 = y1 - y0 y21 = y2 - y1 y20 = y2 - y0 if y2 * y20 < 2.0 * y1 * y10: x2 = x1 y2 = y1 else: b = (x1 - x0) / y10 c = (y10 - y21) / (y21 * y20) xm = x0 - b * y0 * (1.0 - c * y1) ym = f(xm) if ((ym == 0.0) or (math.fabs(xm - xmlast) < eps)): return xm, ix+1 xmlast = xm if ym * y0 < 0.0: x2 = xm y2 = ym else: x0 = xm y0 = ym x2 = x1 y2 = y1 raise "No convergence" def root_find_D(x0, x2, f, eps, itmax): '''Root lies between x0 and x2. f is the function to evaluate; it takes one Decimal argument and returns a Decimal. eps is the Decimal precision to find the root to and itmax is the maximum number of iterations allowed. Returns a tuple (x, numits) where x is the root, a Decimal. numits is the number of iterations taken. The routine will throw an exception if it receives bad input data or it doesn't converge. ''' x1=y0=y1=y2=b=c=temp=y10=y20=y21=xm=ym=xmlast=x0=Decimal(0) zero = Decimal(0) one = Decimal(1) two = Decimal(2) assert(x0 < x2) assert(eps > zero) assert(itmax > 0) y0 = f(x0) if y0 == zero: return x0, 0 y2 = f(x2) if y2 == zero: return x2, 0 if y2 * y0 > zero: raise "Bad data: y0 = %f, y2 = %f\n"% (y0, y2) for ix in xrange(itmax): x1 = (x2 + x0)/two y1 = f(x1) if (y1 == zero) or (abs(x1 - x0) < eps): return x1, ix+1 if y1 * y0 > zero: temp = x0 x0 = x2 x2 = temp temp = y0 y0 = y2 y2 = temp y10 = y1 - y0 y21 = y2 - y1 y20 = y2 - y0 if y2 * y20 < two * y1 * y10: x2 = x1 y2 = y1 else: b = (x1 - x0) / y10 c = (y10 - y21) / (y21 * y20) xm = x0 - b * y0 * (one - c * y1) ym = f(xm) if ((ym == zero) or (abs(xm - xmlast) < eps)): return xm, ix+1 xmlast = xm if ym * y0 < zero: x2 = xm y2 = ym else: x0 = xm y0 = ym x2 = x1 y2 = y1 raise "No convergence" if __name__ == "__main__": '''Here's a quick test of the routine. The function is the polynomial x^8 - 2 = 0; we should get as an answer the 8th root of 2. You should see the following output: Calculated root = 1.090507732665258 Correct value = 1.090507732665258 Num iterations = 9 Calculated root = 1.090507732665257659207010655760707978993 Correct value = 1.090507732665258 Num iterations = 14 ''' def f(x): return math.pow(x, 8) - 2 eps = 1e-10 itmax = 20 x0 = 0.0 x1 = 10.0 root, numits = root_find(x0, x1, f, eps, itmax) print "Calculated root = %.15f" % root print "Correct value = %.15f" % math.pow(2, 0.125) print "Num iterations = %d" % numits print # Calculation with Decimals def f_D(x): return x*x*x*x*x*x*x*x - 2 # Calculate to a large number of decimal places digits = 40 eps = Decimal("1e-%d" % digits) itmax = 20 x0 = Decimal("0.0") x1 = Decimal("10.0") getcontext().prec = digits root, numits = root_find_D(x0, x1, f_D, eps, itmax) print "Calculated root =", root print "Correct value = %.15f" % math.pow(2, 0.125) print "Num iterations = %d" % numits