Mercurial > hg > tvii
comparison tests/test_logistic_regression.py @ 31:fa7a51df0d90
[logistic regression] test gradient descent
| author | Jeff Hammel <k0scist@gmail.com> |
|---|---|
| date | Mon, 04 Sep 2017 12:37:45 -0700 |
| parents | cf7584f0a29f |
| children | 0f29b02f4806 |
comparison
equal
deleted
inserted
replaced
| 30:ae0c345ea09d | 31:fa7a51df0d90 |
|---|---|
| 9 import unittest | 9 import unittest |
| 10 from tvii import logistic_regression | 10 from tvii import logistic_regression |
| 11 | 11 |
| 12 | 12 |
| 13 class LogisticRegresionTests(unittest.TestCase): | 13 class LogisticRegresionTests(unittest.TestCase): |
| 14 | |
| 15 def compare_arrays(self, a, b): | |
| 16 assert a.shape == b.shape | |
| 17 for x, y in zip(a.flatten(), | |
| 18 b.flatten()): | |
| 19 self.assertAlmostEqual(x, y) | |
| 20 | |
| 14 | 21 |
| 15 def test_cost(self): | 22 def test_cost(self): |
| 16 """test cost function""" | 23 """test cost function""" |
| 17 | 24 |
| 18 w, b, X, Y = (np.array([[1],[2]]), | 25 w, b, X, Y = (np.array([[1],[2]]), |
| 46 assert grads['dw'].shape == dw_expected.shape | 53 assert grads['dw'].shape == dw_expected.shape |
| 47 for a, b in zip(grads['dw'].flatten(), | 54 for a, b in zip(grads['dw'].flatten(), |
| 48 dw_expected.flatten()): | 55 dw_expected.flatten()): |
| 49 self.assertAlmostEqual(a, b) | 56 self.assertAlmostEqual(a, b) |
| 50 | 57 |
| 58 def test_optimize(self): | |
| 59 """test gradient descent method""" | |
| 60 | |
| 61 # test examples | |
| 62 w, b, X, Y = np.array([[1],[2]]), 2, np.array([[1,2],[3,4]]), np.array([[1,0]]) | |
| 63 | |
| 64 params, grads, costs = logistic_regression.optimize(w, b, X, Y, num_iterations= 100, learning_rate = 0.009, print_cost = False) | |
| 65 | |
| 66 # expected output | |
| 67 w_expected = np.array([[0.1124579 ], | |
| 68 [0.23106775]]) | |
| 69 dw_expected = np.array([[ 0.90158428], | |
| 70 [ 1.76250842]]) | |
| 71 b_expected = 1.55930492484 | |
| 72 db_expected = 0.430462071679 | |
| 73 | |
| 74 # compare output | |
| 75 self.assertAlmostEqual(params['b'], b_expected) | |
| 76 self.assertAlmostEqual(grads['db'], db_expected) | |
| 77 self.compare_arrays(w_expected, params['w']) | |
| 78 self.compare_arrays(dw_expected, grads['dw']) | |
| 79 | |
| 51 | 80 |
| 52 if __name__ == '__main__': | 81 if __name__ == '__main__': |
| 53 unittest.main() | 82 unittest.main() |
