dsplib 1.1.0
C++ DSP library for MATLAB-like coding
Loading...
Searching...
No Matches
lms.h
1#pragma once
2
3#include <dsplib/array.h>
4#include <dsplib/utils.h>
5#include <dsplib/math.h>
6
7namespace dsplib {
8
9enum class LmsType
10{
11 LMS,
12 NLMS
13};
14
15//LMS adaptive filter
16template<typename T>
18{
19public:
20 //len: length of FIR filter weights
21 //step_size: adaptation step size
22 //method: method to calculate filter weights
23 //leak: leakage factor
24 explicit LmsFilter(int len, real_t step_size, LmsType method = LmsType::LMS, real_t leak = 1)
25 : _u(len - 1)
26 , _w(len)
27 , _mu{step_size}
28 , _len{len}
29 , _method{method}
30 , _lk{leak} {
31 }
32
33 struct Result
34 {
35 base_array<T> y; //output
36 base_array<T> e; //error
37 };
38
39 Result operator()(span_t<T> x, span_t<T> d) {
40 return this->process(x, d);
41 }
42
43 Result process(span_t<T> x, span_t<T> d) {
44 DSPLIB_ASSERT(x.size() == d.size(), "vector size error: len(x) != len(d)");
45
46 int nx = x.size();
47 base_array<T> y(nx);
48 base_array<T> e(nx);
49 base_array<T> tu = concatenate(_u, x);
50 arr_real tu2 = (_method == LmsType::NLMS) ? abs2(tu) : arr_real{};
51
52 //update delay
53 _u = tu.slice(nx, nx + _len - 1);
54
55 for (int k = 0; k < nx; k++) {
56 //y(n) = w.T(n) * u(n)
57 for (int i = 0; i < _len; i++) {
58 y[k] += _w[i] * tu[i + k];
59 }
60
61 e[k] = d[k] - y[k];
62
63 if (_locked) {
64 continue;
65 }
66
67 if (_method == LmsType::LMS) {
68 //TODO: dont use cycles
69 //w(n) = w(n-1) * a + mu * e(n) * u(n)
70 for (int i = 0; i < _len; i++) {
71 _w[i] = (_w[i] * _lk) + (_mu * e[k] * conj(tu[i + k]));
72 }
73 } else if (_method == LmsType::NLMS) {
74 //pu = u.T(n) * u(n)
75 //TODO: use recurrent sum
76 real_t pu = 0;
77 for (int i = 0; i < _len; i++) {
78 pu += tu2[i + k];
79 }
80
81 //w(n) = w(n-1) * a + mu * e(n) * u(n) / norm
82 const real_t norm = pu + eps();
83 for (int i = 0; i < _len; i++) {
84 _w[i] = (_w[i] * _lk) + (_mu * e[k] * conj(tu[i + k]) / norm);
85 }
86 }
87 }
88
89 return {y, e};
90 }
91
92 void set_lock_coeffs(bool locked) {
93 _locked = locked;
94 }
95
96 [[nodiscard]] bool coeffs_locked() const {
97 return _locked;
98 }
99
100 //TODO: return span
101 base_array<T> coeffs() const {
102 return flip(_w);
103 }
104
105 //TODO: reset()
106 //TODO: step_size control
107
108private:
109 base_array<T> _u;
110 base_array<T> _w;
111 real_t _mu;
112 int _len;
113 bool _locked{false};
114 LmsType _method;
115 real_t _lk;
116};
117
118using LmsFilterR = LmsFilter<real_t>;
119using LmsFilterC = LmsFilter<cmplx_t>;
120
121} // namespace dsplib
Definition lms.h:18
base dsplib array type
Definition array.h:25
Definition span.h:295
Definition lms.h:34