internal MemoryState Process(Head[] heads) { int headCount = heads.Length; int memoryColumnsN = _memory.CellCountN; ReadData[] newReadDatas = new ReadData[headCount]; HeadSetting[] newHeadSettings = new HeadSetting[headCount]; for (int i = 0; i < headCount; i++) { Head head = heads[i]; BetaSimilarity[] similarities = new BetaSimilarity[_memory.CellCountN]; for (int j = 0; j < memoryColumnsN; j++) { Unit[] memoryColumn = _memory.Data[j]; SimilarityMeasure similarity = new SimilarityMeasure(new CosineSimilarityFunction(), head.KeyVector, memoryColumn); similarities[j] = new BetaSimilarity(head.Beta, similarity); } ContentAddressing ca = new ContentAddressing(similarities); GatedAddressing ga = new GatedAddressing(head.Gate, ca, _headSettings[i]); ShiftedAddressing sa = new ShiftedAddressing(head.Shift, ga); newHeadSettings[i] = new HeadSetting(head.Gamma, sa); newReadDatas[i] = new ReadData(newHeadSettings[i], _memory); } NTMMemory newMemory = new NTMMemory(newHeadSettings, heads, _memory); return(new MemoryState(newMemory, newHeadSettings, newReadDatas)); }
internal NTMMemory(HeadSetting[] headSettings, Head[] heads, NTMMemory memory) { CellCountN = memory.CellCountN; CellSizeM = memory.CellSizeM; HeadCount = memory.HeadCount; HeadSettings = headSettings; _heads = heads; _oldMemory = memory; Data = UnitFactory.GetTensor2(memory.CellCountN, memory.CellSizeM); _erase = GetTensor2(HeadCount, memory.CellSizeM); _add = GetTensor2(HeadCount, memory.CellSizeM); var erasures = GetTensor2(memory.CellCountN, memory.CellSizeM); for (int i = 0; i < HeadCount; i++) { Unit[] eraseVector = _heads[i].EraseVector; Unit[] addVector = _heads[i].AddVector; double[] erases = _erase[i]; double[] adds = _add[i]; for (int j = 0; j < CellSizeM; j++) { erases[j] = Sigmoid.GetValue(eraseVector[j].Value); adds[j] = Sigmoid.GetValue(addVector[j].Value); } } for (int i = 0; i < CellCountN; i++) { Unit[] oldRow = _oldMemory.Data[i]; double[] erasure = erasures[i]; Unit[] row = Data[i]; for (int j = 0; j < CellSizeM; j++) { Unit oldCell = oldRow[j]; double erase = 1; double add = 0; for (int k = 0; k < HeadCount; k++) { HeadSetting headSetting = HeadSettings[k]; double addressingValue = headSetting.AddressingVector[i].Value; erase *= (1 - (addressingValue * _erase[k][j])); add += addressingValue * _add[k][j]; } erasure[j] = erase; row[j].Value += (erase * oldCell.Value) + add; } } }
internal ReadData(HeadSetting headSetting, NTMMemory controllerMemory) { HeadSetting = headSetting; _controllerMemory = controllerMemory; _cellSize = _controllerMemory.CellSizeM; _cellCount = _controllerMemory.CellCountN; ReadVector = new Unit[_cellSize]; for (int i = 0; i < _cellSize; i++) { double temp = 0; for (int j = 0; j < _cellCount; j++) { temp += headSetting.AddressingVector[j].Value * controllerMemory.Data[j][i].Value; //if (double.IsNaN(temp)) //{ // throw new Exception("Memory error"); //} } ReadVector[i] = new Unit(temp); } }
internal MemoryState(NTMMemory memory, HeadSetting[] headSettings, ReadData[] readDatas) { _memory = memory; _headSettings = headSettings; ReadData = readDatas; }
internal MemoryState(NTMMemory memory) { _memory = memory; }
internal MemoryState Process(Head[] heads) { int headCount = heads.Length; int memoryColumnsN = _memory.CellCountN; ReadData[] newReadDatas = new ReadData[headCount]; HeadSetting[] newHeadSettings = new HeadSetting[headCount]; for (int i = 0; i < headCount; i++) { Head head = heads[i]; BetaSimilarity[] similarities = new BetaSimilarity[_memory.CellCountN]; for (int j = 0; j < memoryColumnsN; j++) { Unit[] memoryColumn = _memory.Data[j]; SimilarityMeasure similarity = new SimilarityMeasure(new CosineSimilarityFunction(), head.KeyVector, memoryColumn); similarities[j] = new BetaSimilarity(head.Beta, similarity); } ContentAddressing ca = new ContentAddressing(similarities); GatedAddressing ga = new GatedAddressing(head.Gate, ca, _headSettings[i]); ShiftedAddressing sa = new ShiftedAddressing(head.Shift, ga); newHeadSettings[i] = new HeadSetting(head.Gamma, sa); newReadDatas[i] = new ReadData(newHeadSettings[i], _memory); } NTMMemory newMemory = new NTMMemory(newHeadSettings, heads, _memory); return new MemoryState(newMemory, newHeadSettings, newReadDatas); }