打印

最小均方算法(LMS Algorithm)理论及DSP实现(2)

[复制链接]
751|3
手机看帖
扫描二维码
随时随地手机跟帖
跳转到指定楼层
楼主
Peonys|  楼主 | 2017-11-15 10:21 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
最小均方算法(LMS Algorithm)理论及DSP实现(2)




  • /*
  • * zx_lms.c
  • * Least Mean Squares Algorithm
  • *  Created on: 2013-8-4
  • *      Author: monkeyzx
  • */  
  • #include "zx_lms.h"  
  • #include "config.h"  
  • #include <stdio.h>  
  • #include <stdlib.h>  
  •   
  • static double init_y[] = {4.00,3.30,3.69,2.32};  
  • static double init_x[] = {  
  •         2.104,3,  
  •         1.600,3,  
  •         2.400,3,  
  •         3.000,4  
  • };  
  • static double weight[2] = {0.1, 0.1};  
  •   
  • /*
  • * Least Mean Square Algorithm
  • * return value @Error when stop iteration
  • * use @lms_prob->method to choose a method.
  • */  
  • double lms(struct lms_st *lms_prob)  
  • {  
  •     double err;  
  •     double error;  
  •     int i = 0;  
  •     int j = 0;  
  •     int iter = 0;  
  •     static double *h = 0;       /* 加static,防止栈溢出*/  
  •   
  •     h = (double *)malloc(sizeof(double) * lms_prob->m);  
  •     if (!h) {  
  •         return -1;  
  •     }  
  •     do {  
  •         error = 0;  
  •   
  •         if (lms_prob->method != STOCHASTIC) {  
  •             i = 0;  
  •         } else {  
  •             /* i=(i+1) mod m */  
  •             i = i + 1;  
  •             if (i >= lms_prob->m) {  
  •                 i = 0;  
  •             }  
  •         }  
  •   
  •         for ( ; i<lms_prob->m; i++) {  
  •             h = 0;  
  •             for (j=0; j<lms_prob->n; j++) {  
  •                 h += lms_prob->weight[j] * lms_prob->x[i*lms_prob->n+j]; /* h(x) */  
  •             }  
  •             if (lms_prob->method == STOCHASTIC) break;   /* handle STOCHASTIC */  
  •         }  
  •   
  •         for (j=0; j<lms_prob->n; j++) {  
  •             if (lms_prob->method != STOCHASTIC) {  
  •                 i = 0;  
  •             }  
  •             for ( ; i<lms_prob->m; i++) {  
  •                 err = lms_prob->lrate  
  •                         * (lms_prob->y - h) * lms_prob->x[i*lms_prob->n+j];  
  •                 lms_prob->weight[j] += err;            /* Update weights */  
  •                 error += ABS(err);  
  •                 if (lms_prob->method == STOCHASTIC) break; /* handle STOCHASTIC */  
  •             }  
  •         }  
  •   
  •         iter = iter + 1;  
  •         if ((lms_prob->max_iter > 0) && ((iter > lms_prob->max_iter))) {  
  •             break;  
  •         }  
  •     } while (error >= lms_prob->threshhold);  
  •   
  •     free(h);  
  •   
  •     return error;  
  • }  
  •   
  • #define DEBUG  
  • void zx_lms(void)  
  • {  
  •     int i = 0;  
  •     double error = 0;  
  •     struct lms_st lms_prob;  
  •   
  •     lms_prob.lrate = 0.01;  
  •     lms_prob.m = 4;  
  •     lms_prob.n = 2;  
  •     lms_prob.weight = weight;  
  •     lms_prob.threshhold = 0.2;  
  •     lms_prob.max_iter = 1000;  
  •     lms_prob.x = init_x;  
  •     lms_prob.y = init_y;  
  • //  lms_prob.method = STOCHASTIC;  
  •     lms_prob.method = BATCH;  
  •   
  • //  error = lms(init_x, 2, init_y, 4, weight, 0.01, 0.1, 1000);  
  •     error = lms(&lms_prob);  
  •   
  • #ifdef DEBUG  
  •     for (i=0; i<sizeof(weight)/sizeof(weight[0]); i++) {  
  •         printf("%f\n", weight);  
  •     }  
  •     printf("error:%f\n", error);  
  • #endif  
  • }  

复制代码

         输入、输出、初始权值为
  • static double init_y[] = {4.00,3.30,3.69,2.32};
  • static double init_x[] = {        /* 用一维数组保存 */
  • 2.104, 3,
  • 1.600, 3,
  • 2.400, 3,
  • 3.000, 4
  • };

复制代码
  • static double weight[2] = {0.1, 0.1};

复制代码
        main函数中只需要调用zx_lms()就可以运行了,本文对两种梯度下降方法做了个简单对比,

         需要说明的是:batch算法是达到最大迭代次数1000退出的,而stochastic是收敛退出的,因此这里batch算法应该没有对数据做到较好的拟合。stochastic算法则在时钟周期上只有995,远比batch更有时间上的优势。
         注:这里的error没有太大的可比性,因为batch的error针对的整体数据集的error,而stochastic 的error是针对一个随机的数据实例。
         LMS有个很重要的问题:收敛。开始时可以根据给定数据集设置w值,使h(x)尽可能与接近y,如果不确定可以将w设置小一点。

         这里顺便记录下在调试过程中遇到的一个问题:在程序运行时发现有变量的值为1.#QNAN。
         解决:QNAN是Quiet Not a Number简写,是常见的浮点溢出错误,在网上找到了解释
         A QNaN is a NaN with the most significant fraction bit set. QNaN’s propagate freely through most arithmetic operations. These values pop out of an operation when the result is not mathematically defined.
         在开始调试过程中因为迭代没有收敛,发散使得w和error等值逐渐累积,超过了浮点数的范围,从而出现上面的错误,通过修改使程序收敛后上面的问题自然而然解决了。

相关帖子

沙发
firstblood| | 2017-11-15 18:01 | 只看该作者
这个程序的设计看着还是有点复杂的

使用特权

评论回复
板凳
smilingangel| | 2017-11-15 19:41 | 只看该作者
最小均方算法的设计要看具体应用场景的哈

使用特权

评论回复
地板
comeon201208| | 2017-11-15 20:40 | 只看该作者
这个就涉及到数据的收敛设计了

使用特权

评论回复
发新帖 我要提问
您需要登录后才可以回帖 登录 | 注册

本版积分规则

640

主题

901

帖子

5

粉丝