public void TestFeedFowardCbow() { var testSubject = Word2VecTestClass.GetFruitAndJuiceCbowPreset(); var nextWord = testSubject.ReadNextWord(); Console.WriteLine($"Target Word: {nextWord.TargetWord.Word}"); Console.WriteLine("ContextWords:"); foreach (var f in nextWord.ContextWords) { Console.WriteLine(f.Word); } var testResult = testSubject.FeedFoward(nextWord); Console.WriteLine(testResult.Print()); var expect = new[] { 0.12369933153402476, 0.12534084483955113, 0.12465557473825152, 0.12537247712099528, 0.12502507795456907, 0.1248127122711621, 0.12537976081050728, 0.1257142207309389, }; for (var i = 0; i < expect.Length; i++) { var tr = testResult[0, i]; var ex = expect[i]; Assert.IsTrue(System.Math.Abs(tr - ex) < 0.000001); } }
public void TestBackpropagateSkipGram() { var WI = new[, ] { { 0.0368690499741906, 0.0135390551358178, -0.0377502766147956, 0.0336466886725494, -0.0270974517460435 }, { 0.016440156808328, 0.0169175251931499, 0.0922190173492855, -0.031504701977365, -0.0114259869379112 }, { -0.0127991043556477, 0.0487416296958652, -0.0701314578625054, -0.0536020711779604, -0.0154984786713023 }, { 0.0657846277420338, 0.00954156336818894, -0.0136355384782122, -0.0162582306732695, 0.0366232909898382 }, { 0.0211230152850612, 0.0597162170148065, 0.057514813056921, -0.0456053966403033, 0.0368811796125402 }, { 0.0958493649940236, -0.0770609196168654, -0.0213859166583912, -0.0380667455206005, -0.059526237547177 }, { 0.0255941354323151, -0.0347667583891967, 0.0813340035180254, -0.0742663991517696, 0.064415111748695 }, { -0.09638258432801, 0.0478759082722831, -0.00176863414317772, 0.0383402774754634, -0.048905799700369 } }; var WO = new[, ] { { -0.0702967214259769, 0.0859662250550306, -0.0469181300825058, 0.0750174499931826, -0.0862165990221391, -0.025250033999444, 0.0277256570419882, 0.0687815535668198 }, { -0.0643791042102404, -0.0336792340193313, -0.0982879727139547, 0.011368115391288, 0.0984938144676824, 0.0220246880883466, -0.00235211108920729, 0.021364244130144 }, { -0.0938820250303866, 0.0536523073695844, 0.015810846870677, 0.0204317684846147, 0.0960061897505104, 0.0433280812312514, 0.0465293434199548, 0.0599207696318258 }, { 0.0455908204175489, 0.0266330542632533, 0.0242098380458587, -0.0212760037375968, 0.0244631621634882, 0.0349988715420472, -0.0567920621283315, -0.0829100279057911 }, { -0.0940704868147478, -0.0859107999996798, 0.093844279644007, -0.0646651228725282, 0.000752150221146719, -0.0685755731391607, -0.0400362076889892, -0.0787345594627478 } }; var oneHot = new double[, ] { { 1, 0, 0, 0, 0, 0, 0, 0 } }; var expectedOutput = new double[, ] { { 0, 1, 0, 1, 0, 0, 0, 0 } }; var actualOutput = oneHot.DotProduct(WI).DotProduct(WO).GetSoftmax(); var testSubject = Word2VecTestClass.GetFruitAndJuiceCbowPreset(); testSubject.WI = WI; testSubject.WO = WO; testSubject.Backpropagate(oneHot, actualOutput, expectedOutput); var expectedWI = new[, ] { { 0.06832391877374158, 0.0102198806703854, -0.02896077629135472, 0.03484245859783144, -0.048742468748119175 }, { 0.016440156808328, 0.0169175251931499, 0.0922190173492855, -0.031504701977365, -0.0114259869379112 }, { -0.0127991043556477, 0.0487416296958652, -0.0701314578625054, -0.0536020711779604, -0.0154984786713023 }, { 0.0657846277420338, 0.00954156336818894, -0.0136355384782122, -0.0162582306732695, 0.0366232909898382 }, { 0.0211230152850612, 0.0597162170148065, 0.057514813056921, -0.0456053966403033, 0.0368811796125402 }, { 0.0958493649940236, -0.0770609196168654, -0.0213859166583912, -0.0380667455206005, -0.059526237547177 }, { 0.0255941354323151, -0.0347667583891967, 0.0813340035180254, -0.0742663991517696, 0.064415111748695 }, { -0.09638258432801, 0.0478759082722831, -0.00176863414317772, 0.0383402774754634, -0.048905799700369 }, }; var expectedWO = new[, ] { { -0.0712222542811498, 0.0924147349563861, -0.0478348666338179, 0.0814666326878734, -0.0871339949705385, -0.0261724288026183, 0.0268054360904665, 0.067859952075192 }, { -0.0647189784506933, -0.0313112117578562, -0.0986246167750825, 0.0137363847159791, 0.0981569282627097, 0.0216859662035703, -0.00269003469202615, 0.0210258135652898 }, { -0.092934370525502, 0.0470496682264595, 0.0167494948265954, 0.0138284404673656, 0.0969455128641142, 0.0442725226798988, 0.0474715590585729, 0.0608643988075684 }, { 0.0447461792929195, 0.0325179632914124, 0.0233732244266215, -0.0153904807183556, 0.0236259467785824, 0.0341570942031326, -0.0576318556106014, -0.0837510812687247 }, { -0.0933902528000395, -0.0906502274058155, 0.0945180486855311, -0.0694050447580429, 0.00142640389627994, -0.0678976454820252, -0.0393598777361693, -0.0780572148616271 } }; Assert.IsTrue(MatrixOps.AreEqual(testSubject.WI, expectedWI)); Assert.IsTrue(MatrixOps.AreEqual(testSubject.WO, expectedWO)); }
public void TestFeedFowardSkipGram() { var testSubject = Word2VecTestClass.GetFruitAndJuiceCbowPreset(); testSubject.IsCbow = false; var nextWord = testSubject.ReadNextWord(); Console.WriteLine($"Target Word: {nextWord.TargetWord.Word}"); Console.WriteLine("ContextWords:"); foreach (var f in nextWord.ContextWords) { Console.WriteLine(f.Word); } var testResult = testSubject.FeedFoward(nextWord); Console.WriteLine(testResult.Print()); var expect = new[] { 0.12551623323909344, 0.12548466723855337, 0.12432332158733014, 0.12539342630127645, 0.12441274579106223, 0.12509066599491453, 0.12479585887972285, 0.12498308096804703, }; for (var i = 0; i < expect.Length; i++) { var tr = testResult[0, i]; var ex = expect[i]; Assert.IsTrue(System.Math.Abs(tr - ex) < 0.000001); } }