/////////////////////////////////////////////////////////////////////////////// // Advantage Updating & Q-Learning Comparison on a Linear Quadratic Regulator// // Capt. Leemon Baird, Wright Laboratory, WPAFB, OH // // bairdlc@wL.wpafb.af.mil 2 Nov 93 // // This program runs several simulations of both Q-learning and advantage // // updating on a simple linear quadratic regulator (LQR) problem. // // The names of weights start with "w". The functions of the weights are: // // Q(x,u)=wqxx*x^2 + wqxu*x*u + wquu*u^2 // // A(x,u)=waxx*x^2 + waxu*x*u + wauu*u^2 // // V(x,u)=wav*x^2 // // The optimal weights: wqxx=-k2-k3*k1*k1*dt wqxu=-2*k3*k1*dt wquu=-k3*dt // // wav= waxx=(-k3*k1*k1) waxu=-2*k3*k1 wauu=-k3 // // Each line printed by the program has the time step duration (dt), number // // of system simulated in parallel (nsys), learning rate for learning A, V, // // normalizing A, and learning Q (aA,aV,aW,aQ), noise (N), and number of // // timesteps required for advantage updating and Q-learning to learn the // // policy with mean absolute error less than 0.001 (tA,tQ). // /////////////////////////////////////////////////////////////////////////////// #include "stdio.h" #include "math.h" #include "stdlib.h" #define gamma 0.9 /* discount factor */ #define AU_sim 1 /* 1 to simulate advantage updating, 0 otherwise */ #define Q_sim 0 /* 1 to simulate Q-learning, 0 otherwise */ #define first_line 17 /* first simulation is this row of the array k[][] */ #define last_line 17+26 //17+26 /* last simulation is this row of array k[][] */ #define mxs 100 /* max number of systems that can learn in parallel */ #define dt (k[c][0]) /* time step duration */ #define nsys (k[c][1]) /* number of systems to simulate in parallel */ #define alphaa (k[c][2]) /* learning rate for learning A */ #define alphav (k[c][3]) /* learning rate for learning V */ #define alphaw (k[c][4]) /* learning rate for normalizing A */ #define alphaq (k[c][5]) /* learning rate for learning Q */ #define noise (k[c][6]) /* amount of random noise to add to reinforcement */ #define R(x,u) ((2*gd1*u*u+2*u*ln*(u*dt*gd-gd1*x) + \ ln2*((gd1-dt*dt*gd)*u*u+gd1*x*x-2*dt*gd*u*x))/ln3) // do action u in state x, discounted reinforcement for one time step=R(x,u). void init_w (void); // initialize the weights void update_av(void); // perform one step of advantage updating void update_q (void); // perform one step of Q-leasrning void avg_av (void); // find the mean |error| for A and V weights void avg_q (void); // find the mean |error| for Q weights double rnd (void); // return a random number in the range [-1,1] double waxx[mxs],waxu[mxs],wauu[mxs],wav[mxs],wqxx[mxs],wqxu[mxs],wquu[mxs]; double k1,k2,k3,gd,gd1,gd2,ln,ln2,ln3; //IEEE standard 80-bit floating point long c, t, s, donea, doneq; char back[13]={8,8,8,8,8,8,8,8,8,8,8,8,0}; //12 backspaces used for printing unsigned long rnd_seed=1; // should be a 32-bit unsigned integer double k[100][7] = { //this defines the parameters for 34 different simulations // dt nsys aA aV aW aQ noise tA tQ 1e-1, 100, 1.0, .3, .5, 1.4, 0, // 214 239 // 0 1e-1, 100, 1.0, .3, .5, 1.4, 1, // 222 272 // 1 1e-1, 100, .6, .3, .3, .74, 2, // 286 415 // 2 1e-1, 100, .5, .3, .3, .44, 3, // 375 660 // 3 1e-1, 100, .4, .3, .4, .26, 4, // 445 1,128 // 4 1e-1, 100, .3, .2, .3, .17, 5, // 561 1,688 // 5 1e-1, 100, .2, .4, .1, .11, 6, // 755 2,402 // 6 1e-1, 100, .2, .2, .2, .088, 7, // 765 3,250 // 7 1e-1, 100, .1, .1, .07, .073, 8, // 1335 3,441 // 8 1e-1, 100, .1, .09,.05, .054, 9, // 1578 4,668 // 9 1e-1, 100, .1, .1, .06, .050, 10, // 1761 4,880 //10 1e-1, 100, .08,.06,.06, .046, 11, // 1761 5,506 //11 1e-1, 100, .08,.3, .4, .030, 12, // 1643 # 8,725 //12 1e-1, 100, .07,.3, .3, .028, 13, // 1736 # 8,863 //13 1e-1, 100, .07,.1, .3, .022, 14, // 1825 # 11,642 //14 1e-1, 100, .06,.2, .1, .018, 15, // 1880 # 13,131 //15 1e-1, 100, .06,.2, .1, .018, 16, // 1881 # 13,183 //16 1e-1, 100, .10,.2, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.2, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.2, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.3, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.3, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.3, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.4, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.4, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .10,.4, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.2, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.2, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.2, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.3, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.3, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.3, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.4, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.4, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .08,.4, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.2, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.2, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.2, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.3, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.3, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.3, .5, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.4, .3, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.4, .4, .028, 11, // 1736 8,863 //13 1e-1, 100, .09,.4, .5, .028, 11, // 1736 8,863 //13 1e0 , 100, 1.0, .7, .4, .44, 0, // 194 # 382 //17 3e-1, 100, 1.0, .4, .8, 1.0, 0, // 190 # 195 //18 1e-1, 100, 1.0, .3, .5, 1.4, 0, // 214 # 239 //19 3e-2, 100, .9, .3, .5, 1.5, 0, // 216 # 459 //20 1e-2, 100, .9, .3, .5, 1.6, 0, // 216 # 1,003 //21 3e-3, 100, .9, .3, .7, 1.6, 0, // 214 # 2,870 //22 1e-3, 100, .9, .3, .7, 1.5, 0, // 215 # 9,032 //23 3e-4, 100, .9, .3, .7, 1.4, 0, // 214 # 32,117 //24 1e-4, 100, .9, .3, .7, 1.4, 0, // 214 # 96,764 //25 3e-5, 100, .9, .3, .7, 1.2, 0, // 214 # 372,995 //26 1e-5, 100, .9, .3, .7, 1.3, 0, // 214 #1,032,482 //27 3e-6, 100, .9, .3, .7, 1.2, 0, // 214 #3,715,221 //28 1e-6, 10, .9, .3, .7, 1.2, 0, // 214#10,524,463 (nsys=100 for tA) //29 3e-7, 10, .9, .3, .7, 1.2, 0, // 214#34,678,545 (nsys=100 for tA) //30 1e-7, 100, .9, .3, .7, 0, 0, // 214 # //31 3e-8, 100, .9, .3, .7, 0, 0, // 214 # //32 1e-8, 100, .9, .3, .7, 0, 0};// 214 # //33 main() { printf("\n dt nsys aA aV aW aQ N tA tQ\n"); for (c=first_line;c<=last_line;c++) { //do several lines from the table init_w(); //init each parallel system with different weights rnd_seed=7; //reseed the random number generator if (!AU_sim) printf (" "); donea=!AU_sim; doneq=!Q_sim; for (t=0;!donea;t++) { //do updates for nsys adv. upd. systems in parallel avg_av(); for (s=0;s1e9 || fabs(waxu[s])>1e9 || fabs(wauu[s])>1e9 || fabs(wav[s])>1e9) inf=1; //if any weight blows up, then learning takes infinite time } if ((t/100)*100==t) printf("%s(%09ld) ",&back,t); if (piam<.001 && !donea) { donea=1; printf("%s%9ld ",&back,t); if (doneq) printf("\n"); } if (inf && !donea) { donea=1; printf("%s infinity ",&back); if (doneq) printf("\n"); } } void avg_q(void) { //find the average error and print current time double piqm, inf=0; for (s=piqm=0;s1e9 || fabs(wqxu[s])>1e9 || fabs(wquu[s])>1e9) inf=1; //if any weight blows up, then learning takes infinite time } if ((t/100)*100==t) printf("%s(%09ld) ",&back,t); if (piqm<.001 && !doneq) { doneq=1; printf("%s%9ld \n",&back,t); } if (inf && !doneq) { doneq=1; printf("%s infinity \n",&back); } } void update_av(void) { //do both learning and normalizing double am1,am0,x1a,ra,a,v0,v1,x0,dv,ev,ea,ua,aum,da,am0n; x0=rnd(); //pick initial state x0 and action ua if (rnd()>0) ua=rnd(); else if (wauu[s]<0) ua=-waxu[s]*x0/2/wauu[s]; else if (waxu[s]>0) ua=1; else ua=-1; if (ua>1) ua=1; else if (ua<-1) ua=-1; x1a=x0+ua*dt; //x1 is state after doing ua in x0 ra=R(x0,ua)+noise*rnd()/10000; //ra is reinforcement + 0 mean noise a=waxx[s]*x0*x0+waxu[s]*x0*ua+wauu[s]*ua*ua; //a is A(x0,ua) v0=wav[s]*x0*x0; //v0 is V(x0) v1=wav[s]*x1a*x1a; //v1 is V(x1) if (wauu[s]>=0) am0=waxx[s]*x0*x0+fabs(waxu[s]*x0)+wauu[s]; else { aum=(-waxu[s]*x0/2./wauu[s]); if (aum<-1.) aum=-1.; else if (aum> 1.) aum= 1.; am0=waxx[s]*x0*x0+waxu[s]*x0*aum+wauu[s]*aum*aum; } //am0 is Amax(x0) before the weights change if (wauu[s]>=0) am1=waxx[s]*x1a*x1a+fabs(waxu[s]*x1a)+wauu[s]; else { aum=(-waxu[s]*x1a/2./wauu[s]); if (aum<-1.) aum=-1.; else if (aum> 1.) aum= 1.; am1=waxx[s]*x1a*x1a+waxu[s]*x1a*aum+wauu[s]*aum*aum; } //am1 is Amax(x1) before the weights change da=(-v0+gd*v1+ra)/dt+am0; //desired output of advantage net is da ea=da-a; //error in output is ea waxx[s]+=ea*x0*x0*alphaa; //change 3 weights of advantage net with LMS waxu[s]+=ea*x0*ua*alphaa; wauu[s]+=ea*ua*ua*alphaa; if (wauu[s]>=0) am0n=waxx[s]*x0*x0+fabs(waxu[s]*x0)+wauu[s]; else { aum=(-waxu[s]*x0/2./wauu[s]); if (aum<-1.) aum=-1.; else if (aum> 1.) aum= 1.; am0n=waxx[s]*x0*x0+waxu[s]*x0*aum+wauu[s]*aum*aum; }; //am0n is new Amax(x0) after the weights change dv=(am0n-am0)/alphaa+v0; //desired output of value net is dv ev=dv-v0; //error in output is ev wav[s] +=ev*x0*x0*alphav; //change the weight of the value net with LMS //--------------------------------- normalization --------------------------- x0=rnd();ua=rnd();a=waxx[s]*x0*x0+waxu[s]*x0*ua+wauu[s]*ua*ua; if (wauu[s]>=0) am0=waxx[s]*x0*x0+fabs(waxu[s]*x0)+wauu[s]; else { aum=(-waxu[s]*x0/2./wauu[s]); if (aum<-1.) aum=-1.; else if (aum> 1.) aum= 1.; am0=waxx[s]*x0*x0+waxu[s]*x0*aum+wauu[s]*aum*aum; }; //am0 is Amax(x0) after the weights change da=a-am0; //desired output of advantage net is da ea=da-a; //error in output is ea waxx[s]+=ea*x0*x0*alphaw;//change 3 weights of advantage net with delta rule waxu[s]+=ea*x0*ua*alphaw; wauu[s]+=ea*ua*ua*alphaw; } //end update_av void update_q(void) { //do Q learning double x0,x1q,rq,q,v0,v1,qm,uq,x1a,qum,dq,eq; x0=rnd(); //pick initial state x0, and action uq if (rnd()>0) uq=rnd(); else if (wquu[s]<0) uq=-wqxu[s]*x0/2/wquu[s]; else if (wqxu[s]>0) uq=1; else uq=-1; if (uq>1) uq=1; else if (uq<-1) uq=-1; x1q=x0+uq*dt; //x1q is new state after doing uq in x0 rq=R(x0,uq)+noise*rnd()/10000; //rq is reinforcement + zero mean noise q=wqxx[s]*x0*x0+wqxu[s]*x0*uq+wquu[s]*uq*uq;//q is Q(x0,uq) v0=wav[s]*x0*x0; // v0 is Qmax(x0) v1=wav[s]*x1a*x1a; // v1 is Qmax(x1) if (wquu[s]>=0) qm=wqxx[s]*x1q*x1q+fabs(wqxu[s]*x1q)+wquu[s]; else { qum=(-wqxu[s]*x1q/2./wquu[s]); if (qum<-1.) qum=-1.; else if (qum> 1.) qum= 1.; qm=wqxx[s]*x1q*x1q+wqxu[s]*x1q*qum+wquu[s]*qum*qum; } //qm is Qmax(x1) before the weights change dq=gd*qm+rq; //desired output of Q net is dq eq=dq-q; //error in output is eq wqxx[s]+=eq*x0*x0*alphaq;//change 3 weights of Q net with delta rule wqxu[s]+=eq*x0*uq*alphaq; wquu[s]+=eq*uq*uq*alphaq; } //end update_q double rnd() { //return a random double in the range [-1,1] rnd_seed = rnd_seed * 1103515245 + 12345; return (2.*(double)((rnd_seed>>16)&32767)/(double)32767-1.); } //rnd_seed should be a 32-bit unsigned integer