https://svn.lrde.epita.fr/svn/oln/trunk/milena
Index: ChangeLog
from Thierry Geraud <thierry.geraud(a)lrde.epita.fr>
Fix fun stat mahalanobis.
* mln/fun/stat/mahalanobis.hh (var_1, mean): Rename as...
(var_1_, mean_): ...these.
Protect them.
(mean_t, mean): New typedef and method.
(operator()): Fix missing sqrt.
* tests/fun/stat/mahalanobis.cc: Augment.
mln/fun/stat/mahalanobis.hh | 24 +++++++++++++++++++-----
tests/fun/stat/mahalanobis.cc | 17 ++++++++---------
2 files changed, 27 insertions(+), 14 deletions(-)
Index: mln/fun/stat/mahalanobis.hh
--- mln/fun/stat/mahalanobis.hh (revision 3771)
+++ mln/fun/stat/mahalanobis.hh (working copy)
@@ -32,6 +32,7 @@
///
/// Define the FIXME
+# include <cmath>
# include <mln/core/concept/function.hh>
# include <mln/algebra/vec.hh>
# include <mln/algebra/mat.hh>
@@ -59,8 +60,13 @@
float operator()(const V& v) const;
- algebra::mat<n,n,float> var_1;
- algebra::vec<n,float> mean;
+ typedef algebra::vec<n,float> mean_t;
+
+ mean_t mean() const;
+
+ protected:
+ algebra::mat<n,n,float> var_1_;
+ algebra::vec<n,float> mean_;
};
@@ -71,8 +77,8 @@
mahalanobis<V>::mahalanobis(const
algebra::mat<V::dim,V::dim,float>& var,
const algebra::vec<V::dim,float>& mean)
{
- var_1 = var._1();
- mean = mean;
+ var_1_ = var._1();
+ mean_ = mean;
}
template <typename V>
@@ -80,7 +86,15 @@
float
mahalanobis<V>::operator()(const V& v) const
{
- return (v - mean).t() * var_1 * (v - mean);
+ return std::sqrt((v - mean_).t() * var_1_ * (v - mean_));
+ }
+
+ template <typename V>
+ inline
+ typename mahalanobis<V>::mean_t
+ mahalanobis<V>::mean() const
+ {
+ return mean_;
}
# endif // ! MLN_INCLUDE_ONLY
Index: tests/fun/stat/mahalanobis.cc
--- tests/fun/stat/mahalanobis.cc (revision 3771)
+++ tests/fun/stat/mahalanobis.cc (working copy)
@@ -60,15 +60,14 @@
for (int i = 0; i < n; ++i)
a.take(v[i]);
-// vec3f m = a.mean();
-// mln_assertion(m[0] > 0.4 && m[0] < 0.6);
-// mln_assertion(m[1] > 0.9 && m[1] < 1.1);
-// mln_assertion(m[2] > 1.4 && m[2] < 1.6);
-
fun::stat::mahalanobis<vec3f> f(a.variance(), a.mean());
+ mln_assertion(f(a.mean()) == 0.f);
-// algebra::mat<3,3,float> s_1 = a.variance()._1();
-// mln_assertion(s_1(0,0) > 11 && s_1(0,0) < 13);
-// mln_assertion(s_1(1,1) > 2 && s_1(1,1) < 4);
-// mln_assertion(s_1(2,2) > 1.1 && s_1(2,2) < 1.5);
+ float sum = 0.f;
+ for (int i = 0; i < n; ++i)
+ {
+ float f_ = f(v[i]);
+ sum += f_ * f_;
+ }
+ mln_assertion(std::abs(sum / n - 3.f) < 0.00002f);
}