From 80edba7420cc611d36e0ab0d10af39e981e94fb1 Mon Sep 17 00:00:00 2001 From: Cezary Pukownik Date: Thu, 30 May 2019 20:50:13 +0200 Subject: [PATCH] making VAE model in train and generate modules --- project/data/samples.npz | Bin 169003 -> 0 bytes project/generate.py | 11 +++-- project/train.py | 101 +++++++++++++++++++++++++++++++++------ 3 files changed, 93 insertions(+), 19 deletions(-) delete mode 100644 project/data/samples.npz diff --git a/project/data/samples.npz b/project/data/samples.npz deleted file mode 100644 index 0b82e27eb930e5ad63a5d7beb6b1123596abf17f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 169003 zcmeHQ2|yE9`X5Bau1DOhE38+vm)&g@x>m7rL~C2S_3GBGCzrLYRTNM_KnN3R3%j*~ z-ENB&1zQ(giv%xjNYe6G+A2gv2yz9B5+Fdp5CR!y-v3R4kdO(GfNj-B)rKT9^WN`$ z-|su;&6~@uM^3^pYz+K63#*%F^}qDbM_@eOgi(*dcQ)_d?ddSr$1kL@=!NZXZ=Uqp zwbDhONL)UC@h=}e`@Qt$pBGGC`-c_(_wuJtz4K)B_y4MzWj8f6$v@8{$?2r#Fi8|- zOMIoAg`K$`&YE)35SbcXx4s`69?cZ#JTX}^8?x*$+4ww3am zo0rM9idV73@bBvQc#f+}*P^cUsM&Bh=m#Ny5I_hZ1P}rU0fYcT;2}q#MVxlR7F*NC za@R!7C6~rhUoQt&6Y;w1oPc_ERzGjs3@0$9CwH#2$4=Dsd&PdoXEkiYX^tJzSgeO_ zAXo*f{ICbM;@Ut-5mFksBZa&m#bs+bgd&YR6fG%BDi`7ue-|$=wxSgvD~SsxnOB5`1*1`Zx6%h zdQQRf%}Sq(tsUf*na0_3sv*AscdD|m^{wVWp*Nsmbj3PY>XK2SPx@a4y-Vw_*^p7_ zJps=*+p@J8Mkp|>j?C>Ii zi5w9W-ST7^6vL_?SjU$ZD#(y^Lhl3?^Z2l0VHQ^m>y>FJN;52>OFa;W{<39`C6+_a zDH!??Fu@Sa=rCOV-YHskS^|Z><8$v!L;ioYu);Fiq4UIseEr@NfLoTkIK5{X@+*ti zX2@68KY=vB*Bv|EF=)S6&#n(AcXXs7{_b8Dv-tZydA-;JXvz>2j6s{evJdsOc~UXd ztMxsecK~rjHzpQy#PCKz7RQ4<*_FO)Z1L|ru-jro2AN|R1Jtf43&r#|vNppA1%_E1 ztHV^B#j;?0V0C$O@>t#<2h8f}c{e(9GOSU7px#2q(|AIbzf&EDpy)==GAK;x2dBO9 z5YKAn=y&Wcv}}ER%+pt0q#IxM+kev3_xoSM>N!r5vApNZEaAXfl4X}OF_@ffvb|P* zzjkTQvB}LI;}5bVZ(=6H^S0D@Bszwu#yh|QhN7Wd-QZY}?@bFeMBC6bj2V=U@eiz)=;Xuz*CGBQird4t#n$!HGuG4np)aJ` z#|S0;;L zT?`DXI%BqW3{?*M9RHheIQxtV{fZDk2p|Ly0*?*?)p>=}D|lawSTbqc&f_!xHfG$v ziXy$fyY~9zDBtX5{H&D|e5X7*V4#^I1Rg#Fc=&e!kebnYGoxHd{chPC$AMmghUeT4 z^!)9a36y?;>TbRMhsS0Zoapvhz|y-+XB*J_?ci#I4^?&gCz=^U;(Fix;iSQc!Zi9H z{%@JVi9#Cni<*lKr*^x@;6oRK6aBCSn4Vgnc{f}$IMEEEpQ#zn`^^|b089IsTAz9E zH3Tr9;^<-2XKIG?MoX$i%1~O-l4?OgTB?NsX(PABvY1=jJY8Fo(AtXU>4@mENr>-C8rb2Z0qWoA>?(p42i z`WB78ZRUy1u|v&#yltn8nrb>k>1P6~=b7OuAV@IG&>kTvfy#p*fh63okBt-uNeOLb zkd$B;vN*;@I2}gaU0$*YK8sI&J@b>bFVFes6BX~iW$k`-U*hR!t$x0fK4SFK4y#?4 z-F)}X=w@e~zoL~Vam;j~VQkvRiOH06V#-sY@9=fGUV;D*WAu3P}o*6eKC#2+;pS)@KkVSt@MOH_<#0ZQ!`{&CZ))&-+v` z%yCNklmyU4X>so386bFRQ2LmtFe(NqHF{w_4~W-DgjF#3$1kwX zIO%=YYgw~FF_FikkJ#!<(LIP|5H7Rt5IxjR0%> z-$sKkI0un(m;%r&Bd~1!I>esahL!Q85>P%}??PveZ#U>uXsJ-8&{+U>lXc?+z3-KN zW?0rw*$8}t^gb6w!x8HL?*E7ZBr?dhyxz7YKcNaijbxG538buyY}R5E8r1lgwITVN z!~U~qd*xI6C;h=Tf7Lf}-^~5}goU?RelI3XpYOjQ@t^wwAR%hgoTJ5}|rBEzv)>3@$`Gj%6W7+9=tAZ~hii!^+ksKPg1 zajLY*Fc|v}2A0Gx41!W=uplp%8APz_r7B-3v3h^VNdhw*dx$}rJlgjHDKuOVh?~AK zlxI^-p~ZYXfAr<;)4 zA+=-XaGoJw{xe)q!_oZK$XAaX(MyB?LI5Fv5I_hZ1P}rbAp#SdpUM<=ko~l2M$D5L5@q&!IYVMc;pB85NY}VM~k8#=9Aq)kp4(rM> z`}t-1Bf_w{yGm?;2cGF4VEA}681OFsPMN}0RTSB-6Id@kX#V4ndiFaxl8^v;c1$v3 zL|S=Xm@u2_k~;ZThE)@skqpmoqrbHeP)~n|4hu6Ie zSanYP23IUxPOwdn#!=y(t!%0NSbF(YtXX2NJ$zw;(0)b- z#ahAuA^n9yxjmktDdM>1w2-V1`>JjCDHI8spH=Ce9hVrUr>k^SJ6CyzTP-9=W1Q)8 z_<8u|kl)Wv5Q0?Z+21>hAecPqiFs%k&SSVmk!=1PeixT*G-5-=V0xS&yR4atGOhG|<_2fA zw^=L5xxQQ$Zy@FB?+J_WQ~_>;{BjIMikC~Rj3j<=V}fTHBA0a5h+Tsf$DfmeS%$E{ zQ^OxwYu?YFNx_nVafc(39@rG5;0VI&6(p0E_t8G4XHmi{7Ao5Jucld7w1gvvg64vsKuw}N_<2i-N zFE@r~-V8Wb6Ivy^wt4)sA$Wo~Q|**)yf0-af*z7HA}tC*8?J3)cp{v~WJv%;S3l5BRjTVR+ zO_N#}9@vd;pLL+WfMJ(jAUYI$LTIFCy1A2=FEtcW?e>MnG~YeQHD68$a!j%BtbFrX z-2gq(L5wggSWsa<`bi76%(U3;7=*;4a}SQG%)24j=(TgF7kmFVAL zL!eqUt+I4|8_j?_)**UIxxQBo&LQ=TGlu=RSE4m_-F z`^JJ0+N(8`akyokS+nw*7CLWhynH~H?u16D{_%6RU_-lrJ*n?ldLj!9%PQHrs>#Me zk5}u*W0Nb-a$Vjvd5wG4An%{y+37X1m$-8FYYbthJ`R`C&t{z$l3n@BrwcuR=!v)n zwq1gr4A9{NhAoQr$yxaOutw*J{VooC#=0blz_&BYmTsB5T(y$oodsfis=`}#Tnt_6 z02eWqVN&qI)roiZB`fb7!7PrX3?}O@`|kV~ap~;(0fmf{gR6;6-m_~*j9WqLQ?;Rh z)1U-o0Lvfgt{WC+mtUev%&RO|BMvljO7%xhuk0F9(df$&?g@BjuIVfc^klucmqK5D zbAWZfCn_fV32V_poij={WKLj;?+uT4;EStpb&;YZ{Hw-I0IsKr)9T)K5Nz^J5c+Gr zTOGS7Ucx!kaYAEr>7>5C>OAIdUTe%TP9F5}l_dWFP*?A*mEMOoa}8a6Ne(61x~jm& zPA{<2f6h2u*5GZOUh=R`?(d&9Jd`2y>dUMQGHfTVy7Di6pBS@G2za;wS~*8#A|aL zN7nQTkImeoo(+C!!S>kbTDr)y@G+w}(RjV|X4RJsx-r=795lmuQ)US9*}y`UM+6Cl zse9h(rsqDoL5InUqX>%uNXt^S5I8uqGa21{w%wC4UulRm;xm0V)|X4fXZ;yB#Al}DMvo!%#mVSx#Akyj z_8~q){u%PmdTJ2JKkGN#>PbOpeH`SU-OqRv+BSoHlwrInn$$NihD?s)P3mhDZ|aB0 zCitw!ZxShV0cD?f)9)+eAR5J+2JRG!H;qMv7AurL8=M9GP3F&z-W*K+tf$D*U;L_{ zWP+v%{_xKopGv3GVz_)=V@W?i$%{$Od|NNCzUjFZ!m!dbzY;8S^yLJn{hT|?PLKHX z2@GDk<@(+4N;-G1!Yt1F=zBgjnce*<0pt74XG2wBR|yF+IrVXWtI8ks;?4` zN|s#x2MD6b1lLed2^MIl&Ub7AoaS6FSJ4xp?vA zUrEEE|Gq8Xsb;ysqK5c5P9u6>et5R=RA&sS=8)Q?zDqoz;3d7* zPkq$xX1=GZvuD0<@@g`#7#1xxTS&2DBk&6}1Kqvlr1d&Ixi(w%`Yf0#Ut@z!{!FIL zBA?KBCcLE6GAL@eZ|#*Z-l+C8!jn!a3ua^MOs4BC6^TyL(@^Pmophv7l5P^XoZLdt&7F~bpVjWwF3w#hOLtx$H{bIS5wMX z$|ver2q|ptWSh7)1MeWX!u;Kt2Ts}`JR`M=F zwD%TK>i4M*nU#oBSl%p^am~`?Uz8B<7m@1+q}KRj{pa;-pB{+ntOxT818&d`^IBcPz^+}P%G;UKKRlsH$OhJF%fQvZ(^M1K9BSz3-eddw zEe_C$O4K<~3*Q#EgM$NHcTrV4AFsRWWUd_Ox?NyOfw$&vXL;q#laZ z5PT21mJA>yrga?pj7Mh3xTe3Lo8f(NEzwC40tf+w073vEfDk|kAOwbjfanEh+?|+? zdF%2B6PlzIyrxZAxc`VtI%hVySeS{)dmAaE_^leDU~izWtLW!n$La2y<+jY5{v{(A zGg6qp>BNUfK(DyXi8^*LjJX5rs(z0xbWOP2ovNP0Wlyw|Y*a@?sB{yqQQ@89!V2C+ z_-k6ZP&!@|`mi7|nC$#=wU9%%9av%2>%VkLw?3{@qDzE;)$wdirP|8%6Kzp&z6bDE z5LSxrNOCCJjE+SJAOsKsS_G;;c3C)4`|2*fbL89T2O)qEKnNfN5CRARgaASSA%GA- z2p|Ly0tkVJ9f3c#HrZXo`_#MPn zVmxn-Huo+`C*!8JfVgJq5ni^o2Eb3I3PzRc?T%nu&+lvjvAIWL%LTnalRl?e`KGC}kf) zB=e}VAR?K*H#7gT%KDowQTJ1rMI?jdWYAq7h-47SSaM$pA{j(7h-56$oCnpd_!Y*D zNCuG%A{j(71Koc&q=X(K8ALLOWDv<9k{Rm52JS0@zk>VuYh4oMWKd4V5v^4YSSGSWl}=L~pB%d^@PM zzM^+O_OG)+9qQrgr_y}iXcKPWnl6^J!|9ufz2d?4?#Lpokl=AOQT${s=U%IfO2BgH zc^;vAlQfK{-)x1H@DXUDsh&fJ3lXz5Hm-DA-FPghV{Z3$w)1mYV=AB!lx4B7Mje;% zt87O1wBPZQ$T&UuL0x%6%B0>mI=DY-EC~Qo*A%=1?(~51La3zfF43|@3hQ>-Ftw93 z*+8=+DJEP+&i7(c)WG=&aEehXD9iaz#?+~EUQ_&oQksD z%Dw&+Pk&`J&*?}{FR&Q)0@I>I_4+u#(8p=FG3fL}y|Ryt!v!Bui%aC}DqSyd?i5`t z$WKRBRwyDyFcse2yZqar)i>uC4GS_-#4qh0x_ShHVGAm!2>NNIWlDatUA zDo6LhJuv2Ecs+Q2Jq6rT@L+q=8WTCId$KiE-PxLu-fWFf&l16G4TbfXDdcnf8i}@F zp4vny+CWNhjnhIuvB(jtxaAh9%k6saoiQ;3>QJhRjrI7xtZ}_bVC4Hg!b$KV5@dUT z77uNCS^G7(LffFH{g1fD633JqT6ln#R88C$wqi;1EH|vc=g}l zo!Y)wMqcJPU2ap-<~-96;@~akV05KY>-)0pdb2Rd_kEO-z-!;~4s2J26u%5@&~AH)dk5 z8v{8Pg-Z{{8V~#SM6?m=CtOOa<@QvzLvVxLt58lxXTTpV!5Jtg>5&b28Hz{3v5;8d)kMM>A z)Qo4;_xOmoHvKf?NlS8A_af{?3KnvHVQa{veW!TG7xaG5B1er4bI0$_g5KSKfzj^2 zs8aB#$`8D&Nhnvbv=2Z$B=Z>GU63pZWuNi&uClooVINi7SmgY=HddZZa@+Qu=oT(2 zd?u^q?cXi1TH<8wdHmza+vc=Q-u8z#)4s@-vM0ViWtP`xXEtyyTzl=z&D~2@JpIXm zSN%SIkstT&=PUpD#nE}^cbw1jE3Eixo_{SBAC;c|N|7hz@jd@C(x>A|!4Kn~^8LJ(l4`Ibbc5zY zBM;{2_WVF#do%^l$e|16MgBcpc9s&d77DhwE3BvK_A;L@k&_YidJcocjYxMDwLP3A z;YTmv8muiBrO_*H6IsXZcmlR*39P(a!+M&FRVtn+Y`T1}S9OP0F8G;MtT8KkZtI28 zQ(GofSsBpJ1CSoQZ$$*=XP2ON5dsJSgaASSA%GA-2p|L=2?S2owa2AU^(UplDq&4@ zc44?1 z%zlQuXLM>b9=b=S5Ox$a6Szx;@V4tyhn_Tfr={csrG6FobskA^R&}ywH~tuxwEt|C z>}n?{r6RV3xRlCnka!x9o-BE6;XXYxFen{$-NbEpifEY6sQS z^#EeOm92h)#PD2YH#mL`C(3~=?9Qn`>-W7gz-(Xs+)J9yPVHoTE8Wfk_B5s^3G(YT z?B&4kd0Pd0VPKIv7<=U?ynUg8O15|K%;Q#ZJ30#U^2H0LIlhSR0O4R?9z#yXU;@~i zere|j%1hLm6CA*ODXj|l?f|T2`p8m^_r+HnZZDGO4+5&Uc$FFJr^&xQqHdcdHF$;w~#@ClpI3I>5~1l|9~>T9mI0% zcZN^*`itTEX%nB=PlC35-?@?E&CgCL2jy&m&`vJ35ww$``!|!kH7^E&t#b2KHvC)Heh;(8OMtbCw8NRVn1Ya?5xH4`%_c zkmZH%X#!Uj)}`#qB~ahx4x}(mAR4z*0yK(_qVKgb*qyUa-pxlxF{3WQ)hh`b6^F1&jDMS>_S5X|hXy06=Y zPgf^D?y!_;^195%O49nIcQ!;WNN`TzYv&iXg^&hRTC2#`-a^Dl<6KYRn+Cp1!x6C|C#DihSL=LZ(SWJNF66=4^Y%Le8d}U5ceCtQE_QkLymS zrFsJ*5nuS*5%xLs)pg83~u|_^}+7Y9(6sb60Qn%L6 ztpb8NcJzB0c=b6*G%G8Gs!_Gx8RBdfTN={QaWk?evNonks>T=gY(u5r)WH{9=}RZ? z-{-%mRNGib%Pt<5Up}j%@GS)Hhi}UwFY@*?7>9eeQZ~@`r3v^8?%v3rc#n-YG(=COYgxgBqvgAMF5uozXny>B!Yz?0YeC@%5vZ|NHvy($_t;DdX&|56hPp zvS%KQ}R%(kYtSeDC2jJZF}pV9ioEH=YRZi&VT+=8MNoaq|Jxa+qb^I zwQa9+@Ve_)&J`$g&tJYmjBndoNkw#CpMU#0=k}D7z+X9b#ntVkBc7kKTG;w+O;}BW*}Iwk)O`bov^rz&X~4=n4Xljk zN57-{kd+B91Al{8+SruYIqI4JT>$wk#oFToT?XR|j_!BFJKCDoxo({Wu}z4ZUH!Wa zFc2zq)!d0wi#LirZ><;`sX zDpf+#HMf}KNdP!|G(PVJqndF{7<-aG#=caJ-WZD;NHDP(KH zQZ>k@aQlMC9Tpj{<1LMSB3}BR%WC=@Cj!(gx3}SNm#TfkBkCKPmDH_k&9|=o?eZ>n zS!}z4;C#9o99(Yn=Q8<{fgm1y_wSo<+`-Eq8jZ|p&ndvVGVmRB-46z4u_sLkArwwp z4ukTV0G!DJU(vgPYM`BqMCL95iKH|$_g%oTOQ2@O8vT*@k+Ih~JAS%D*wp%D5O~C? zBB)JSI8>%}K?!PB%pO+Es*Ta|MZycT*@MNc-i5vQ+F%CQW0O-=&~h(;mOJk9*}D{; zNU|I0!0g29M$_*BBq`b{Lp>!8$&J zNaU$J-vDR+9IId-%D&YMMxF4EML^o`S-nYY-vn}l zVh^XV#;s%?(qboaDL;*h$UIMRZl9xUXo@lZb58pq{RQo(Q~1DpzIgoi0fA~f_`;?G zwTdv`u+GjPp>OU7F2V~SA>3^{KHckChDq4)-o-W4(d`}6B42^)#qByXr9ffr%gtDK zpj~O3s!FgOt#SNsaG&sq$2KN| zT1AJd-q9oQh_n(XgInX}6YV}{K+s_1*v$NI_S9y296Fn;u_BAUP=n{1a9PRP`5uR7 zdwpmOV9z^c8k<`0R8qJ$-QBx+>gXLF+rZBJz@wgJ;Hq-sRYo&5KY$L4k8$Mf(EZ`h zl3b;A_2?l_uMgiKK;w7kLs{|joqPzwJCO{^op2R~Wqr(<_E&3ndJK=Pi&24Du`15Y z9+5D3q5Tkkflt#Aypm^#NFH{xzG$GoKLrHB(1k>Ob02p*~0MFx9e3d z4o@eNoXiZ51GT%{NqmFcWA$xlh+=DhZ)ZZ}Bj)Zj%w;H3%P5X$Qf_d7|4VDbsq_l~ zS|G9Q_k*`Udv21#Hwc>rcf!t);0+=juy=f7G?iY)R_v!XugM~8_)cyBP24A_HP?PLSjCSJ=HO+vwUqAQ1WFg1y%uDt&D&3;BG6(&VZ5c4$x(qFX zG5P7<*)FF^w;kN%Pmgo-I+sP@yLaVIDUd7qa&VneRCvIQAhYH3&Y+gpvlraCOIrIo zR**u`4XBjkD}|k-P9$ZEiKu1SQOjn}(TcB7%1r%+nHb*4{bd?rpH-Za_5{|40O56t z)lSD-YZyge=hJB|XCmRB?32A`M-}hKhGy{NMb-8fQ^eVPu$&P3hr)Qv0dC7xHkZF zOOqm1nAZwMGXmd~x6@uwM+Njo-yl*rlwv`Gd?%!h66>gXM%jTbh8N6t8pt5_(i;%P z6W%FRb&3nuS`Q@PNtZMPXsRgtHj)&2ebPg8#xKe=P#JDDKrtk`qWtj zp6i6}`Kn^Pp1?H6jUex3KTbRS1t;3kpT7#oYHbf4E7tf@(Vhw6HE-v4R+5q!TSexC zh!+?wN2Wog0K@J+zkO@G{YQT)Phwx?xN9Xb$Ygb!J%BF_NA{xv9dusiS5-EFJ=@Hu)pSk%z5#Wn<{EYBj z{DniXchK0D=-$zW9M8e}R3P-Bm7`az4Fap@^4A>?;a1Vh0gje0ZwLEY8AZaKq7wL8 zlloo;0eIRljt@BTa;F!B+KBMS^wk`yyn3Ev$-9RE-_}!h|!4f2UZ5$-uOFJXO7YB`p*@A0VN&D@U z@=?ygds3#9So4_47^0J0jzen{pbLZW8kzLlYtH;8MGAMrJfVq%!1zjz1N{^Q{U6@6 zM3`A9N?qsM7L;{Y2_tER*`_GFkq&^ZYo4nV`qxoSPyUWUcBW*}rJbt@fdre*3|*UU zSw<_f!k`?0VIg!<%hs}_)=~vbCKyH$LhcVHBJibZnDxUn=Y4=EtbPjZvm7V_bf)Qp zR4PtbBY}ne;$7iEKvi=wl)seGi|C8|{<5!RdpBQdZ?I`wy{-zRTmWAYu@{EmcE34S+7(V}F796dQ4~ zOK?eJ1{59`P^%U#2F4VqeJ6mX{@HGC%k~A50N~40u1s~9)RoAVebbH~9 ze4DLo@@@gKHCblAy$LAxa=jLsa@%yCn?0F#Xsw)`@kU7#@B|%~rJ!95zT$XA!Wd94 z1a1dv_fB&sfvoi{^`aR;KK7)UnqCAI!)oXu0s_R*4GV21F;01hu~u0PsejhDqAvh6JWnAU=3Y6SXokjLu5z zcq0gi-ln~~005!oH@hm>Gn!`1IC^24Ry+hCk1@5CS<$C+;%|)Et?}tXJ3si;CWh>k z?oAjf!YBt;(ah%c@e=&cFu&*~fC;k5Y+!o`rXo!RdjxZ=rMDFDk5Ge;SGR>+h_Vju1yG=zg zLarufR|VNn4&G5;P{I7UG&JZ`r$%&@QEA#(7vOr+v8R(9ES)RI7c*N_?T7Ri;9rQ^ zEsb!kvSKDoYyucw7Jon$sxQSR;A)~|KWz&^eSTuQDII&g{iH&A3U*Dr4%76N`K+XC zU9K%9QK7{jaPYrx0wJ4^;cMyvOqdj(nSp)!H|DhBLGMZm7NzVG2rz3MOwJ^T-4n{e zI%VbrX8#Ndt0hhzCB((#)Ng#c?^c> z?vg!6rCKM?LBJ2sj+Q5Z>yU_+8I2seK?I$cD`~9K!lo&ckOOjO5uIWNo z@={W?GMh+Y+gJHdAF+jz8+HT3iQL|BqtMxN%)GR9oMU*FHcw2O*GtQlBUc>k-^mLb zp9`J5J`CSJ%YnzPb$lgV4)~X@h$EiioUN|jaK0u@C|p|x0}zZfra|pmTzwTT zYl5F@{a~^RItvc3lv+mD)L@M!hFz8X+CZ1V+_+A*zEn$nc}iH}%P@TAIIb4`@thA} zAK|uLc2p#!#tZ+KO5Z*jajTBGs6~4YE9J0%FYnNaVvfU!KmzPNOUP|KzwnS+uX(FX zR}+{aDF(5r&8kxH>qzky$nnIg$=#tq+-w5M+ zeZZ)V8k;gGqNys@m7{}7b<qgSyL~+i zyJRZUnd(D(L7$3Aa4_it(9{4QWo8wmjrU#r4l&9ObTMpB(9dfQLVqqAk|1(nz3oM;4TUKDHgWbB`1>8Jd0j$tBX|wFG9l}JCyT^9I z&PR8x6S52PBCL8Om*C2vE=}9iltqP#W?^_vusCoGqu-!EaOk~l`MZ!-y39 zh7hnA0TAQv^nW`WKN;bWbKvJQd##`O>=T<;E=61%b!&s@^sKi;6O)eprQxq{j5rtZ z;a@~EPyPGLr~bS4@(!1mT6wKppst@g^ET(M(`<_`j}Al#{5B8}r7Ju06={lFQNCml zmQUQ_;EjP@;*<6`uDL6FR|>oO8SXfwZx2InGv1SE`rbT~cQ!+7XM88$gYj+2S=eI) zZ`E9>y*KI~(&E@tq;bG9@Le1@W!npz~gZO zxs|_TF=%tk2z^N^I0YWZE`4ouXQ|OGt4Nw2ycHi3VIFSQ-Z^q8=A!b>-P$W2>`zW; zq;78RQmN~VLt~rAIQ;d&rO`)hDSt8V!8#k7V+HKjdJ^6;C316aB>DKJ6c0(#0WSLr z(}qD0uqj3|#IPZOnPwv*Mk|`MWxBK_#m0RN$v!-`b~G%q5B=Pxc0n<*ISP9Fw;yCI zO538{NUQ#C8Q5s{T9;*eMl9oIQ9K`xV4 zs)U&tSpp%~*}?XN`eRW&k{iao=pJ3A(PU8(pJ}Jxc%6e`S1#sNDKzP-hQQpi4E1-7 zk1j9|G!fiCA#YLx*Y3N?ctt^t?qZjTmmG^*zZR%ibeaaq%cE<3ObK0d(|WpH&@U^x zQf*a31cM-ZS-JXMSCR2%mPqMLf%wQ|eT$B2YqH9-!2-9cBRAzsul)nW?N5>u<@whB z-%YoA?F35QA(qu+SrK0iZ(I0`-s14sPXhKO^%+TP`Ce6)$Ic4EtNoGD8`o^a4MRmN``kZnwii+NtZ{0jsth((_SyHD6mn$vXkZ$Lb{YR zTWJ+(dRKYm%sjfz%o*(+W)92t%P!w3_)4+&*~1OE|E@qTP6T_+U7;8oQH10M1&i>h zdr>_FAypu0{bUgbN)9XqVAXg#=^%*8jh2=qNmScws%*|9i9r%$ATg;@KxlU|f_}^@ z>DeF`^ussAxl*c#@-mcNsBOSb1F6kT|m{=vUh6=>s46eo{`{LC3Q`?GAd}K-B*t&FnXpQ35%X3 zw-%CckMslDEmPm7ot~Ke(fFQnMP+++l}!?Idi0>WA2s6#Ek+Bx38eJHlISN^L%)aD zdZki+)8_k*6N>~T=@h-^Opiy6)C|c^x5F%{h}JfcU!ET%wfR)hPN`28bUGboIWRu6 z3CYmla;Xn(ay>$-HAmh7S|>*4Zs5dxT+f|#2Js9sMus6afRC#AD`_p%7U!@oFnivuDg|KcS!nDNfSYl+x)5M+R@`4ibAD9Sr%EAp{E_I9kutpO{vuSE+G=S3%JL3 z*}T_t%=4&G&ybEE`Lwr0sA7G(f+k&R;<@-X}r_s&WNM8EfpQL*@ ztygY^>X`Spz}x@33Le^v2$?dbT|Vx$z{CbQ1yl+Yx(*m0bO#giZjHmj?hY95E$lJw z=~8RZ?u4pEOge>Rrzf4O+STgQ?W%B*tAZLhjT$&3HtEWWFxgUbR|RE`S3oPd~nOR ze$8(jJd(v8v3I7Q*Hb^t?0lVbKK0y&KL~SWb*X;Ks{Ry&Dv|wFBCStq8)TMW&;)1U z0%doDjAOZ^vh{nnLshmT(otRyF|)2svpX{-hs7m$e-q3NE~K|9TK`GfYT_C?-akg>mV0@&J4;i+&m>&E z!h{NNM|n$@8;81ljgu58_`cU4v-`Abd~<@LQT6o)--5iF_lYbg547Dcx{1Olj<#9l zF5#ScMR(iD-&a;PRVckW!=>P|EKlqN5-^7;D9eScQLi=Xrr3xz+MPq5-pF;Xp+N~hUQ4(k94Fg878>#PuQy>?el51?fe4XAj@6Gq5teQ|Kc}J*_ zDodd|@q&@%M&67;DPKlCL6ZT1kmYIo&KM7whKDfvIZtYEO_LnndSJ*VkL z-wbzuHCNQp?Z$*Uvb_+&pgY@)TUb%T7ez#7l4p0q*Y6Elny_jP5pE|f-;j?|zQ)>x z7)Ofet~!p+F*qh!o|ZVPxQchyG8 z7!+Nb5nZEX3=$VJ#Dxxb-zWK7qV{OFZWjLCcUlLgPL|9(LxS^9Cr zzdroxnZx$q%>90mB0g`D!;j>xP*mf-w-;yXbU7QAo~RUWwQDt9ddibm%G8Zq9(t^8 zZ)>Nd=GAk}Re>;ON5| z&w(YD)vf5GOWyT-?X(99Ur{hek(I;6;kJ&U#@WZQHSUhtm3*l>!-5n6$ zGh)(|z`$dP3WKvTku?*OLS>;Z+@Azb2sK-!2EICdQ45TUNtGIe0Czuo$%HdtazOM4 zSmW9IMTI6*sy`iPA~-|;W&oJrgxX52nT+-T2M%!M{uqWvk#WK44<0^zP1{ZgaX)5v zs0im({jz?VMN2-wBz4K+{xm=lVZ5A`c76Tk;$Pvcomj+MRWir^Qjs?cS4t)&PQ3EVY@Y6@P zTm~j(HoTKDw$Rd)Jz%s;1I=zhLzq%4iW!A<_*eK|1crL zdjB04WLO6Z#Rj58T87Q_Nal=hgKB^S?o`r0OehtMQo(~GeVw0-YOJv4 zOskXyk-%^eKZ;mU#5z0?tElwsmJv{pM{XuJl`l=3Fzt6{XK)8M5L#RM>jh!zDWGWv ztBwJEyfMRyx#`PW>Ne0%y^F6|y3V49>(FzGWR;zx!RvK=i7CJX`-&G!*5`sjWg7i3tq{#$>H zHh{@NhE>O7P|}cLy}ynNGOYb(X!>v;lit#Nhzu(-tjMsMwSg$DHYj3+H?5(Fl@^7E zE+c>pD>AGDmRoz=-2NjflnRCnfg)B@d(s$I31~L`%^B9Z7MqtDZr~rF>DF{MOP5xp zr)okhOlDr`V;;Vne$d^X1Kq|rP==Tm3sV&W`ZK1K7cIUIVQ(k8DaB+&KadpthHpxt zWflD(1P}r=0$z=kWK!&%>7rd@p13|kJanW2U9v8hy6$zp^FU7Ebs}Eb%xKg`9VYMq2Xf`#LsCamwq-FP%v!0OHHww&1srLj2JVR%y{o|iB8GB{Az4$Q&phI z#dA={4x&tD1QV0o4Wj&iPr7<8sfUuOmyvbf?^{;5_crnQ<&&D1c%QHC)rIauBDaF* z*!YFSk9p13i!G2k?P(mG)33N?>`?QqYr$Yet_O9IpIb=N&J^L@xV-L!vKUvjYSky9|yMr3Pbw=O}>{YX~+@m8ucHwv^6PZkohEa4b0O?Jnn zAAPR`2Y%nl8QFU?4V)UCL&3Z>#E+|mX*Gh!l;O_~MvfUU!6Iw6+Pmy(|D5${oJFLf zo$b}YTWnwI+84IbFU{teONKq8vWsV*1||04ot^s<%2KJ#!BjaFaqF6?jqRA!Mu)v! z^wM>1>9=l{FFOjBSRk_9p`u;%9R)>x=btiVo?@THv+9<=;OzwsttZ)gFYm=ZY#Us< zPk=#VTbpH7E)Kgc{a(o(Gmz-bZpS+*r!@Y?zczDC8wneX$IlH-tGK}3 zeIuOgf2BG!IYau_QgB|MW6!d)#%j`AMT_`WLu1%G{ zE_xyBqbK_D39NV*;~WbsGusZTFhj+*gyw%;ff$_E!KADv#;w0_==HCNfF0kK+S@sg z6Sn!C&VO)h{O%fk1y|0Eb^+f+-L0e zsw%R!Nk6&iy^PcB4zBKc2P}^2>on3soL)X#)>wJ|&UoeG0r}08cc7(8{rQ^ry4rfb zqwkSy|N5^75ln?~f21jPQO7QYCMx=I!e-IHMn?XB@ZlS~DkHc0vQ@5wG283I6z&pX zSNbj~SbX(gr1f-DGdIBr5D;E`$`Sl#s4%y7r4q9k+3xdiDbxlZu6NCtwC;w3o|SNz zR_^*@bM_JwLxH=DAHws#TE>n8X{tUx*kBc)+hT5^P7^V6>|T18KR6NY86tMb_5HbS zX`zyYssLbXT1|{&X6{>H|0+*B?9LA#sYjS+XL_%WVJmX)St~`@OtTzWeTekWXzpts z%dPz)9jCk-HZT*xle^Tc=CpOJ@xFA$nrXnagD>*SZSVSN>SULB+xw3OdfZbZbES2{ zq1({jd~?rv=kTN-N;wPpDqF>{LpTy;YG)C0!m%1Oh~iQ z4D{EB$^poNCE8Ez3A)PZL>F`r&7? z+JY(b)KO+Q6DfECjTMD-1|h?-6ZRUi4gm{SYfa9rYp4aOX~qI zIycQou zm=v(vZJyn(0K%`A(vz)zlELW0OYV)HV;AC6b^3*&o7LrgKs9a6%~)&n)!!fj)tsHg z)sDNT!Y5SUE;1+2Vx$0&`n)uUp|k2ax3cR5*JWUK)he476H(g>cnO(5`O|5;>z zn2@$*y9!by-1NqVJ$Zf7t?;imjlM&Hm(Iv3#SsDG4@}z#9P9TTEhU-c_H}; zxbmiFr{~0H87OqHX|%=v!LSLN(#E=-KNLJl@P+3jXPj=i-p zTNg%EKK{|bmKZMcAxc}_aGEDQBQ(c$8=71_9a@y5G9J@vbPee?aXF~7^=&HU)2&9)W)d+Moxh;L>+ckQWfR^E8| z8^2S>W-jsiC~nMGN8+}8J<;dG_MN8>v>%l!yGCtQIvr7Pr7?!jGt^rGQ!kgk82)~1 zR=S7h`k=%F>x{16P_!FSui#Zp9M}12>eHbX1%cWNN6P+{m|p2$Coi(7rm&(4blG!q z9OZK-TIPM}q6in_rx;d6$CV46{d>pSA?F2A`!>zUtkumO}k` zl+U8<<~Lwjw9~+o?~ZVI%QpypJKK`w;!Jz9GX}cUh^X#RXB)u^;^urH?6+UZw7W+S z#w&PjQ%8I4os>C;2(z1EbhGgaNIqP8Zd=}Cgsa(ki+d=yeabG+@ZTL)QaQ2T_?iCx zad{#qHe7l0sO*Y*aQ^MRs~ne!T2GzLEtgf0=D9-M*(g!nhFzP>o1cubZ8E24>s~HP z%ss diff --git a/project/generate.py b/project/generate.py index 333712a..f95d9d9 100644 --- a/project/generate.py +++ b/project/generate.py @@ -3,6 +3,7 @@ import numpy as np import midi import tensorflow as tf +import pypianoroll as roll from keras.layers import Input, Dense, Conv2D from keras.models import Model from tensorflow.keras import layers @@ -18,7 +19,8 @@ output_path = sys.argv[2] treshold = float(sys.argv[3]) #random seed -generate_seed = np.random.rand(12288).reshape(1,96,128) +# generate_seed = np.random.rand(12288).reshape(1,96,128) +generate_seed = np.random.rand(2).reshape(1,-1) # load and predict model = pickle.load(open(trained_model_path, 'rb')) @@ -29,8 +31,7 @@ generated_sample = generated_sample.reshape(96,128) generated_sample = generated_sample > treshold * generated_sample.max() #save to midi -midi.to_midi(generated_sample, output_path='{}.mid'.format(output_path) ) +midi = midi.to_midi(generated_sample, output_path='{}.mid'.format(output_path) ) -#save piano roll to png -plt.imshow(generated_sample, cmap = plt.get_cmap('gray')) -plt.savefig('{}.png'.format(output_path)) +#save plot for preview +roll.plot(midi, filename='{}.png'.format(output_path)) diff --git a/project/train.py b/project/train.py index a5a8ac3..eeda2a7 100644 --- a/project/train.py +++ b/project/train.py @@ -13,27 +13,100 @@ train_data_path = sys.argv[1] save_model_path = sys.argv[2] epochs = int(sys.argv[3]) -model = Sequential() -model.add(LSTM(128,input_shape=(96, 128),return_sequences=True)) -model.add(Dropout(0.3)) -model.add(LSTM(512, return_sequences=True)) -model.add(Dropout(0.3)) -model.add(LSTM(128)) -model.add(Dense(128)) -model.add(Dropout(0.3)) -model.add(Dense(128*96)) -model.add(Activation('softmax')) -model.add(Reshape((96, 128))) -model.compile(loss='categorical_crossentropy', optimizer='rmsprop') +# best model yet - working autoencoder +# model = Sequential() +# model.add(LSTM(128,input_shape=(96, 128),return_sequences=True)) +# model.add(Dropout(0.3)) +# model.add(LSTM(512, return_sequences=True)) +# model.add(Dropout(0.3)) +# model.add(LSTM(128)) +# model.add(Dense(96)) +# model.add(Dropout(0.3)) +# model.add(Dense(128*96)) +# model.add(Activation('softmax')) +# model.add(Reshape((96, 128))) +# model.compile(loss='binary_crossentropy', optimizer='rmsprop') + +# # working model #2 +# model = Sequential() +# model.add(LSTM(128, input_shape=(96, 128), return_sequences=True)) +# model.add(LSTM(512, return_sequences=True)) +# model.add(TimeDistributed(Dense(128))) +# model.add(Activation('softmax')) +# model.add(Reshape((96, 128))) +# model.compile(loss='binary_crossentropy', optimizer='adadelta') + +# VAE model - LSTM +from keras.layers import Lambda, Input, Dense +from keras.models import Model +from keras.datasets import mnist +from keras.losses import mse, binary_crossentropy +from keras.utils import plot_model +from keras import backend as K +import numpy as np +import matplotlib.pyplot as plt +import argparse +import os + +def sampling(args): + z_mean, z_log_var = args + batch = K.shape(z_mean)[0] + dim = K.int_shape(z_mean)[1] + epsilon = K.random_normal(shape=(batch, dim)) + return z_mean + K.exp(0.5 * z_log_var) * epsilon + +# network parameters +original_dim = 96 * 128 +input_shape = (96,128) +intermediate_dim = 128 +batch_size = 128 +latent_dim = 2 + +# Encoder +inputs = Input(shape=input_shape, name='encoder_input') +x = LSTM(intermediate_dim, activation='relu', name='first_lstm')(inputs) + +z_mean = Dense(latent_dim, name='z_mean')(x) +z_log_var = Dense(latent_dim, name='z_log_var')(x) + +z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) + +encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder') + +# build decoder model +latent_inputs = Input(shape=(latent_dim,), name='z_sampling') +x = LSTM(intermediate_dim, return_sequences=True, activation='relu')(latent_inputs) +outputs = Dense(original_dim, activation='sigmoid')(x) +reshaped = Reshape((96,128))(outputs) + +# instantiate decoder model +decoder = Model(latent_inputs, outputs, name='decoder') +# plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True) + +# instantiate VAE model +outputs = decoder(encoder(inputs)[2]) +vae = Model(inputs, outputs, name='vae_mlp') # load training data print('Traing Samples: {}'.format(train_data_path)) train_X = np.load(train_data_path)['arr_0'] +# train_X = train_X.reshape((train_X.shape[0], 96*128)) + +# compiling model + +def vae_loss(inputs, outputs): + xent_loss = binary_crossentropy(inputs, outputs) + kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) + return xent_loss + kl_loss + +vae.compile(optimizer='rmsprop', loss=vae_loss) +# vae.summary() +# plot_model(vae, to_file='vae_mlp.png', show_shapes=True) # model training -model.fit(train_X, train_X, epochs=epochs, batch_size=32) +vae.fit(train_X,train_X, epochs=epochs, batch_size=32) # save trained model pickle_path = '{}.pickle'.format(save_model_path) -pickle.dump(model, open(pickle_path,'wb')) +pickle.dump(decoder, open(pickle_path,'wb')) print("Model save to {}".format(pickle_path))