dsplib 1.1.0
C++ DSP library for MATLAB-like coding
Loading...
Searching...
No Matches
rls.h
1#pragma once
2
3#include <dsplib/array.h>
4
5namespace dsplib {
6
7//RLS adaptive filter
8template<typename T>
10{
11public:
12 RlsFilter(int filter_len, real_t forget_factor = 0.9, real_t diag_load = 1.0)
13 : _n{filter_len}
14 , _mu{forget_factor}
15 , _u(_n)
16 , _w(_n)
17 , _p(_n * _n) {
18 for (auto i = 0; i < _n; i++) {
19 _p[i * _n + i] = diag_load;
20 }
21 }
22
23 struct Result
24 {
25 base_array<T> y; //output
26 base_array<T> e; //error
27 };
28
29 Result operator()(span_t<T> x, span_t<T> d) {
30 return this->process(x, d);
31 }
32
33 Result process(span_t<T> x, span_t<T> d);
34
35 void set_lock_coeffs(bool locked) {
36 _locked = locked;
37 }
38
39 [[nodiscard]] bool coeffs_locked() const {
40 return _locked;
41 }
42
43 span_t<T> coeffs() const {
44 return _w;
45 }
46
47private:
48 int _n;
49 real_t _mu;
50 base_array<T> _u;
51 base_array<T> _w;
52 base_array<T> _p;
53 bool _locked{false};
54};
55
56using RlsFilterR = RlsFilter<real_t>;
57using RlsFilterC = RlsFilter<cmplx_t>;
58
59//-----------------------------------------------------------------------------------------------
60template<typename T>
61typename RlsFilter<T>::Result RlsFilter<T>::process(span_t<T> x, span_t<T> d) {
62 DSPLIB_ASSERT(x.size() == d.size(), "vector size error: len(x) != len(d)");
63
64 const int nx = x.size();
65 base_array<T> y(nx);
66 base_array<T> e(nx);
67 base_array<T> g(_n);
68
69 //TODO: use matrix syntax
70
71 base_array<T> Pu(_n);
72 base_array<T> uTP(_n);
73 base_array<T> guP(_n * _n);
74
75 for (int idx = 0; idx < nx; idx++) {
76 std::memmove(_u.data() + 1, _u.data(), (_n - 1) * sizeof(T));
77 _u[0] = x[idx];
78
79 y[idx] = dot(_w, _u);
80 e[idx] = d[idx] - y[idx];
81
82 if (_locked) {
83 continue;
84 }
85
86 //g = (P * u) / (mu + u' * P * u);
87 {
88 //Pu = P * u
89 std::fill(Pu.begin(), Pu.end(), 0);
90 for (int i = 0; i < _n; i++) {
91 for (int k = 0; k < _n; k++) {
92 Pu[i] += _p[i * _n + k] * _u[k];
93 }
94 }
95
96 //u' * P
97 std::fill(uTP.begin(), uTP.end(), 0);
98 for (int i = 0; i < _n; i++) {
99 for (int k = 0; k < _n; k++) {
100 uTP[i] += conj(_u[k]) * _p[k * _n + i];
101 }
102 }
103
104 //mu + u' * P * u
105 g = Pu / (_mu + dot(uTP, _u));
106 }
107
108 //P = (1/mu) * (P - g * u' * P);
109 {
110 //guP = g * u' * P
111 for (int i = 0; i < _n; i++) {
112 for (int k = 0; k < _n; k++) {
113 guP[i * _n + k] = g[i] * uTP[k];
114 }
115 }
116
117 //P - g * u' * P
118 for (int i = 0; i < (_n * _n); i++) {
119 _p[i] = (1 / _mu) * (_p[i] - guP[i]);
120 }
121 }
122
123 for (int i = 0; i < _n; i++) {
124 _w[i] += conj(g[i]) * e[idx];
125 }
126 }
127
128 return {y, e};
129}
130
131} // namespace dsplib
Definition rls.h:10
Definition span.h:295
Definition rls.h:24