示例#1
0
        //constructor for debugging , because it enables to pass secretkey which will not happen in real life
        // special parameters :
        //  batchsize - number of batches sample in one ciphertext ,if the client batches mulitple  samples in the ciphertet .
        //  featureSize - number of features in sample

        private Svc(double[][] vectors, double[][] coefficients, double[] intercepts, String kernel, double gamma, double coef0, ulong degree, int power, PublicKey publicKey, SecretKey secretKey, RelinKeys relinKeys, GaloisKeys galoisKeys, int batchSize, int featureSize)
        {
            this._vectors      = vectors;
            this._coefficients = coefficients;
            this._intercepts   = intercepts;

            this._kernel = (Kernel)System.Enum.Parse(typeof(Kernel), kernel);
            this._gamma  = gamma;
            this._coef0  = coef0;
            this._degree = degree;
            this._power  = power;
            //Use the ckks SCheme
            EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS);

            // polyModulusDegree and CoeffModulus used for general SVM algorithm and depends on the polynomial kernel degree
            // ( and the percision constraint ) .
            // My implementation can be used up to dgree = 4 , but it can be easly refactored with some oprimizations  to higher
            // polynomial degree
            ulong polyModulusDegree = 16384;

            if (power >= 20 && power < 40)
            {
                parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree,
                                                         new int[] { 60, 20, 21, 22, 23, 24, 25, 26, 27, 60 });
            }
            else if (power >= 40 && power < 60)
            {
                parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree,
                                                         new int[] { 60, 40, 40, 40, 40, 40, 40, 40, 60 });
            }
            else if (power == 60)
            {
                polyModulusDegree  = 32768;
                parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree,
                                                         new int[] { 60, 60, 60, 60, 60, 60, 60, 60, 60 });
            }
            parms.PolyModulusDegree = polyModulusDegree;
            _context = new SEALContext(parms);



            _publicKey  = publicKey;
            _secretKey  = secretKey;
            _relinKeys  = relinKeys;
            _galoisKeys = galoisKeys;

            _evaluator = new Evaluator(_context);
            if (_secretKey != null)
            {
                _decryptor = new Decryptor(_context, _secretKey);                 //FOR DEBUG ONLY ( not used in real server)
            }
            _encoder = new CKKSEncoder(_context);


            Stopwatch serverInitStopwatch = new Stopwatch();

            serverInitStopwatch.Start();



            _numOfrowsCount    = _vectors.Length;       //Number of Support Vectors
            _numOfcolumnsCount = _vectors[0].Length;    //Number of features in every Support vector
            _scale             = Math.Pow(2.0, _power);


            //	vars for batch rotations
            _svPlaintexts = new Plaintext[_numOfrowsCount];

            //Encode support vectors
            _sums = new Ciphertext[_numOfrowsCount];
            if (UseBatchInnerProduct)
            {
                double[] batchVectors = new double[batchSize * featureSize];

                for (int i = 0; i < _numOfrowsCount; i++)
                {
                    for (int k = 0; k < batchSize * featureSize; k++)
                    {
                        var index0 = k % featureSize;

                        batchVectors[k] = index0 < _numOfcolumnsCount ? _vectors[i][index0] : 0;
                    }
                    _svPlaintexts[i] = new Plaintext();
                    _encoder.Encode(batchVectors, _scale, _svPlaintexts[i]);
                    SVCUtilities.SvcUtilities.PrintScale(_svPlaintexts[i], "batch supportVectorsPlaintext" + i);
                    _sums[i] = new Ciphertext();
                }
            }
            else
            {
                /////////////////////////////////////////////////////////////////////////
                //
                //   vars for simple inner product
                //
                // Handle SV

                _svPlaintextsArr = new Plaintext[_numOfrowsCount, _numOfcolumnsCount];

                //Encode SV
                for (int i = 0; i < _numOfrowsCount; i++)
                {
                    for (int j = 0; j < _numOfcolumnsCount; j++)
                    {
                        _svPlaintextsArr[i, j] = new Plaintext();

                        _encoder.Encode(_vectors[i][j] != 0 ? _vectors[i][j] : Zero, _scale, _svPlaintextsArr[i, j]);
                        SvcUtilities.PrintScale(_svPlaintextsArr[i, j], $"supportVectorsPlaintext[{i}][{j}]");
                    }
                }

                // Prepare sum of inner product
                _innerProdSums = new Ciphertext[_numOfcolumnsCount];
                for (int i = 0; i < _numOfcolumnsCount; i++)
                {
                    _innerProdSums[i] = new Ciphertext();
                }
                //////////////////////////////////////////////////////////////
            }

            // Allocate memory for svm secure calculation
            _kernels      = new Ciphertext[_numOfrowsCount];
            _decisionsArr = new Ciphertext[_numOfrowsCount];
            _coefArr      = new Plaintext[_numOfrowsCount];

            for (int i = 0; i < _numOfrowsCount; i++)
            {
                _kernels[i]      = new Ciphertext();
                _decisionsArr[i] = new Ciphertext();
                _coefArr[i]      = new Plaintext();
            }
            _gamaPlaintext = new Plaintext();
            _encoder.Encode(_gamma != 0 ? _gamma : Zero, _scale, _gamaPlaintext);


            serverInitStopwatch.Stop();
            Console.WriteLine($"server Init elapsed {serverInitStopwatch.ElapsedMilliseconds} ms");
        }
示例#2
0
        // Function for classification of samples
        // I follow the SEAL eamples recomandations to rescale and relinearize after each calculation.
        // useRelinearizeInplace and useReScale should be always true.
        // This parametrs are enabled only for debugging and learning purpose.
        public Ciphertext  Predict(Ciphertext featuresCiphertexts, bool useRelinearizeInplace, bool useReScale,
                                   Stopwatch innerProductStopwatch, Stopwatch degreeStopwatch, Stopwatch negateStopwatch, Stopwatch serverDecisionStopWatch)
        {
            Ciphertext tempCt = new Ciphertext();

            // Level 1
            for (int i = 0; i < _numOfrowsCount; i++)
            {
                //inner product
                //calculate IP = < x, x'>
                innerProductStopwatch.Start();
                if (UseBatchInnerProduct)
                {
                    _kernels[i] = InnerProduct(featuresCiphertexts, _svPlaintexts, i, _sums, _numOfcolumnsCount, tempCt);
                }


                innerProductStopwatch.Stop();
                SvcUtilities.PrintCyprherText(_decryptor, _kernels[i], _encoder, $"inner product TotalValue {i}");
                SvcUtilities.PrintScale(_kernels[i], "0. kernels" + i);
                if (useRelinearizeInplace)
                {
                    _evaluator.RelinearizeInplace(_kernels[i], _relinKeys);
                }

                if (useReScale)
                {
                    _evaluator.RescaleToNextInplace(_kernels[i]);
                }

                SvcUtilities.PrintScale(_kernels[i], "1. kernels" + i);
                _kernels[i].Scale = _scale;

                //For polynimial kernel calculate
                if (_kernel == Kernel.Poly)
                {
                    // calculate (y *IP+r)^d
                    // IP is calculated previously
                    // y = gamma
                    // r = _coef0
                    if (useReScale)
                    {
                        ParmsId lastParmsId = _kernels[i].ParmsId;
                        _evaluator.ModSwitchToInplace(_gamaPlaintext, lastParmsId);
                    }
                    //calculate y * IP
                    _evaluator.MultiplyPlainInplace(_kernels[i], _gamaPlaintext);
                    SvcUtilities.PrintScale(_kernels[i], "2. kernels" + i);
                    if (useRelinearizeInplace)
                    {
                        _evaluator.RelinearizeInplace(_kernels[i], _relinKeys);
                    }

                    if (useReScale)
                    {
                        _evaluator.RescaleToNextInplace(_kernels[i]);
                    }
                    SvcUtilities.PrintScale(_kernels[i], "3.  kernels" + i);

                    // add r
                    if (Math.Abs(_coef0) > 0)
                    {
                        Plaintext coef0Plaintext = new Plaintext();
                        _encoder.Encode(_coef0, _kernels[i].Scale, coef0Plaintext);
                        if (useReScale)
                        {
                            ParmsId lastParmsId = _kernels[i].ParmsId;
                            _evaluator.ModSwitchToInplace(coef0Plaintext, lastParmsId);
                        }

                        //kernels[i].Scale = coef0Plaintext.Scale;

                        _evaluator.AddPlainInplace(_kernels[i], coef0Plaintext);
                    }

                    SvcUtilities.PrintScale(_kernels[i], "4.  kernels" + i);
                    degreeStopwatch.Start();
                    // calculate the polynom degree
                    var kernel = new Ciphertext(_kernels[i]);
                    for (int d = 0; d < (int)_degree - 1; d++)
                    {
                        kernel.Scale = _kernels[i].Scale;
                        if (useReScale)
                        {
                            ParmsId lastParmsId = _kernels[i].ParmsId;
                            _evaluator.ModSwitchToInplace(kernel, lastParmsId);
                        }
                        _evaluator.MultiplyInplace(_kernels[i], kernel);
                        SvcUtilities.PrintScale(_kernels[i], d + "  5. kernels" + i);
                        if (useRelinearizeInplace)
                        {
                            _evaluator.RelinearizeInplace(_kernels[i], _relinKeys);
                        }

                        if (useReScale)
                        {
                            _evaluator.RescaleToNextInplace(_kernels[i]);
                        }
                        SvcUtilities.PrintScale(_kernels[i], d + " rescale  6. kernels" + i);
                    }
                    SvcUtilities.PrintScale(_kernels[i], "7. kernels" + i);
                    degreeStopwatch.Stop();
                }

                negateStopwatch.Start();

                _evaluator.NegateInplace(_kernels[i]);
                negateStopwatch.Stop();

                SvcUtilities.PrintScale(_kernels[i], "8. kernel" + i);

                SvcUtilities.PrintCyprherText(_decryptor, _kernels[i], _encoder, "kernel" + i);
            }
            serverDecisionStopWatch.Start();
            // Encode coefficients : ParmsId! , scale!
            double scale2 = Math.Pow(2.0, _power);

            if (useReScale)
            {
                scale2 = _kernels[0].Scale;
            }

            for (int i = 0; i < _numOfrowsCount; i++)
            {
                _encoder.Encode(_coefficients[0][i], scale2, _coefArr[i]);
                SvcUtilities.PrintScale(_coefArr[i], "coefPlainText" + i);
            }

            if (useReScale)
            {
                for (int i = 0; i < _numOfrowsCount; i++)
                {
                    ParmsId lastParmsId = _kernels[i].ParmsId;
                    _evaluator.ModSwitchToInplace(_coefArr[i], lastParmsId);
                }
            }
            // Level 2
            // Calculate decisionArr
            for (int i = 0; i < _numOfrowsCount; i++)
            {
                _evaluator.MultiplyPlain(_kernels[i], _coefArr[i], _decisionsArr[i]);
                if (useRelinearizeInplace)
                {
                    _evaluator.RelinearizeInplace(_decisionsArr[i], _relinKeys);
                }

                if (useReScale)
                {
                    _evaluator.RescaleToNextInplace(_decisionsArr[i]);
                }
                SvcUtilities.PrintScale(_decisionsArr[i], "decision" + i);
                SvcUtilities.PrintCyprherText(_decryptor, _decisionsArr[i], _encoder, "decision" + i);
            }



            // Calculate decisionTotal
            Ciphertext decisionTotal = new Ciphertext();

            //=================================================================
            _evaluator.AddMany(_decisionsArr, decisionTotal);
            //=================================================================

            SvcUtilities.PrintScale(decisionTotal, "decisionTotal");
            SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "decision total");


            // Encode intercepts : ParmsId! , scale!
            Plaintext interceptsPlainText = new Plaintext();

            double scale3 = Math.Pow(2.0, _power * 3);

            if (useReScale)
            {
                scale3 = decisionTotal.Scale;
            }
            _encoder.Encode(_intercepts[0], scale3, interceptsPlainText);
            if (useReScale)
            {
                ParmsId lastParmsId = decisionTotal.ParmsId;
                _evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId);
            }

            SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText");
            SvcUtilities.PrintScale(decisionTotal, "decisionTotal");


            //// Calculate finalTotal
            Ciphertext finalTotal = new Ciphertext();

            //=================================================================
            _evaluator.AddPlainInplace(decisionTotal, interceptsPlainText);
            //=================================================================

            SvcUtilities.PrintScale(decisionTotal, "decisionTotal");  //Level 3
            List <double> result = SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "finalTotal", true);

            serverDecisionStopWatch.Stop();
            long innerProductMilliseconds = innerProductStopwatch.ElapsedMilliseconds;
            //Console.WriteLine($"server innerProductStopwatch elapsed {innerProductMilliseconds} ms");
            long negateMilliseconds = negateStopwatch.ElapsedMilliseconds;
            //Console.WriteLine($"server negateStopwatch elapsed {negateMilliseconds} ms");
            long degreeMilliseconds = degreeStopwatch.ElapsedMilliseconds;
            //Console.WriteLine($"server degreeStopwatch elapsed {degreeMilliseconds} ms");
            long serverDecisionMilliseconds = serverDecisionStopWatch.ElapsedMilliseconds;

            //Console.WriteLine($"server Decision elapsed {serverDecisionMilliseconds} ms");


            return(decisionTotal);
        }
示例#3
0
            public int Predict(double[] features, int power, bool useRelinearizeInplace, bool useReScale, Stopwatch timePredictSum)
            {
                EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS);

                if (power < 60)
                {
                    ulong polyModulusDegree = 8192;
                    parms.PolyModulusDegree = polyModulusDegree;
                    parms.CoeffModulus      = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 40, 40, 60 });
                }
                else
                {
                    ulong polyModulusDegree = 16384;
                    parms.PolyModulusDegree = polyModulusDegree;
                    parms.CoeffModulus      = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 60, 60, 60, 60, 60 });
                }
                //

                double scale = Math.Pow(2.0, power);

                SEALContext context = new SEALContext(parms);

                Console.WriteLine();

                KeyGenerator keygen    = new KeyGenerator(context);
                PublicKey    publicKey = keygen.PublicKey;
                SecretKey    secretKey = keygen.SecretKey;
                RelinKeys    relinKeys = keygen.RelinKeys();
                Encryptor    encryptor = new Encryptor(context, publicKey);
                Evaluator    evaluator = new Evaluator(context);
                Decryptor    decryptor = new Decryptor(context, secretKey);

                CKKSEncoder encoder = new CKKSEncoder(context);

                ulong slotCount = encoder.SlotCount;

                Console.WriteLine($"Number of slots: {slotCount}");

                timePredictSum.Start();
                Plaintext fPlaintext0 = new Plaintext();
                Plaintext fPlaintext1 = new Plaintext();
                Plaintext fPlaintext2 = new Plaintext();
                Plaintext fPlaintext3 = new Plaintext();


                encoder.Encode(features[0], scale, fPlaintext0);
                encoder.Encode(features[1], scale, fPlaintext1);
                encoder.Encode(features[2], scale, fPlaintext2);
                encoder.Encode(features[3], scale, fPlaintext3);


                SvcUtilities.PrintScale(fPlaintext0, "fPlaintext0");
                SvcUtilities.PrintScale(fPlaintext1, "fPlaintext1");
                SvcUtilities.PrintScale(fPlaintext2, "fPlaintext2");

                Ciphertext f0Encrypted = new Ciphertext();
                Ciphertext f1Encrypted = new Ciphertext();
                Ciphertext f2Encrypted = new Ciphertext();
                Ciphertext f3Encrypted = new Ciphertext();

                encryptor.Encrypt(fPlaintext0, f0Encrypted);
                encryptor.Encrypt(fPlaintext1, f1Encrypted);
                encryptor.Encrypt(fPlaintext2, f2Encrypted);
                encryptor.Encrypt(fPlaintext3, f3Encrypted);

                SvcUtilities.PrintScale(f0Encrypted, "f0Encrypted");
                SvcUtilities.PrintScale(f1Encrypted, "f1Encrypted");
                SvcUtilities.PrintScale(f2Encrypted, "f2Encrypted");

                Plaintext v00Plaintext1 = new Plaintext();
                Plaintext v01Plaintext1 = new Plaintext();
                Plaintext v02Plaintext1 = new Plaintext();
                Plaintext v03Plaintext1 = new Plaintext();
                Plaintext v10Plaintext1 = new Plaintext();
                Plaintext v11Plaintext1 = new Plaintext();
                Plaintext v12Plaintext1 = new Plaintext();
                Plaintext v13Plaintext1 = new Plaintext();

                Plaintext v20Plaintext1 = new Plaintext();
                Plaintext v21Plaintext1 = new Plaintext();
                Plaintext v22Plaintext1 = new Plaintext();
                Plaintext v23Plaintext1 = new Plaintext();


                encoder.Encode(_vectors[0][0], scale, v00Plaintext1);
                encoder.Encode(_vectors[0][1], scale, v01Plaintext1);
                encoder.Encode(_vectors[0][2], scale, v02Plaintext1);
                encoder.Encode(_vectors[0][3], scale, v03Plaintext1);
                encoder.Encode(_vectors[1][0], scale, v10Plaintext1);
                encoder.Encode(_vectors[1][1], scale, v11Plaintext1);
                encoder.Encode(_vectors[1][2], scale, v12Plaintext1);
                encoder.Encode(_vectors[1][3], scale, v13Plaintext1);
                encoder.Encode(_vectors[2][0], scale, v20Plaintext1);
                encoder.Encode(_vectors[2][1], scale, v21Plaintext1);
                encoder.Encode(_vectors[2][2], scale, v22Plaintext1);
                encoder.Encode(_vectors[2][3], scale, v23Plaintext1);


                SvcUtilities.PrintScale(v00Plaintext1, "v00Plaintext1");
                SvcUtilities.PrintScale(v01Plaintext1, "v01Plaintext1");
                SvcUtilities.PrintScale(v02Plaintext1, "v02Plaintext1");
                SvcUtilities.PrintScale(v03Plaintext1, "v03Plaintext1");

                SvcUtilities.PrintScale(v10Plaintext1, "v10Plaintext1");
                SvcUtilities.PrintScale(v11Plaintext1, "v11Plaintext1");
                SvcUtilities.PrintScale(v12Plaintext1, "v12Plaintext1");
                SvcUtilities.PrintScale(v13Plaintext1, "v13Plaintext1");

                SvcUtilities.PrintScale(v20Plaintext1, "v20Plaintext1");
                SvcUtilities.PrintScale(v21Plaintext1, "v21Plaintext1");
                SvcUtilities.PrintScale(v22Plaintext1, "v22Plaintext1");
                SvcUtilities.PrintScale(v23Plaintext1, "v23Plaintext1");

                Plaintext coef00PlainText = new Plaintext();
                Plaintext coef01PlainText = new Plaintext();
                Plaintext coef02PlainText = new Plaintext();


                Ciphertext tSum1   = new Ciphertext();
                Ciphertext tSum2   = new Ciphertext();
                Ciphertext tSum3   = new Ciphertext();
                Ciphertext tSum4   = new Ciphertext();
                Ciphertext kernel0 = new Ciphertext();

                //Level 1->2
                //=================================================================
                evaluator.MultiplyPlain(f0Encrypted, v00Plaintext1, tSum1);
                evaluator.MultiplyPlain(f1Encrypted, v01Plaintext1, tSum2);
                evaluator.MultiplyPlain(f2Encrypted, v02Plaintext1, tSum3);
                evaluator.MultiplyPlain(f3Encrypted, v03Plaintext1, tSum4);
                //=================================================================

                if (useRelinearizeInplace)
                {
                    Console.WriteLine("RelinearizeInplace sums 1");
                    evaluator.RelinearizeInplace(tSum1, relinKeys);
                    evaluator.RelinearizeInplace(tSum2, relinKeys);
                    evaluator.RelinearizeInplace(tSum3, relinKeys);
                    evaluator.RelinearizeInplace(tSum4, relinKeys);
                }

                if (useReScale)
                {
                    Console.WriteLine("useReScale sums 1");
                    evaluator.RescaleToNextInplace(tSum1);
                    evaluator.RescaleToNextInplace(tSum2);
                    evaluator.RescaleToNextInplace(tSum3);
                    evaluator.RescaleToNextInplace(tSum4);
                }


                SvcUtilities.PrintScale(tSum1, "tSum1"); //Level 2
                SvcUtilities.PrintScale(tSum2, "tSum2"); //Level 2
                SvcUtilities.PrintScale(tSum3, "tSum3"); //Level 2
                SvcUtilities.PrintScale(tSum4, "tSum4"); //Level 2

                var ciphertexts1 = new List <Ciphertext>();

                ciphertexts1.Add(tSum1);
                ciphertexts1.Add(tSum2);
                ciphertexts1.Add(tSum3);
                ciphertexts1.Add(tSum4);

                //=================================================================
                evaluator.AddMany(ciphertexts1, kernel0);    //Level 2
                //=================================================================
                SvcUtilities.PrintScale(kernel0, "kernel0"); //Level 2

                SvcUtilities.PrintCyprherText(decryptor, kernel0, encoder, "kernel0");

                Ciphertext kernel1 = new Ciphertext();

                //Level 1-> 2
                //=================================================================
                evaluator.MultiplyPlain(f0Encrypted, v10Plaintext1, tSum1);
                evaluator.MultiplyPlain(f1Encrypted, v11Plaintext1, tSum2);
                evaluator.MultiplyPlain(f2Encrypted, v12Plaintext1, tSum3);
                evaluator.MultiplyPlain(f3Encrypted, v13Plaintext1, tSum4);
                //=================================================================
                if (useRelinearizeInplace)
                {
                    Console.WriteLine("RelinearizeInplace sums 2");
                    evaluator.RelinearizeInplace(tSum1, relinKeys);
                    evaluator.RelinearizeInplace(tSum2, relinKeys);
                    evaluator.RelinearizeInplace(tSum3, relinKeys);
                    evaluator.RelinearizeInplace(tSum4, relinKeys);
                }

                if (useReScale)
                {
                    Console.WriteLine("useReScale sums 2");
                    evaluator.RescaleToNextInplace(tSum1);
                    evaluator.RescaleToNextInplace(tSum2);
                    evaluator.RescaleToNextInplace(tSum3);
                    evaluator.RescaleToNextInplace(tSum4);
                }

                ciphertexts1.Add(tSum1);
                ciphertexts1.Add(tSum2);
                ciphertexts1.Add(tSum3);
                ciphertexts1.Add(tSum4);



                Console.WriteLine("Second time : ");
                SvcUtilities.PrintScale(tSum1, "tSum1"); //Level 2
                SvcUtilities.PrintScale(tSum2, "tSum2"); //Level 2
                SvcUtilities.PrintScale(tSum3, "tSum3"); //Level 2
                SvcUtilities.PrintScale(tSum4, "tSum4"); //Level 2

                var ciphertexts2 = new List <Ciphertext>();

                ciphertexts2.Add(tSum1);
                ciphertexts2.Add(tSum2);
                ciphertexts2.Add(tSum3);
                ciphertexts2.Add(tSum4);


                //=================================================================
                evaluator.AddMany(ciphertexts2, kernel1); // Level 2
                //=================================================================
                SvcUtilities.PrintScale(kernel1, "kernel1");
                SvcUtilities.PrintCyprherText(decryptor, kernel1, encoder, "kernel1");

                Ciphertext kernel2 = new Ciphertext();

                //Level 1->2
                //=================================================================
                evaluator.MultiplyPlain(f0Encrypted, v20Plaintext1, tSum1);
                evaluator.MultiplyPlain(f1Encrypted, v21Plaintext1, tSum2);
                evaluator.MultiplyPlain(f2Encrypted, v22Plaintext1, tSum3);
                evaluator.MultiplyPlain(f3Encrypted, v23Plaintext1, tSum4);
                //=================================================================

                if (useRelinearizeInplace)
                {
                    Console.WriteLine("RelinearizeInplace sums 3");
                    evaluator.RelinearizeInplace(tSum1, relinKeys);
                    evaluator.RelinearizeInplace(tSum2, relinKeys);
                    evaluator.RelinearizeInplace(tSum3, relinKeys);
                    evaluator.RelinearizeInplace(tSum4, relinKeys);
                }



                if (useReScale)
                {
                    Console.WriteLine("useReScale sums 3");
                    evaluator.RescaleToNextInplace(tSum1);
                    evaluator.RescaleToNextInplace(tSum2);
                    evaluator.RescaleToNextInplace(tSum3);
                    evaluator.RescaleToNextInplace(tSum4);
                }

                var ciphertexts3 = new List <Ciphertext>();

                ciphertexts3.Add(tSum1);
                ciphertexts3.Add(tSum2);
                ciphertexts3.Add(tSum3);
                ciphertexts3.Add(tSum4);

                Console.WriteLine("Third time : ");
                SvcUtilities.PrintScale(tSum1, "tSum1"); //Level 2
                SvcUtilities.PrintScale(tSum2, "tSum2"); //Level 2
                SvcUtilities.PrintScale(tSum3, "tSum3"); //Level 2
                SvcUtilities.PrintScale(tSum4, "tSum4"); //Level 2

                //=================================================================
                evaluator.AddMany(ciphertexts3, kernel2);
                //=================================================================
                SvcUtilities.PrintScale(kernel2, "kernel2"); //Level 2

                SvcUtilities.PrintCyprherText(decryptor, kernel2, encoder, "kernel2");

                Ciphertext decision1 = new Ciphertext();
                Ciphertext decision2 = new Ciphertext();
                Ciphertext decision3 = new Ciphertext();

                SvcUtilities.PrintScale(decision1, "decision1"); //Level 0
                SvcUtilities.PrintScale(decision2, "decision2"); //Level 0
                SvcUtilities.PrintScale(decision3, "decision3"); //Level 0


                Ciphertext nKernel0 = new Ciphertext();
                Ciphertext nKernel1 = new Ciphertext();
                Ciphertext nKernel2 = new Ciphertext();

                //=================================================================
                evaluator.Negate(kernel0, nKernel0);
                evaluator.Negate(kernel1, nKernel1); //Level 2
                evaluator.Negate(kernel2, nKernel2); //Level 2
                //=================================================================



                //nKernel0.Scale = scale;
                //nKernel1.Scale = scale;
                //nKernel2.Scale = scale;
                double scale2 = Math.Pow(2.0, power);

                if (useReScale)
                {
                    scale2 = nKernel0.Scale;
                }

                encoder.Encode(_coefficients[0][0], scale2, coef00PlainText);
                encoder.Encode(_coefficients[0][1], scale2, coef01PlainText);
                encoder.Encode(_coefficients[0][2], scale2, coef02PlainText);

                SvcUtilities.PrintScale(coef00PlainText, "coef00PlainText");
                SvcUtilities.PrintScale(coef01PlainText, "coef01PlainText");
                SvcUtilities.PrintScale(coef02PlainText, "coef02PlainText");



                if (useReScale)
                {
                    ParmsId lastParmsId = nKernel0.ParmsId;
                    evaluator.ModSwitchToInplace(coef00PlainText, lastParmsId);


                    lastParmsId = nKernel1.ParmsId;
                    evaluator.ModSwitchToInplace(coef01PlainText, lastParmsId);

                    lastParmsId = nKernel2.ParmsId;
                    evaluator.ModSwitchToInplace(coef02PlainText, lastParmsId);
                }

                SvcUtilities.PrintScale(nKernel0, "nKernel0");   //Level 2
                SvcUtilities.PrintScale(nKernel1, "nKernel1");   //Level 2
                SvcUtilities.PrintScale(nKernel2, "nKernel2");   //Level 2

                //Level 2->3
                //=================================================================
                evaluator.MultiplyPlain(nKernel0, coef00PlainText, decision1);
                evaluator.MultiplyPlain(nKernel1, coef01PlainText, decision2);
                evaluator.MultiplyPlain(nKernel2, coef02PlainText, decision3);
                //=================================================================



                if (useRelinearizeInplace)
                {
                    Console.WriteLine("RelinearizeInplace decisions");

                    evaluator.RelinearizeInplace(decision1, relinKeys);
                    evaluator.RelinearizeInplace(decision2, relinKeys);
                    evaluator.RelinearizeInplace(decision3, relinKeys);
                }


                if (useReScale)
                {
                    Console.WriteLine("Rescale decisions");

                    evaluator.RescaleToNextInplace(decision1);
                    evaluator.RescaleToNextInplace(decision2);
                    evaluator.RescaleToNextInplace(decision3);
                }


                SvcUtilities.PrintScale(decision1, "decision1"); //Level 3
                SvcUtilities.PrintScale(decision2, "decision2"); //Level 3
                SvcUtilities.PrintScale(decision3, "decision3"); //Level 3
                SvcUtilities.PrintCyprherText(decryptor, decision1, encoder, "decision1");
                SvcUtilities.PrintCyprherText(decryptor, decision2, encoder, "decision2");
                SvcUtilities.PrintCyprherText(decryptor, decision3, encoder, "decision3");

                //=================================================================
                //evaluator.RelinearizeInplace(decision1,keygen.RelinKeys());
                //evaluator.RelinearizeInplace(decision2, keygen.RelinKeys());
                //evaluator.RelinearizeInplace(decision3, keygen.RelinKeys());
                //=================================================================


                //PrintScale(decision1, "decision1");

                var decisions = new List <Ciphertext>();

                decisions.Add(decision1);
                decisions.Add(decision2);
                decisions.Add(decision3);

                Ciphertext decisionTotal = new Ciphertext();

                //=================================================================
                evaluator.AddMany(decisions, decisionTotal);
                //=================================================================
                SvcUtilities.PrintScale(decisionTotal, "decisionTotal");
                SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "decision total");


                Ciphertext finalTotal = new Ciphertext();

                Plaintext interceptsPlainText = new Plaintext();

                double scale3 = Math.Pow(2.0, power * 3);

                if (useReScale)
                {
                    scale3 = decisionTotal.Scale;
                }
                encoder.Encode(_intercepts[0], scale3, interceptsPlainText);
                if (useReScale)
                {
                    ParmsId lastParmsId = decisionTotal.ParmsId;
                    evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId);
                }

                SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText");
                SvcUtilities.PrintScale(decisionTotal, "decisionTotal");

                //=================================================================
                evaluator.AddPlainInplace(decisionTotal, interceptsPlainText);
                //=================================================================
                timePredictSum.Stop();
                SvcUtilities.PrintScale(decisionTotal, "decisionTotal");  //Level 3
                List <double> result = SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "finalTotal");

                using (System.IO.StreamWriter file =
                           new System.IO.StreamWriter(
                               $@"{OutputDir}IrisSimple_IrisSecureSVC_total_{power}_{useRelinearizeInplace}_{useReScale}.txt", !_firstTime)
                       )
                {
                    _firstTime = false;
                    file.WriteLine($"{result[0]}");
                }

                if (result[0] > 0)
                {
                    return(0);
                }

                return(1);
            }
            public int Predict(double[] features, int power, bool useRelinearizeInplace, bool useReScale, Stopwatch timePredictSum)
            {
                EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS);

                if (power < 60)
                {
                    ulong polyModulusDegree = 8192;
                    parms.PolyModulusDegree = polyModulusDegree;
                    parms.CoeffModulus      = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 40, 40, 60 });
                }
                else
                {
                    ulong polyModulusDegree = 16384;
                    parms.PolyModulusDegree = polyModulusDegree;
                    parms.CoeffModulus      = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 60, 60, 60, 60, 60 });
                }
                //

                double scale = Math.Pow(2.0, power);

                SEALContext context = new SEALContext(parms);

                Console.WriteLine();

                KeyGenerator keygen    = new KeyGenerator(context);
                PublicKey    publicKey = keygen.PublicKey;
                SecretKey    secretKey = keygen.SecretKey;
                RelinKeys    relinKeys = keygen.RelinKeys();
                Encryptor    encryptor = new Encryptor(context, publicKey);
                Evaluator    evaluator = new Evaluator(context);
                Decryptor    decryptor = new Decryptor(context, secretKey);

                CKKSEncoder encoder = new CKKSEncoder(context);

                ulong slotCount = encoder.SlotCount;

                Console.WriteLine($"Number of slots: {slotCount}");
                timePredictSum.Start();
                var featuresLength = features.Length;

                var plaintexts          = new Plaintext[featuresLength];
                var featuresCiphertexts = new Ciphertext[featuresLength];

                //Encode and encrypt features
                for (int i = 0; i < featuresLength; i++)
                {
                    plaintexts[i] = new Plaintext();

                    encoder.Encode(features[i], scale, plaintexts[i]);

                    SvcUtilities.PrintScale(plaintexts[i], "featurePlaintext" + i);
                    featuresCiphertexts[i] = new Ciphertext();

                    encryptor.Encrypt(plaintexts[i], featuresCiphertexts[i]);
                    SvcUtilities.PrintScale(featuresCiphertexts[i], "featurefEncrypted" + i);
                }

                // Handle SV
                var numOfrows    = _vectors.Length;
                var numOfcolumns = _vectors[0].Length;
                var svPlaintexts = new Plaintext[numOfrows, numOfcolumns];

                //Encode SV
                for (int i = 0; i < numOfrows; i++)
                {
                    for (int j = 0; j < numOfcolumns; j++)
                    {
                        svPlaintexts[i, j] = new Plaintext();
                        encoder.Encode(_vectors[i][j], scale, svPlaintexts[i, j]);
                        SvcUtilities.PrintScale(svPlaintexts[i, j], "supportVectorsPlaintext" + i + j);
                    }
                }
                // Prepare sum of inner product
                var sums = new Ciphertext[numOfcolumns];

                for (int i = 0; i < numOfcolumns; i++)
                {
                    sums[i] = new Ciphertext();
                }

                var kernels      = new Ciphertext[numOfrows];
                var decisionsArr = new Ciphertext[numOfrows];
                var coefArr      = new Plaintext [numOfrows];

                for (int i = 0; i < numOfrows; i++)
                {
                    kernels[i]      = new Ciphertext();
                    decisionsArr[i] = new Ciphertext();
                    coefArr[i]      = new Plaintext();
                }

                // Level 1
                for (int i = 0; i < numOfrows; i++)
                {
                    var ciphertexts = new List <Ciphertext>();

                    //inner product
                    for (int j = 0; j < numOfcolumns; j++)
                    {
                        evaluator.MultiplyPlain(featuresCiphertexts[j], svPlaintexts[i, j], sums[j]);

                        if (useRelinearizeInplace)
                        {
                            evaluator.RelinearizeInplace(sums[j], relinKeys);
                        }

                        if (useReScale)
                        {
                            evaluator.RescaleToNextInplace(sums[j]);
                        }

                        SvcUtilities.PrintScale(sums[j], "tSum" + j);
                    }

                    evaluator.AddMany(sums, kernels[i]);

                    evaluator.NegateInplace(kernels[i]);

                    SvcUtilities.PrintScale(kernels[i], "kernel" + i);

                    SvcUtilities.PrintCyprherText(decryptor, kernels[i], encoder, "kernel" + i);
                }

                // Encode coefficients : ParmsId! , scale!
                double scale2 = Math.Pow(2.0, power);

                if (useReScale)
                {
                    scale2 = kernels[0].Scale;
                }

                for (int i = 0; i < numOfrows; i++)
                {
                    encoder.Encode(_coefficients[0][i], scale2, coefArr[i]);
                    SvcUtilities.PrintScale(coefArr[i], "coefPlainText+i");
                }



                if (useReScale)
                {
                    for (int i = 0; i < numOfrows; i++)
                    {
                        ParmsId lastParmsId = kernels[i].ParmsId;
                        evaluator.ModSwitchToInplace(coefArr[i], lastParmsId);
                    }
                }
                // Level 2
                // Calculate decisionArr
                for (int i = 0; i < numOfrows; i++)
                {
                    evaluator.MultiplyPlain(kernels[i], coefArr[i], decisionsArr[i]);
                    if (useRelinearizeInplace)
                    {
                        evaluator.RelinearizeInplace(decisionsArr[i], relinKeys);
                    }

                    if (useReScale)
                    {
                        evaluator.RescaleToNextInplace(decisionsArr[i]);
                    }
                    SvcUtilities.PrintScale(decisionsArr[i], "decision" + i);
                    SvcUtilities.PrintCyprherText(decryptor, decisionsArr[i], encoder, "decision" + i);
                }



                // Calculate decisionTotal
                Ciphertext decisionTotal = new Ciphertext();

                //=================================================================
                evaluator.AddMany(decisionsArr, decisionTotal);
                //=================================================================

                SvcUtilities.PrintScale(decisionTotal, "decisionTotal");
                SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "decision total");


                // Encode intercepts : ParmsId! , scale!
                Plaintext interceptsPlainText = new Plaintext();

                double scale3 = Math.Pow(2.0, power * 3);

                if (useReScale)
                {
                    scale3 = decisionTotal.Scale;
                }
                encoder.Encode(_intercepts[0], scale3, interceptsPlainText);
                if (useReScale)
                {
                    ParmsId lastParmsId = decisionTotal.ParmsId;
                    evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId);
                }

                SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText");
                SvcUtilities.PrintScale(decisionTotal, "decisionTotal");


                //// Calculate finalTotal
                Ciphertext finalTotal = new Ciphertext();

                //=================================================================
                evaluator.AddPlainInplace(decisionTotal, interceptsPlainText);
                //=================================================================
                timePredictSum.Stop();
                SvcUtilities.PrintScale(decisionTotal, "decisionTotal");  //Level 3
                List <double> result = SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "finalTotal");

                using (System.IO.StreamWriter file =
                           new System.IO.StreamWriter(
                               $@"{OutputDir}IrisLinear_IrisSecureSVC_total_{power}_{useRelinearizeInplace}_{useReScale}.txt", !_firstTime)
                       )
                {
                    _firstTime = false;
                    file.WriteLine($"{result[0]}");
                }

                if (result[0] > 0)
                {
                    return(0);
                }

                return(1);
            }
