2023-12-07 00:04:41 +08:00
|
|
|
use std.math;
|
2023-03-01 00:41:34 +08:00
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var mat = func(width,height) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var res=[];
|
|
|
|
setsize(res,width*height);
|
|
|
|
forindex(var i;res) {
|
|
|
|
res[i]=0;
|
|
|
|
}
|
|
|
|
return {
|
|
|
|
width:width,
|
|
|
|
height:height,
|
|
|
|
mat:res
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var rand_init = func(a) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var ref=a.mat;
|
|
|
|
forindex(var i;ref) {
|
2023-03-01 23:37:13 +08:00
|
|
|
ref[i]=rand()*2-1;
|
2023-03-01 00:41:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var mat_print = func(a) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var (width,height,ref)=(a.width,a.height,a.mat);
|
|
|
|
for(var i=0;i<height;i+=1) {
|
|
|
|
for(var j=0;j<width;j+=1) {
|
|
|
|
print(ref[i*width+j]," ");
|
|
|
|
}
|
|
|
|
println();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var add = func(a,b) {
|
|
|
|
if (a.width!=b.width or a.height!=b.height) {
|
2023-03-02 00:26:52 +08:00
|
|
|
println("matrix a: ",a);
|
|
|
|
println("matrix b: ",b);
|
|
|
|
die("width and height must be the same");
|
2023-03-01 00:41:34 +08:00
|
|
|
return nil;
|
|
|
|
}
|
|
|
|
|
|
|
|
var res=mat(a.width,a.height);
|
|
|
|
var (width,height,ref)=(res.width,res.height,res.mat);
|
|
|
|
var (aref,bref)=(a.mat,b.mat);
|
|
|
|
|
|
|
|
for(var i=0;i<height;i+=1) {
|
|
|
|
for(var j=0;j<width;j+=1) {
|
|
|
|
ref[i*width+j]=aref[i*width+j]+bref[i*width+j];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var sub = func(a,b) {
|
|
|
|
if (a.width!=b.width or a.height!=b.height) {
|
2023-03-02 00:26:52 +08:00
|
|
|
println("matrix a: ",a);
|
|
|
|
println("matrix b: ",b);
|
|
|
|
die("width and height must be the same");
|
2023-03-01 00:41:34 +08:00
|
|
|
return nil;
|
|
|
|
}
|
|
|
|
|
|
|
|
var res=mat(a.width,a.height);
|
|
|
|
var (width,height,ref)=(res.width,res.height,res.mat);
|
|
|
|
var (aref,bref)=(a.mat,b.mat);
|
|
|
|
|
|
|
|
for(var i=0;i<height;i+=1) {
|
|
|
|
for(var j=0;j<width;j+=1) {
|
|
|
|
ref[i*width+j]=aref[i*width+j]-bref[i*width+j];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var hardamard = func(a,b) {
|
|
|
|
if (a.width!=b.width or a.height!=b.height) {
|
2023-03-02 00:26:52 +08:00
|
|
|
println("matrix a: ",a);
|
|
|
|
println("matrix b: ",b);
|
|
|
|
die("width and height must be the same");
|
2023-03-01 23:37:13 +08:00
|
|
|
return nil;
|
|
|
|
}
|
|
|
|
|
|
|
|
var res=mat(a.width,a.height);
|
|
|
|
var (width,height,ref)=(res.width,res.height,res.mat);
|
|
|
|
var (aref,bref)=(a.mat,b.mat);
|
|
|
|
|
|
|
|
for(var i=0;i<height;i+=1) {
|
|
|
|
for(var j=0;j<width;j+=1) {
|
|
|
|
ref[i*width+j]=aref[i*width+j]*bref[i*width+j];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var neg = func(a) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var res=mat(a.width,a.height);
|
|
|
|
var (aref,ref)=(a.mat,res.mat);
|
|
|
|
forindex(var i;aref) {
|
|
|
|
ref[i]=-aref[i];
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var sum = func(a) {
|
2023-03-01 23:37:13 +08:00
|
|
|
var res=0;
|
|
|
|
var aref=a.mat;
|
|
|
|
forindex(var i;aref) {
|
|
|
|
res+=aref[i];
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var mult_num = func(a,c) {
|
2023-03-01 23:37:13 +08:00
|
|
|
var res=mat(a.width,a.height);
|
|
|
|
var ref=res.mat;
|
|
|
|
var aref=a.mat;
|
|
|
|
forindex(var i;aref) {
|
|
|
|
ref[i]=aref[i]*c;
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var trans = func(a) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var res=mat(a.height,a.width);
|
|
|
|
var ref=res.mat;
|
|
|
|
var (a_width,a_height,aref)=(a.width,a.height,a.mat);
|
|
|
|
for(var i=0;i<a_width;i+=1) {
|
|
|
|
for(var j=0;j<a_height;j+=1) {
|
|
|
|
ref[i*a_height+j]=aref[j*a_width+i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var activate = func(a,f) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var res=mat(a.width,a.height);
|
|
|
|
var (aref,ref)=(a.mat,res.mat);
|
|
|
|
forindex(var i;aref) {
|
|
|
|
ref[i]=f(aref[i]);
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var mult = func(a,b) {
|
|
|
|
if (a.width!=b.height) {
|
2023-03-02 00:26:52 +08:00
|
|
|
println("matrix a: ",a);
|
|
|
|
println("matrix b: ",b);
|
|
|
|
die("a.width must equal to b.height, but get a.width:"~str(a.width)~" and b.height"~str(b.height));
|
2023-03-01 00:41:34 +08:00
|
|
|
return nil;
|
|
|
|
}
|
|
|
|
|
|
|
|
var res=mat(b.width,a.height);
|
|
|
|
var (res_width,res_height,ref)=(res.width,res.height,res.mat);
|
|
|
|
var (a_width,aref)=(a.width,a.mat);
|
|
|
|
var (b_width,bref)=(b.width,b.mat);
|
|
|
|
|
|
|
|
for(var i=0;i<res_width;i+=1) {
|
|
|
|
for(var j=0;j<res_height;j+=1) {
|
|
|
|
for(var k=0;k<a_width;k+=1) {
|
|
|
|
ref[j*res_width+i]+=aref[j*a_width+k]*bref[k*b_width+i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var sigmoid = func(x) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var t=math.exp(-x);
|
|
|
|
return 1/(1+t);
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var diffsigmoid = func(x) {
|
2023-03-01 23:37:13 +08:00
|
|
|
x=sigmoid(x);
|
|
|
|
return x*(1-x);
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var tanh = func(x) {
|
2023-03-01 00:41:34 +08:00
|
|
|
var t1=math.exp(x);
|
|
|
|
var t2=math.exp(-x);
|
|
|
|
return (t1-t2)/(t1+t2);
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var difftanh = func(x) {
|
2023-03-01 23:37:13 +08:00
|
|
|
x=tanh(x);
|
|
|
|
return 1-x*x;
|
|
|
|
}
|
|
|
|
|
2023-11-16 23:19:03 +08:00
|
|
|
var bp_example = func() {
|
2023-03-01 23:37:13 +08:00
|
|
|
srand();
|
|
|
|
var lr=0.01;
|
|
|
|
var input=[
|
|
|
|
{width:2,height:1,mat:[0,0]},
|
|
|
|
{width:2,height:1,mat:[0,1]},
|
|
|
|
{width:2,height:1,mat:[1,0]},
|
|
|
|
{width:2,height:1,mat:[1,1]}
|
|
|
|
];
|
|
|
|
# last 2 column is useless, only used to make sure bp runs correctly
|
|
|
|
var expect=[
|
|
|
|
{width:3,height:1,mat:[0,0,0]},
|
2023-03-08 23:53:02 +08:00
|
|
|
{width:3,height:1,mat:[1,0,1]},
|
|
|
|
{width:3,height:1,mat:[1,1,0]},
|
|
|
|
{width:3,height:1,mat:[0,1,1]}
|
2023-03-01 23:37:13 +08:00
|
|
|
];
|
|
|
|
var hidden={
|
|
|
|
weight:mat(4,2),
|
|
|
|
bias:mat(4,1),
|
|
|
|
in:nil,
|
|
|
|
out:nil,
|
|
|
|
diff:nil
|
|
|
|
};
|
|
|
|
var output={
|
|
|
|
weight:mat(3,4),
|
|
|
|
bias:mat(3,1),
|
|
|
|
in:nil,
|
|
|
|
out:nil,
|
|
|
|
diff:nil
|
|
|
|
};
|
|
|
|
rand_init(hidden.weight);
|
|
|
|
rand_init(hidden.bias);
|
|
|
|
rand_init(output.weight);
|
|
|
|
rand_init(output.bias);
|
|
|
|
|
|
|
|
var epoch=0;
|
|
|
|
var total=1e6;
|
2023-03-10 22:27:06 +08:00
|
|
|
while(total>0.001) {
|
2023-03-01 23:37:13 +08:00
|
|
|
epoch+=1;
|
2023-11-16 23:19:03 +08:00
|
|
|
if (epoch>1e4) {
|
2023-03-01 23:37:13 +08:00
|
|
|
println("Training failed after ",epoch," epoch.");
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
total=0;
|
|
|
|
forindex(var i;input) {
|
|
|
|
hidden.in=add(mult(input[i],hidden.weight),hidden.bias);
|
|
|
|
hidden.out=activate(hidden.in,tanh);
|
|
|
|
|
|
|
|
output.in=add(mult(hidden.out,output.weight),output.bias);
|
|
|
|
output.out=activate(output.in,sigmoid);
|
|
|
|
|
|
|
|
var error=sub(expect[i],output.out);
|
|
|
|
|
|
|
|
output.diff=hardamard(error,activate(output.in,diffsigmoid));
|
|
|
|
hidden.diff=hardamard(trans(mult(output.weight,trans(output.diff))),activate(hidden.in,difftanh));
|
|
|
|
|
|
|
|
output.bias=add(output.bias,output.diff);
|
|
|
|
hidden.bias=add(hidden.bias,hidden.diff);
|
|
|
|
|
|
|
|
output.weight=add(output.weight,mult(trans(hidden.out),output.diff));
|
|
|
|
hidden.weight=add(hidden.weight,mult(trans(input[i]),hidden.diff));
|
|
|
|
|
|
|
|
total+=sum(mult_num(mult(error,trans(error)),0.5));
|
|
|
|
}
|
|
|
|
}
|
2023-11-16 23:19:03 +08:00
|
|
|
if (epoch<=1e4) {
|
2023-03-01 23:37:13 +08:00
|
|
|
println("Training succeeded after ",epoch," epoch.");
|
|
|
|
}
|
|
|
|
|
|
|
|
forindex(var i;input) {
|
|
|
|
hidden.in=add(mult(input[i],hidden.weight),hidden.bias);
|
|
|
|
hidden.out=activate(hidden.in,tanh);
|
|
|
|
|
|
|
|
output.in=add(mult(hidden.out,output.weight),output.bias);
|
|
|
|
output.out=activate(output.in,sigmoid);
|
|
|
|
|
|
|
|
println(input[i].mat," : ",output.out.mat);
|
|
|
|
}
|
|
|
|
}
|