最小均方算法(LMS Algorithm)理论及DSP实现(2)
- /*
- * zx_lms.c
- * Least Mean Squares Algorithm
- * Created on: 2013-8-4
- * Author: monkeyzx
- */
- #include "zx_lms.h"
- #include "config.h"
- #include <stdio.h>
- #include <stdlib.h>
-
- static double init_y[] = {4.00,3.30,3.69,2.32};
- static double init_x[] = {
- 2.104,3,
- 1.600,3,
- 2.400,3,
- 3.000,4
- };
- static double weight[2] = {0.1, 0.1};
-
- /*
- * Least Mean Square Algorithm
- * return value @Error when stop iteration
- * use @lms_prob->method to choose a method.
- */
- double lms(struct lms_st *lms_prob)
- {
- double err;
- double error;
- int i = 0;
- int j = 0;
- int iter = 0;
- static double *h = 0; /* 加static,防止栈溢出*/
-
- h = (double *)malloc(sizeof(double) * lms_prob->m);
- if (!h) {
- return -1;
- }
- do {
- error = 0;
-
- if (lms_prob->method != STOCHASTIC) {
- i = 0;
- } else {
- /* i=(i+1) mod m */
- i = i + 1;
- if (i >= lms_prob->m) {
- i = 0;
- }
- }
-
- for ( ; i<lms_prob->m; i++) {
- h = 0;
- for (j=0; j<lms_prob->n; j++) {
- h += lms_prob->weight[j] * lms_prob->x[i*lms_prob->n+j]; /* h(x) */
- }
- if (lms_prob->method == STOCHASTIC) break; /* handle STOCHASTIC */
- }
-
- for (j=0; j<lms_prob->n; j++) {
- if (lms_prob->method != STOCHASTIC) {
- i = 0;
- }
- for ( ; i<lms_prob->m; i++) {
- err = lms_prob->lrate
- * (lms_prob->y - h) * lms_prob->x[i*lms_prob->n+j];
- lms_prob->weight[j] += err; /* Update weights */
- error += ABS(err);
- if (lms_prob->method == STOCHASTIC) break; /* handle STOCHASTIC */
- }
- }
-
- iter = iter + 1;
- if ((lms_prob->max_iter > 0) && ((iter > lms_prob->max_iter))) {
- break;
- }
- } while (error >= lms_prob->threshhold);
-
- free(h);
-
- return error;
- }
-
- #define DEBUG
- void zx_lms(void)
- {
- int i = 0;
- double error = 0;
- struct lms_st lms_prob;
-
- lms_prob.lrate = 0.01;
- lms_prob.m = 4;
- lms_prob.n = 2;
- lms_prob.weight = weight;
- lms_prob.threshhold = 0.2;
- lms_prob.max_iter = 1000;
- lms_prob.x = init_x;
- lms_prob.y = init_y;
- // lms_prob.method = STOCHASTIC;
- lms_prob.method = BATCH;
-
- // error = lms(init_x, 2, init_y, 4, weight, 0.01, 0.1, 1000);
- error = lms(&lms_prob);
-
- #ifdef DEBUG
- for (i=0; i<sizeof(weight)/sizeof(weight[0]); i++) {
- printf("%f\n", weight);
- }
- printf("error:%f\n", error);
- #endif
- }
复制代码
输入、输出、初始权值为
- static double init_y[] = {4.00,3.30,3.69,2.32};
- static double init_x[] = { /* 用一维数组保存 */
- 2.104, 3,
- 1.600, 3,
- 2.400, 3,
- 3.000, 4
- };
复制代码
- static double weight[2] = {0.1, 0.1};
复制代码
main函数中只需要调用zx_lms()就可以运行了,本文对两种梯度下降方法做了个简单对比,
需要说明的是:batch算法是达到最大迭代次数1000退出的,而stochastic是收敛退出的,因此这里batch算法应该没有对数据做到较好的拟合。stochastic算法则在时钟周期上只有995,远比batch更有时间上的优势。
注:这里的error没有太大的可比性,因为batch的error针对的整体数据集的error,而stochastic 的error是针对一个随机的数据实例。
LMS有个很重要的问题:收敛。开始时可以根据给定数据集设置w值,使h(x)尽可能与接近y,如果不确定可以将w设置小一点。
这里顺便记录下在调试过程中遇到的一个问题:在程序运行时发现有变量的值为1.#QNAN。
解决:QNAN是Quiet Not a Number简写,是常见的浮点溢出错误,在网上找到了解释
A QNaN is a NaN with the most significant fraction bit set. QNaN’s propagate freely through most arithmetic operations. These values pop out of an operation when the result is not mathematically defined.
在开始调试过程中因为迭代没有收敛,发散使得w和error等值逐渐累积,超过了浮点数的范围,从而出现上面的错误,通过修改使程序收敛后上面的问题自然而然解决了。
|