Last active
March 27, 2019 22:02
-
-
Save mwacc/7b93865a5cf19b19cc0ed7491fad29e3 to your computer and use it in GitHub Desktop.
fourth-degree polynomial using Halley's Method.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
print(tf.__version__) | |
def calc_fx1(fx, x): | |
return tf.gradients(fx, x)[0] | |
def calc_fx2(fx, x): | |
return tf.gradients(tf.gradients(fx, x)[0], x)[0] | |
def calc_next_x(a, x, xn): | |
a0 = a[:,0] | |
a1 = a[:,1] | |
a2 = a[:,2] | |
a3 = a[:,3] | |
a4 = a[:,4] | |
_xn = tf.Print(xn, [x, xn], "Args: ") | |
x = _xn | |
fx = a0 + a1*x + a2*x*x + a3*x*x*x + a4*x*x*x*x | |
fx1 = calc_fx1(fx, x) | |
fx2 = calc_fx2(fx, x) | |
xn = x - (2*fx*fx1)/(2*tf.pow(tf.round(fx1), 2) - fx*fx2) | |
return [a, x, xn] | |
def condition(coeff, x, xn): | |
comp = tf.less(tf.sqrt(tf.squared_difference(xn, x)), 0.0005) | |
# is there a better way to verify that Tensor contains not all True elements? | |
return tf.less(tf.reduce_sum(tf.cast(comp, tf.float32)),x.get_shape()[0].value) | |
with tf.Session() as sess: | |
a = tf.constant([ | |
[1.0, 3.0, -1.0, 2.0, -1.0], | |
[2.3, 4.1, 4.8, -1.0, 0.0] | |
]) | |
x = tf.constant([0.,0.]) # initial x value | |
xn = tf.constant([100.0,-100.0]) | |
res = tf.while_loop(condition, calc_next_x, loop_vars = [a, x, xn], parallel_iterations=1, maximum_iterations=3000) | |
result = sess.run(res) | |
print('result:') | |
print(result[2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment