Beispiel #1
0
        //IMPLEMENTATION OF SHIFT - page 9
        internal ShiftedAddressing(Unit shift, GatedAddressing gatedAddressing)
        {
            _shift = shift;
            GatedAddressing = gatedAddressing;
            _gatedVector = GatedAddressing.GatedVector;
            _cellCount = _gatedVector.Length;

            ShiftedVector = UnitFactory.GetVector(_cellCount);
            double cellCountDbl = _cellCount;

            //Max shift is from range -1 to 1
            _shiftWeight = Sigmoid.GetValue(_shift.Value);
            double maxShift = ((2 * _shiftWeight) - 1);
            double convolutionDbl = (maxShift + cellCountDbl) % cellCountDbl;

            _simj = 1 - (convolutionDbl - Math.Floor(convolutionDbl));
            _oneMinusSimj = (1 - _simj);
            _convolution = (int)convolutionDbl;

            for (int i = 0; i < _cellCount; i++)
            {
                int imj = (i + _convolution) % _cellCount;

                Unit vectorItem = ShiftedVector[i];

                vectorItem.Value = (_gatedVector[imj].Value * _simj) +
                                   (_gatedVector[(imj + 1) % _cellCount].Value * _oneMinusSimj);
                if (vectorItem.Value < 0 || double.IsNaN(vectorItem.Value))
                {
                    throw new Exception("Error - weight should not be smaller than zero or nan");
                }
            }
        }
Beispiel #2
0
        //IMPLEMENTATION OF SHIFT - page 9
        internal ShiftedAddressing(Unit shift, GatedAddressing gatedAddressing)
        {
            _shift          = shift;
            GatedAddressing = gatedAddressing;
            _gatedVector    = GatedAddressing.GatedVector;
            _cellCount      = _gatedVector.Length;

            ShiftedVector = UnitFactory.GetVector(_cellCount);
            double cellCountDbl = _cellCount;

            //Max shift is from range -1 to 1
            _shiftWeight = Sigmoid.GetValue(_shift.Value);
            double maxShift       = ((2 * _shiftWeight) - 1);
            double convolutionDbl = (maxShift + cellCountDbl) % cellCountDbl;

            _simj         = 1 - (convolutionDbl - Math.Floor(convolutionDbl));
            _oneMinusSimj = (1 - _simj);
            _convolution  = (int)convolutionDbl;

            for (int i = 0; i < _cellCount; i++)
            {
                int imj = (i + _convolution) % _cellCount;

                Unit vectorItem = ShiftedVector[i];

                vectorItem.Value = (_gatedVector[imj].Value * _simj) +
                                   (_gatedVector[(imj + 1) % _cellCount].Value * _oneMinusSimj);
                if (vectorItem.Value < 0 || double.IsNaN(vectorItem.Value))
                {
                    throw new Exception("Error - weight should not be smaller than zero or nan");
                }
            }
        }
Beispiel #3
0
        internal MemoryState Process(Head[] heads)
        {
            int headCount = heads.Length;
            int memoryColumnsN = _memory.CellCountN;

            ReadData[] newReadDatas = new ReadData[headCount];
            HeadSetting[] newHeadSettings = new HeadSetting[headCount];
            for (int i = 0; i < headCount; i++)
            {
                Head head = heads[i];
                BetaSimilarity[] similarities = new BetaSimilarity[_memory.CellCountN];

                for (int j = 0; j < memoryColumnsN; j++)
                {
                    Unit[] memoryColumn = _memory.Data[j];
                    SimilarityMeasure similarity = new SimilarityMeasure(new CosineSimilarityFunction(), head.KeyVector, memoryColumn);
                    similarities[j] = new BetaSimilarity(head.Beta, similarity);
                }

                ContentAddressing ca = new ContentAddressing(similarities);
                GatedAddressing ga = new GatedAddressing(head.Gate, ca, _headSettings[i]);
                ShiftedAddressing sa = new ShiftedAddressing(head.Shift, ga);

                newHeadSettings[i] = new HeadSetting(head.Gamma, sa);
                newReadDatas[i] = new ReadData(newHeadSettings[i], _memory);
            }

            NTMMemory newMemory = new NTMMemory(newHeadSettings, heads, _memory);

            return new MemoryState(newMemory, newHeadSettings, newReadDatas);
        }