internal MemoryState Process(Head[] heads) { int headCount = heads.Length; int memoryColumnsN = _memory.CellCountN; ReadData[] newReadDatas = new ReadData[headCount]; HeadSetting[] newHeadSettings = new HeadSetting[headCount]; for (int i = 0; i < headCount; i++) { Head head = heads[i]; BetaSimilarity[] similarities = new BetaSimilarity[_memory.CellCountN]; for (int j = 0; j < memoryColumnsN; j++) { Unit[] memoryColumn = _memory.Data[j]; SimilarityMeasure similarity = new SimilarityMeasure(new CosineSimilarityFunction(), head.KeyVector, memoryColumn); similarities[j] = new BetaSimilarity(head.Beta, similarity); } ContentAddressing ca = new ContentAddressing(similarities); GatedAddressing ga = new GatedAddressing(head.Gate, ca, _headSettings[i]); ShiftedAddressing sa = new ShiftedAddressing(head.Shift, ga); newHeadSettings[i] = new HeadSetting(head.Gamma, sa); newReadDatas[i] = new ReadData(newHeadSettings[i], _memory); } NTMMemory newMemory = new NTMMemory(newHeadSettings, heads, _memory); return(new MemoryState(newMemory, newHeadSettings, newReadDatas)); }
internal HeadSetting(int memoryColumnsN, ContentAddressing contentAddressing) { AddressingVector = UnitFactory.GetVector(memoryColumnsN); for (int i = 0; i < memoryColumnsN; i++) { AddressingVector[i].Value = contentAddressing.ContentVector[i].Value; } }
internal GatedAddressing(Unit gate, ContentAddressing contentAddressing, HeadSetting oldHeadSettings) { _gate = gate; ContentVector = contentAddressing; _oldHeadSettings = oldHeadSettings; Unit[] contentVector = ContentVector.ContentVector; _memoryCellCount = contentVector.Length; GatedVector = UnitFactory.GetVector(_memoryCellCount); //Implementation of focusing by location - page 8 part 3.3.2. Focusing by Location _gt = Sigmoid.GetValue(_gate.Value); _oneminusgt = (1 - _gt); for (int i = 0; i < _memoryCellCount; i++) { GatedVector[i].Value = (_gt * contentVector[i].Value) + (_oneminusgt * _oldHeadSettings.AddressingVector[i].Value); } }
public ContentAddressing[] GetContentAddressing() { return(ContentAddressing.GetVector(HeadCount, i => _oldSimilarities[i])); }