【Python】array同士を比較して一致するかassertでチェックする【numpy】

状況

numpyのarray同士が等しいか比較しようとしたらValueErrorが発生した。

>>> import numpy as np
>>> a = np.array([1.0, 2.0, 3.0])
>>> b = np.array([1.0, 2.0, 3.0])
>>> assert a == b
E       ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

解決法1

エラーメッセージに従って、.all()を使う。

>>> assert (a == b).all()

解決法2 (個人的におすすめ)

numpy.testing.assert_array_equalを使う。

>>> np.testing.assert_array_equal(a, b)

 

その他の便利なassert

このnumpy.testingには他にも便利なassertがあります。

例えば、以下のように微妙に異っているが、等しいと見なしたいarrayがあるとします。

>>> a = np.array([1.0, 2.0, 3.0])
>>> c = np.array([1.0, 2.0, 3.1])

このときassert_array_equal(a, c)ではAssertionErrorが発生します。

>>> np.testing.assert_array_equal(a, c)
E       AssertionError:
E       Arrays are not equal
E
E       Mismatched elements: 1 / 3 (33.3%)
E       Max absolute difference: 0.1
E       Max relative difference: 0.03225806
E        x: array([1., 2., 3.])
E        y: array([1. , 2. , 3.1])

numpy.testing.assert_almost_equalを使うとエラーを回避できます。

>>> np.testing.assert_array_almost_equal(a, c, decimal=1)

pytestなどでテストコードを書くときに重宝します。

参考:

numpy.testing.assert_array_equal — NumPy v1.26 Manual
numpy.testing.assert_array_almost_equal — NumPy v1.26 Manual

コメント