示例#5
0
			public int Predict(double[] features, bool useRelinearizeInplace,bool useReScale,out double finalResult)
			{

			   
				Console.WriteLine();

				ulong slotCount = _encoder.SlotCount;
				
				//Console.WriteLine($"Number of slots: {slotCount}");

				var featuresLength = features.Length;



				var plaintexts  = new Plaintext();
				var featuresCiphertexts = new Ciphertext();
				
				Stopwatch clientStopwatch = new Stopwatch();
				clientStopwatch.Start();
                //Encode and encrypt features
                double scale = Math.Pow(2.0, _power);
				_encoder.Encode(features, scale, plaintexts);
				_encryptor.Encrypt(plaintexts, featuresCiphertexts);
				SvcUtilities.PrintScale(plaintexts, "featurePlaintext");
				SvcUtilities.PrintScale(featuresCiphertexts, "featurefEncrypted");
				clientStopwatch.Stop();


				Stopwatch serverInitStopwatch = new Stopwatch();
				serverInitStopwatch.Start();
                // Handle SV
                var numOfrowsCount    = _vectors.Length;
				var numOfcolumnsCount = _vectors[0].Length;
		   
				var svPlaintexts = new Plaintext[numOfrowsCount];



                //Encode SV
                var sums = new Ciphertext[numOfrowsCount];
                for (int i = 0; i < numOfrowsCount; i++)
				{
						svPlaintexts[i] = new Plaintext();
						_encoder.Encode(_vectors[i], scale, svPlaintexts[i]);
						SvcUtilities.PrintScale(svPlaintexts[i], "supportVectorsPlaintext"+i);
						sums[i] = new Ciphertext();
                }

				var kernels      = new Ciphertext[numOfrowsCount];
				var decisionsArr = new Ciphertext[numOfrowsCount];
				var coefArr      = new Plaintext [numOfrowsCount];

				for (int i = 0; i < numOfrowsCount; i++)
				{
					kernels[i]       = new Ciphertext();
					decisionsArr[i]  = new Ciphertext();
					coefArr[i]       = new Plaintext();
				}
				Plaintext  gamaPlaintext= new Plaintext();
				_encoder.Encode(_gamma, scale, gamaPlaintext);

				Ciphertext tempCt = new Ciphertext();
				serverInitStopwatch.Stop();

				Stopwatch innerProductStopwatch = new Stopwatch();
                Stopwatch negateStopwatch = new Stopwatch();
                Stopwatch degreeStopwatch = new Stopwatch();
                // Level 1
                for (int i = 0; i < numOfrowsCount; i++)
				{
                    //Console.WriteLine(i);

                    //inner product
                  

                    innerProductStopwatch.Start();

                    _evaluator.MultiplyPlain(featuresCiphertexts, svPlaintexts[i],sums[i]);
                    int numOfRotations = (int)Math.Ceiling(Math.Log2(numOfcolumnsCount));

                    for (int k = 1,m=1; m <= numOfRotations/*(int)encoder.SlotCount/2*/; k <<= 1,m++)
                    {

                        _evaluator.RotateVector(sums[i], k, _galoisKeys, tempCt);
                        _evaluator.AddInplace(sums[i], tempCt);

                    }
                    innerProductStopwatch.Stop();
                    kernels[i] = sums[i];

                    SvcUtilities.PrintCyprherText(_decryptor, kernels[i], _encoder, $"inner product TotalValue {i}" );
                    SvcUtilities.PrintScale(kernels[i], "0. kernels" + i);
                    if (useRelinearizeInplace)
                    {
                        _evaluator.RelinearizeInplace(kernels[i], _relinKeys);
                    }

                    if (useReScale)
                    {
                        _evaluator.RescaleToNextInplace(kernels[i]);
                    }

                    SvcUtilities.PrintScale(kernels[i], "1. kernels" + i);
                    kernels[i].Scale = scale;


					if(_kernel == Kernel.Poly)
					{

						if (useReScale)
						{
							ParmsId lastParmsId = kernels[i].ParmsId;
							_evaluator.ModSwitchToInplace(gamaPlaintext, lastParmsId);
						}
						_evaluator.MultiplyPlainInplace(kernels[i], gamaPlaintext);
						SvcUtilities.PrintScale(kernels[i], "2. kernels" + i);
						if (useRelinearizeInplace)
						{
							_evaluator.RelinearizeInplace(kernels[i], _relinKeys);
						}

						if (useReScale)
						{
							_evaluator.RescaleToNextInplace(kernels[i]);
						}
						SvcUtilities.PrintScale(kernels[i], "3.  kernels" + i);

						if (Math.Abs(_coef0) > 0)
						{
							Plaintext coef0Plaintext = new Plaintext();
							_encoder.Encode(_coef0, kernels[i].Scale, coef0Plaintext);
							if (useReScale)
							{
								ParmsId lastParmsId = kernels[i].ParmsId;
								_evaluator.ModSwitchToInplace(coef0Plaintext, lastParmsId);
							}

							//kernels[i].Scale = coef0Plaintext.Scale;

							_evaluator.AddPlainInplace(kernels[i], coef0Plaintext);
                        }

                        SvcUtilities.PrintScale(kernels[i], "4.  kernels" + i);

                       

                        degreeStopwatch.Start();

                        var kernel = new Ciphertext(kernels[i]);
                        for (int d = 0; d < (int)_degree-1; d++)
						{

							kernel.Scale = kernels[i].Scale;
							if (useReScale)
							{
								ParmsId lastParmsId = kernels[i].ParmsId;
								_evaluator.ModSwitchToInplace(kernel, lastParmsId);
							}
                            _evaluator.MultiplyInplace(kernels[i], kernel);
							SvcUtilities.PrintScale(kernels[i], d + "  5. kernels" + i);
							if (useRelinearizeInplace)
							{
								_evaluator.RelinearizeInplace(kernels[i], _relinKeys);
							}

							if (useReScale)
							{
								_evaluator.RescaleToNextInplace(kernels[i]);
							}
							SvcUtilities.PrintScale(kernels[i], d + " rescale  6. kernels" + i);
						}
						SvcUtilities.PrintScale(kernels[i], "7. kernels" + i);

						degreeStopwatch.Stop();
                    }


					

					negateStopwatch.Start();

                    _evaluator.NegateInplace(kernels[i]);
                    negateStopwatch.Stop();

                    SvcUtilities.PrintScale(kernels[i], "8. kernel"+i); 

					SvcUtilities.PrintCyprherText(_decryptor, kernels[i], _encoder, "kernel"+i);

				}


                Stopwatch serverDecisionStopWatch = new Stopwatch();

                serverDecisionStopWatch.Start();

                // Encode coefficients : ParmsId! , scale!
                double scale2 = Math.Pow(2.0, _power);
				if (useReScale)
				{
					scale2 = kernels[0].Scale;
				}

				for (int i = 0; i < numOfrowsCount; i++)
				{
					_encoder.Encode(_coefficients[0][i], scale2, coefArr[i]);
					SvcUtilities.PrintScale(coefArr[i], "coefPlainText"+i);
				}



				if (useReScale)
				{
					for (int i = 0; i < numOfrowsCount; i++)
					{
						ParmsId lastParmsId = kernels[i].ParmsId;
						_evaluator.ModSwitchToInplace(coefArr[i], lastParmsId);
					}
				}
				// Level 2
				// Calculate decisionArr
                for (int i = 0; i < numOfrowsCount; i++)
				{
					_evaluator.MultiplyPlain(kernels[i], coefArr[i], decisionsArr[i]);
					if (useRelinearizeInplace)
					{
						_evaluator.RelinearizeInplace(decisionsArr[i], _relinKeys);
					}

					if (useReScale)
					{
						_evaluator.RescaleToNextInplace(decisionsArr[i]);
					}
					SvcUtilities.PrintScale(decisionsArr[i], "decision"+i);
					SvcUtilities.PrintCyprherText(_decryptor, decisionsArr[i], _encoder, "decision" + i);
				}



				// Calculate decisionTotal
				Ciphertext decisionTotal = new Ciphertext();
				//=================================================================
				_evaluator.AddMany(decisionsArr, decisionTotal);
				//=================================================================
			  
				SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); 
				SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "decision total");


				// Encode intercepts : ParmsId! , scale!
				Plaintext interceptsPlainText = new Plaintext();
				
				double scale3 = Math.Pow(2.0, _power*3);
				if (useReScale)
				{
					scale3 = decisionTotal.Scale;
				}
				_encoder.Encode(_intercepts[0], scale3, interceptsPlainText);
				if (useReScale)
				{
					ParmsId lastParmsId = decisionTotal.ParmsId;
					_evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId);
				}

				SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText");
				SvcUtilities.PrintScale(decisionTotal, "decisionTotal");


				//// Calculate finalTotal
				Ciphertext finalTotal = new Ciphertext();

				//=================================================================
				_evaluator.AddPlainInplace(decisionTotal, interceptsPlainText);
				//=================================================================

				SvcUtilities.PrintScale(decisionTotal, "decisionTotal");  //Level 3
				List<double> result = SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "finalTotal",true);

                serverDecisionStopWatch.Stop();

                Console.WriteLine($"client Init elapsed {clientStopwatch.ElapsedMilliseconds} ms");
                Console.WriteLine($"server Init elapsed {serverInitStopwatch.ElapsedMilliseconds} ms");
                Console.WriteLine($"server innerProductStopwatch elapsed {innerProductStopwatch.ElapsedMilliseconds} ms");
                Console.WriteLine($"server negateStopwatch elapsed {negateStopwatch.ElapsedMilliseconds} ms");
                Console.WriteLine($"server degreeStopwatch elapsed {degreeStopwatch.ElapsedMilliseconds} ms");
                Console.WriteLine($"server Decision elapsed {serverDecisionStopWatch.ElapsedMilliseconds} ms");


                finalResult = result[0];

                if (result[0] > 0)
				{
					return 0;
				}

				return 1;

			}