状況
numpyのarray同士が等しいか比較しようとしたらValueErrorが発生した。
>>> import numpy as np
>>> a = np.array([1.0, 2.0, 3.0])
>>> b = np.array([1.0, 2.0, 3.0])
>>> assert a == b
E ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
解決法1
エラーメッセージに従って、.all()
を使う。
>>> assert (a == b).all()
解決法2 (個人的におすすめ)
numpy.testing.assert_array_equal
を使う。
>>> np.testing.assert_array_equal(a, b)
その他の便利なassert
このnumpy.testingには他にも便利なassertがあります。
例えば、以下のように微妙に異っているが、等しいと見なしたいarrayがあるとします。
>>> a = np.array([1.0, 2.0, 3.0])
>>> c = np.array([1.0, 2.0, 3.1])
このときassert_array_equal(a, c)
ではAssertionErrorが発生します。
>>> np.testing.assert_array_equal(a, c)
E AssertionError:
E Arrays are not equal
E
E Mismatched elements: 1 / 3 (33.3%)
E Max absolute difference: 0.1
E Max relative difference: 0.03225806
E x: array([1., 2., 3.])
E y: array([1. , 2. , 3.1])
numpy.testing.assert_almost_equal
を使うとエラーを回避できます。
>>> np.testing.assert_array_almost_equal(a, c, decimal=1)
pytestなどでテストコードを書くときに重宝します。
参考:
numpy.testing.assert_array_equal — NumPy v2.1 Manual
numpy.testing.assert_array_almost_equal — NumPy v2.1 Manual
コメント