
https://svn.lrde.epita.fr/svn/oln/trunk/milena Index: ChangeLog from Thierry Geraud <thierry.geraud@lrde.epita.fr> Fix fun stat mahalanobis. * mln/fun/stat/mahalanobis.hh (var_1, mean): Rename as... (var_1_, mean_): ...these. Protect them. (mean_t, mean): New typedef and method. (operator()): Fix missing sqrt. * tests/fun/stat/mahalanobis.cc: Augment. mln/fun/stat/mahalanobis.hh | 24 +++++++++++++++++++----- tests/fun/stat/mahalanobis.cc | 17 ++++++++--------- 2 files changed, 27 insertions(+), 14 deletions(-) Index: mln/fun/stat/mahalanobis.hh --- mln/fun/stat/mahalanobis.hh (revision 3771) +++ mln/fun/stat/mahalanobis.hh (working copy) @@ -32,6 +32,7 @@ /// /// Define the FIXME +# include <cmath> # include <mln/core/concept/function.hh> # include <mln/algebra/vec.hh> # include <mln/algebra/mat.hh> @@ -59,8 +60,13 @@ float operator()(const V& v) const; - algebra::mat<n,n,float> var_1; - algebra::vec<n,float> mean; + typedef algebra::vec<n,float> mean_t; + + mean_t mean() const; + + protected: + algebra::mat<n,n,float> var_1_; + algebra::vec<n,float> mean_; }; @@ -71,8 +77,8 @@ mahalanobis<V>::mahalanobis(const algebra::mat<V::dim,V::dim,float>& var, const algebra::vec<V::dim,float>& mean) { - var_1 = var._1(); - mean = mean; + var_1_ = var._1(); + mean_ = mean; } template <typename V> @@ -80,7 +86,15 @@ float mahalanobis<V>::operator()(const V& v) const { - return (v - mean).t() * var_1 * (v - mean); + return std::sqrt((v - mean_).t() * var_1_ * (v - mean_)); + } + + template <typename V> + inline + typename mahalanobis<V>::mean_t + mahalanobis<V>::mean() const + { + return mean_; } # endif // ! MLN_INCLUDE_ONLY Index: tests/fun/stat/mahalanobis.cc --- tests/fun/stat/mahalanobis.cc (revision 3771) +++ tests/fun/stat/mahalanobis.cc (working copy) @@ -60,15 +60,14 @@ for (int i = 0; i < n; ++i) a.take(v[i]); -// vec3f m = a.mean(); -// mln_assertion(m[0] > 0.4 && m[0] < 0.6); -// mln_assertion(m[1] > 0.9 && m[1] < 1.1); -// mln_assertion(m[2] > 1.4 && m[2] < 1.6); - fun::stat::mahalanobis<vec3f> f(a.variance(), a.mean()); + mln_assertion(f(a.mean()) == 0.f); -// algebra::mat<3,3,float> s_1 = a.variance()._1(); -// mln_assertion(s_1(0,0) > 11 && s_1(0,0) < 13); -// mln_assertion(s_1(1,1) > 2 && s_1(1,1) < 4); -// mln_assertion(s_1(2,2) > 1.1 && s_1(2,2) < 1.5); + float sum = 0.f; + for (int i = 0; i < n; ++i) + { + float f_ = f(v[i]); + sum += f_ * f_; + } + mln_assertion(std::abs(sum / n - 3.f) < 0.00002f); }