//constructor for debugging , because it enables to pass secretkey which will not happen in real life // special parameters : // batchsize - number of batches sample in one ciphertext ,if the client batches mulitple samples in the ciphertet . // featureSize - number of features in sample private Svc(double[][] vectors, double[][] coefficients, double[] intercepts, String kernel, double gamma, double coef0, ulong degree, int power, PublicKey publicKey, SecretKey secretKey, RelinKeys relinKeys, GaloisKeys galoisKeys, int batchSize, int featureSize) { this._vectors = vectors; this._coefficients = coefficients; this._intercepts = intercepts; this._kernel = (Kernel)System.Enum.Parse(typeof(Kernel), kernel); this._gamma = gamma; this._coef0 = coef0; this._degree = degree; this._power = power; //Use the ckks SCheme EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS); // polyModulusDegree and CoeffModulus used for general SVM algorithm and depends on the polynomial kernel degree // ( and the percision constraint ) . // My implementation can be used up to dgree = 4 , but it can be easly refactored with some oprimizations to higher // polynomial degree ulong polyModulusDegree = 16384; if (power >= 20 && power < 40) { parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 20, 21, 22, 23, 24, 25, 26, 27, 60 }); } else if (power >= 40 && power < 60) { parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 40, 40, 40, 40, 40, 40, 40, 60 }); } else if (power == 60) { polyModulusDegree = 32768; parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 60, 60, 60, 60, 60, 60, 60, 60 }); } parms.PolyModulusDegree = polyModulusDegree; _context = new SEALContext(parms); _publicKey = publicKey; _secretKey = secretKey; _relinKeys = relinKeys; _galoisKeys = galoisKeys; _evaluator = new Evaluator(_context); if (_secretKey != null) { _decryptor = new Decryptor(_context, _secretKey); //FOR DEBUG ONLY ( not used in real server) } _encoder = new CKKSEncoder(_context); Stopwatch serverInitStopwatch = new Stopwatch(); serverInitStopwatch.Start(); _numOfrowsCount = _vectors.Length; //Number of Support Vectors _numOfcolumnsCount = _vectors[0].Length; //Number of features in every Support vector _scale = Math.Pow(2.0, _power); // vars for batch rotations _svPlaintexts = new Plaintext[_numOfrowsCount]; //Encode support vectors _sums = new Ciphertext[_numOfrowsCount]; if (UseBatchInnerProduct) { double[] batchVectors = new double[batchSize * featureSize]; for (int i = 0; i < _numOfrowsCount; i++) { for (int k = 0; k < batchSize * featureSize; k++) { var index0 = k % featureSize; batchVectors[k] = index0 < _numOfcolumnsCount ? _vectors[i][index0] : 0; } _svPlaintexts[i] = new Plaintext(); _encoder.Encode(batchVectors, _scale, _svPlaintexts[i]); SVCUtilities.SvcUtilities.PrintScale(_svPlaintexts[i], "batch supportVectorsPlaintext" + i); _sums[i] = new Ciphertext(); } } else { ///////////////////////////////////////////////////////////////////////// // // vars for simple inner product // // Handle SV _svPlaintextsArr = new Plaintext[_numOfrowsCount, _numOfcolumnsCount]; //Encode SV for (int i = 0; i < _numOfrowsCount; i++) { for (int j = 0; j < _numOfcolumnsCount; j++) { _svPlaintextsArr[i, j] = new Plaintext(); _encoder.Encode(_vectors[i][j] != 0 ? _vectors[i][j] : Zero, _scale, _svPlaintextsArr[i, j]); SvcUtilities.PrintScale(_svPlaintextsArr[i, j], $"supportVectorsPlaintext[{i}][{j}]"); } } // Prepare sum of inner product _innerProdSums = new Ciphertext[_numOfcolumnsCount]; for (int i = 0; i < _numOfcolumnsCount; i++) { _innerProdSums[i] = new Ciphertext(); } ////////////////////////////////////////////////////////////// } // Allocate memory for svm secure calculation _kernels = new Ciphertext[_numOfrowsCount]; _decisionsArr = new Ciphertext[_numOfrowsCount]; _coefArr = new Plaintext[_numOfrowsCount]; for (int i = 0; i < _numOfrowsCount; i++) { _kernels[i] = new Ciphertext(); _decisionsArr[i] = new Ciphertext(); _coefArr[i] = new Plaintext(); } _gamaPlaintext = new Plaintext(); _encoder.Encode(_gamma != 0 ? _gamma : Zero, _scale, _gamaPlaintext); serverInitStopwatch.Stop(); Console.WriteLine($"server Init elapsed {serverInitStopwatch.ElapsedMilliseconds} ms"); }
// Function for classification of samples // I follow the SEAL eamples recomandations to rescale and relinearize after each calculation. // useRelinearizeInplace and useReScale should be always true. // This parametrs are enabled only for debugging and learning purpose. public Ciphertext Predict(Ciphertext featuresCiphertexts, bool useRelinearizeInplace, bool useReScale, Stopwatch innerProductStopwatch, Stopwatch degreeStopwatch, Stopwatch negateStopwatch, Stopwatch serverDecisionStopWatch) { Ciphertext tempCt = new Ciphertext(); // Level 1 for (int i = 0; i < _numOfrowsCount; i++) { //inner product //calculate IP = < x, x'> innerProductStopwatch.Start(); if (UseBatchInnerProduct) { _kernels[i] = InnerProduct(featuresCiphertexts, _svPlaintexts, i, _sums, _numOfcolumnsCount, tempCt); } innerProductStopwatch.Stop(); SvcUtilities.PrintCyprherText(_decryptor, _kernels[i], _encoder, $"inner product TotalValue {i}"); SvcUtilities.PrintScale(_kernels[i], "0. kernels" + i); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(_kernels[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(_kernels[i]); } SvcUtilities.PrintScale(_kernels[i], "1. kernels" + i); _kernels[i].Scale = _scale; //For polynimial kernel calculate if (_kernel == Kernel.Poly) { // calculate (y *IP+r)^d // IP is calculated previously // y = gamma // r = _coef0 if (useReScale) { ParmsId lastParmsId = _kernels[i].ParmsId; _evaluator.ModSwitchToInplace(_gamaPlaintext, lastParmsId); } //calculate y * IP _evaluator.MultiplyPlainInplace(_kernels[i], _gamaPlaintext); SvcUtilities.PrintScale(_kernels[i], "2. kernels" + i); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(_kernels[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(_kernels[i]); } SvcUtilities.PrintScale(_kernels[i], "3. kernels" + i); // add r if (Math.Abs(_coef0) > 0) { Plaintext coef0Plaintext = new Plaintext(); _encoder.Encode(_coef0, _kernels[i].Scale, coef0Plaintext); if (useReScale) { ParmsId lastParmsId = _kernels[i].ParmsId; _evaluator.ModSwitchToInplace(coef0Plaintext, lastParmsId); } //kernels[i].Scale = coef0Plaintext.Scale; _evaluator.AddPlainInplace(_kernels[i], coef0Plaintext); } SvcUtilities.PrintScale(_kernels[i], "4. kernels" + i); degreeStopwatch.Start(); // calculate the polynom degree var kernel = new Ciphertext(_kernels[i]); for (int d = 0; d < (int)_degree - 1; d++) { kernel.Scale = _kernels[i].Scale; if (useReScale) { ParmsId lastParmsId = _kernels[i].ParmsId; _evaluator.ModSwitchToInplace(kernel, lastParmsId); } _evaluator.MultiplyInplace(_kernels[i], kernel); SvcUtilities.PrintScale(_kernels[i], d + " 5. kernels" + i); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(_kernels[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(_kernels[i]); } SvcUtilities.PrintScale(_kernels[i], d + " rescale 6. kernels" + i); } SvcUtilities.PrintScale(_kernels[i], "7. kernels" + i); degreeStopwatch.Stop(); } negateStopwatch.Start(); _evaluator.NegateInplace(_kernels[i]); negateStopwatch.Stop(); SvcUtilities.PrintScale(_kernels[i], "8. kernel" + i); SvcUtilities.PrintCyprherText(_decryptor, _kernels[i], _encoder, "kernel" + i); } serverDecisionStopWatch.Start(); // Encode coefficients : ParmsId! , scale! double scale2 = Math.Pow(2.0, _power); if (useReScale) { scale2 = _kernels[0].Scale; } for (int i = 0; i < _numOfrowsCount; i++) { _encoder.Encode(_coefficients[0][i], scale2, _coefArr[i]); SvcUtilities.PrintScale(_coefArr[i], "coefPlainText" + i); } if (useReScale) { for (int i = 0; i < _numOfrowsCount; i++) { ParmsId lastParmsId = _kernels[i].ParmsId; _evaluator.ModSwitchToInplace(_coefArr[i], lastParmsId); } } // Level 2 // Calculate decisionArr for (int i = 0; i < _numOfrowsCount; i++) { _evaluator.MultiplyPlain(_kernels[i], _coefArr[i], _decisionsArr[i]); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(_decisionsArr[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(_decisionsArr[i]); } SvcUtilities.PrintScale(_decisionsArr[i], "decision" + i); SvcUtilities.PrintCyprherText(_decryptor, _decisionsArr[i], _encoder, "decision" + i); } // Calculate decisionTotal Ciphertext decisionTotal = new Ciphertext(); //================================================================= _evaluator.AddMany(_decisionsArr, decisionTotal); //================================================================= SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "decision total"); // Encode intercepts : ParmsId! , scale! Plaintext interceptsPlainText = new Plaintext(); double scale3 = Math.Pow(2.0, _power * 3); if (useReScale) { scale3 = decisionTotal.Scale; } _encoder.Encode(_intercepts[0], scale3, interceptsPlainText); if (useReScale) { ParmsId lastParmsId = decisionTotal.ParmsId; _evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId); } SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText"); SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //// Calculate finalTotal Ciphertext finalTotal = new Ciphertext(); //================================================================= _evaluator.AddPlainInplace(decisionTotal, interceptsPlainText); //================================================================= SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //Level 3 List <double> result = SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "finalTotal", true); serverDecisionStopWatch.Stop(); long innerProductMilliseconds = innerProductStopwatch.ElapsedMilliseconds; //Console.WriteLine($"server innerProductStopwatch elapsed {innerProductMilliseconds} ms"); long negateMilliseconds = negateStopwatch.ElapsedMilliseconds; //Console.WriteLine($"server negateStopwatch elapsed {negateMilliseconds} ms"); long degreeMilliseconds = degreeStopwatch.ElapsedMilliseconds; //Console.WriteLine($"server degreeStopwatch elapsed {degreeMilliseconds} ms"); long serverDecisionMilliseconds = serverDecisionStopWatch.ElapsedMilliseconds; //Console.WriteLine($"server Decision elapsed {serverDecisionMilliseconds} ms"); return(decisionTotal); }
public int Predict(double[] features, int power, bool useRelinearizeInplace, bool useReScale, Stopwatch timePredictSum) { EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS); if (power < 60) { ulong polyModulusDegree = 8192; parms.PolyModulusDegree = polyModulusDegree; parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 40, 40, 60 }); } else { ulong polyModulusDegree = 16384; parms.PolyModulusDegree = polyModulusDegree; parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 60, 60, 60, 60, 60 }); } // double scale = Math.Pow(2.0, power); SEALContext context = new SEALContext(parms); Console.WriteLine(); KeyGenerator keygen = new KeyGenerator(context); PublicKey publicKey = keygen.PublicKey; SecretKey secretKey = keygen.SecretKey; RelinKeys relinKeys = keygen.RelinKeys(); Encryptor encryptor = new Encryptor(context, publicKey); Evaluator evaluator = new Evaluator(context); Decryptor decryptor = new Decryptor(context, secretKey); CKKSEncoder encoder = new CKKSEncoder(context); ulong slotCount = encoder.SlotCount; Console.WriteLine($"Number of slots: {slotCount}"); timePredictSum.Start(); Plaintext fPlaintext0 = new Plaintext(); Plaintext fPlaintext1 = new Plaintext(); Plaintext fPlaintext2 = new Plaintext(); Plaintext fPlaintext3 = new Plaintext(); encoder.Encode(features[0], scale, fPlaintext0); encoder.Encode(features[1], scale, fPlaintext1); encoder.Encode(features[2], scale, fPlaintext2); encoder.Encode(features[3], scale, fPlaintext3); SvcUtilities.PrintScale(fPlaintext0, "fPlaintext0"); SvcUtilities.PrintScale(fPlaintext1, "fPlaintext1"); SvcUtilities.PrintScale(fPlaintext2, "fPlaintext2"); Ciphertext f0Encrypted = new Ciphertext(); Ciphertext f1Encrypted = new Ciphertext(); Ciphertext f2Encrypted = new Ciphertext(); Ciphertext f3Encrypted = new Ciphertext(); encryptor.Encrypt(fPlaintext0, f0Encrypted); encryptor.Encrypt(fPlaintext1, f1Encrypted); encryptor.Encrypt(fPlaintext2, f2Encrypted); encryptor.Encrypt(fPlaintext3, f3Encrypted); SvcUtilities.PrintScale(f0Encrypted, "f0Encrypted"); SvcUtilities.PrintScale(f1Encrypted, "f1Encrypted"); SvcUtilities.PrintScale(f2Encrypted, "f2Encrypted"); Plaintext v00Plaintext1 = new Plaintext(); Plaintext v01Plaintext1 = new Plaintext(); Plaintext v02Plaintext1 = new Plaintext(); Plaintext v03Plaintext1 = new Plaintext(); Plaintext v10Plaintext1 = new Plaintext(); Plaintext v11Plaintext1 = new Plaintext(); Plaintext v12Plaintext1 = new Plaintext(); Plaintext v13Plaintext1 = new Plaintext(); Plaintext v20Plaintext1 = new Plaintext(); Plaintext v21Plaintext1 = new Plaintext(); Plaintext v22Plaintext1 = new Plaintext(); Plaintext v23Plaintext1 = new Plaintext(); encoder.Encode(_vectors[0][0], scale, v00Plaintext1); encoder.Encode(_vectors[0][1], scale, v01Plaintext1); encoder.Encode(_vectors[0][2], scale, v02Plaintext1); encoder.Encode(_vectors[0][3], scale, v03Plaintext1); encoder.Encode(_vectors[1][0], scale, v10Plaintext1); encoder.Encode(_vectors[1][1], scale, v11Plaintext1); encoder.Encode(_vectors[1][2], scale, v12Plaintext1); encoder.Encode(_vectors[1][3], scale, v13Plaintext1); encoder.Encode(_vectors[2][0], scale, v20Plaintext1); encoder.Encode(_vectors[2][1], scale, v21Plaintext1); encoder.Encode(_vectors[2][2], scale, v22Plaintext1); encoder.Encode(_vectors[2][3], scale, v23Plaintext1); SvcUtilities.PrintScale(v00Plaintext1, "v00Plaintext1"); SvcUtilities.PrintScale(v01Plaintext1, "v01Plaintext1"); SvcUtilities.PrintScale(v02Plaintext1, "v02Plaintext1"); SvcUtilities.PrintScale(v03Plaintext1, "v03Plaintext1"); SvcUtilities.PrintScale(v10Plaintext1, "v10Plaintext1"); SvcUtilities.PrintScale(v11Plaintext1, "v11Plaintext1"); SvcUtilities.PrintScale(v12Plaintext1, "v12Plaintext1"); SvcUtilities.PrintScale(v13Plaintext1, "v13Plaintext1"); SvcUtilities.PrintScale(v20Plaintext1, "v20Plaintext1"); SvcUtilities.PrintScale(v21Plaintext1, "v21Plaintext1"); SvcUtilities.PrintScale(v22Plaintext1, "v22Plaintext1"); SvcUtilities.PrintScale(v23Plaintext1, "v23Plaintext1"); Plaintext coef00PlainText = new Plaintext(); Plaintext coef01PlainText = new Plaintext(); Plaintext coef02PlainText = new Plaintext(); Ciphertext tSum1 = new Ciphertext(); Ciphertext tSum2 = new Ciphertext(); Ciphertext tSum3 = new Ciphertext(); Ciphertext tSum4 = new Ciphertext(); Ciphertext kernel0 = new Ciphertext(); //Level 1->2 //================================================================= evaluator.MultiplyPlain(f0Encrypted, v00Plaintext1, tSum1); evaluator.MultiplyPlain(f1Encrypted, v01Plaintext1, tSum2); evaluator.MultiplyPlain(f2Encrypted, v02Plaintext1, tSum3); evaluator.MultiplyPlain(f3Encrypted, v03Plaintext1, tSum4); //================================================================= if (useRelinearizeInplace) { Console.WriteLine("RelinearizeInplace sums 1"); evaluator.RelinearizeInplace(tSum1, relinKeys); evaluator.RelinearizeInplace(tSum2, relinKeys); evaluator.RelinearizeInplace(tSum3, relinKeys); evaluator.RelinearizeInplace(tSum4, relinKeys); } if (useReScale) { Console.WriteLine("useReScale sums 1"); evaluator.RescaleToNextInplace(tSum1); evaluator.RescaleToNextInplace(tSum2); evaluator.RescaleToNextInplace(tSum3); evaluator.RescaleToNextInplace(tSum4); } SvcUtilities.PrintScale(tSum1, "tSum1"); //Level 2 SvcUtilities.PrintScale(tSum2, "tSum2"); //Level 2 SvcUtilities.PrintScale(tSum3, "tSum3"); //Level 2 SvcUtilities.PrintScale(tSum4, "tSum4"); //Level 2 var ciphertexts1 = new List <Ciphertext>(); ciphertexts1.Add(tSum1); ciphertexts1.Add(tSum2); ciphertexts1.Add(tSum3); ciphertexts1.Add(tSum4); //================================================================= evaluator.AddMany(ciphertexts1, kernel0); //Level 2 //================================================================= SvcUtilities.PrintScale(kernel0, "kernel0"); //Level 2 SvcUtilities.PrintCyprherText(decryptor, kernel0, encoder, "kernel0"); Ciphertext kernel1 = new Ciphertext(); //Level 1-> 2 //================================================================= evaluator.MultiplyPlain(f0Encrypted, v10Plaintext1, tSum1); evaluator.MultiplyPlain(f1Encrypted, v11Plaintext1, tSum2); evaluator.MultiplyPlain(f2Encrypted, v12Plaintext1, tSum3); evaluator.MultiplyPlain(f3Encrypted, v13Plaintext1, tSum4); //================================================================= if (useRelinearizeInplace) { Console.WriteLine("RelinearizeInplace sums 2"); evaluator.RelinearizeInplace(tSum1, relinKeys); evaluator.RelinearizeInplace(tSum2, relinKeys); evaluator.RelinearizeInplace(tSum3, relinKeys); evaluator.RelinearizeInplace(tSum4, relinKeys); } if (useReScale) { Console.WriteLine("useReScale sums 2"); evaluator.RescaleToNextInplace(tSum1); evaluator.RescaleToNextInplace(tSum2); evaluator.RescaleToNextInplace(tSum3); evaluator.RescaleToNextInplace(tSum4); } ciphertexts1.Add(tSum1); ciphertexts1.Add(tSum2); ciphertexts1.Add(tSum3); ciphertexts1.Add(tSum4); Console.WriteLine("Second time : "); SvcUtilities.PrintScale(tSum1, "tSum1"); //Level 2 SvcUtilities.PrintScale(tSum2, "tSum2"); //Level 2 SvcUtilities.PrintScale(tSum3, "tSum3"); //Level 2 SvcUtilities.PrintScale(tSum4, "tSum4"); //Level 2 var ciphertexts2 = new List <Ciphertext>(); ciphertexts2.Add(tSum1); ciphertexts2.Add(tSum2); ciphertexts2.Add(tSum3); ciphertexts2.Add(tSum4); //================================================================= evaluator.AddMany(ciphertexts2, kernel1); // Level 2 //================================================================= SvcUtilities.PrintScale(kernel1, "kernel1"); SvcUtilities.PrintCyprherText(decryptor, kernel1, encoder, "kernel1"); Ciphertext kernel2 = new Ciphertext(); //Level 1->2 //================================================================= evaluator.MultiplyPlain(f0Encrypted, v20Plaintext1, tSum1); evaluator.MultiplyPlain(f1Encrypted, v21Plaintext1, tSum2); evaluator.MultiplyPlain(f2Encrypted, v22Plaintext1, tSum3); evaluator.MultiplyPlain(f3Encrypted, v23Plaintext1, tSum4); //================================================================= if (useRelinearizeInplace) { Console.WriteLine("RelinearizeInplace sums 3"); evaluator.RelinearizeInplace(tSum1, relinKeys); evaluator.RelinearizeInplace(tSum2, relinKeys); evaluator.RelinearizeInplace(tSum3, relinKeys); evaluator.RelinearizeInplace(tSum4, relinKeys); } if (useReScale) { Console.WriteLine("useReScale sums 3"); evaluator.RescaleToNextInplace(tSum1); evaluator.RescaleToNextInplace(tSum2); evaluator.RescaleToNextInplace(tSum3); evaluator.RescaleToNextInplace(tSum4); } var ciphertexts3 = new List <Ciphertext>(); ciphertexts3.Add(tSum1); ciphertexts3.Add(tSum2); ciphertexts3.Add(tSum3); ciphertexts3.Add(tSum4); Console.WriteLine("Third time : "); SvcUtilities.PrintScale(tSum1, "tSum1"); //Level 2 SvcUtilities.PrintScale(tSum2, "tSum2"); //Level 2 SvcUtilities.PrintScale(tSum3, "tSum3"); //Level 2 SvcUtilities.PrintScale(tSum4, "tSum4"); //Level 2 //================================================================= evaluator.AddMany(ciphertexts3, kernel2); //================================================================= SvcUtilities.PrintScale(kernel2, "kernel2"); //Level 2 SvcUtilities.PrintCyprherText(decryptor, kernel2, encoder, "kernel2"); Ciphertext decision1 = new Ciphertext(); Ciphertext decision2 = new Ciphertext(); Ciphertext decision3 = new Ciphertext(); SvcUtilities.PrintScale(decision1, "decision1"); //Level 0 SvcUtilities.PrintScale(decision2, "decision2"); //Level 0 SvcUtilities.PrintScale(decision3, "decision3"); //Level 0 Ciphertext nKernel0 = new Ciphertext(); Ciphertext nKernel1 = new Ciphertext(); Ciphertext nKernel2 = new Ciphertext(); //================================================================= evaluator.Negate(kernel0, nKernel0); evaluator.Negate(kernel1, nKernel1); //Level 2 evaluator.Negate(kernel2, nKernel2); //Level 2 //================================================================= //nKernel0.Scale = scale; //nKernel1.Scale = scale; //nKernel2.Scale = scale; double scale2 = Math.Pow(2.0, power); if (useReScale) { scale2 = nKernel0.Scale; } encoder.Encode(_coefficients[0][0], scale2, coef00PlainText); encoder.Encode(_coefficients[0][1], scale2, coef01PlainText); encoder.Encode(_coefficients[0][2], scale2, coef02PlainText); SvcUtilities.PrintScale(coef00PlainText, "coef00PlainText"); SvcUtilities.PrintScale(coef01PlainText, "coef01PlainText"); SvcUtilities.PrintScale(coef02PlainText, "coef02PlainText"); if (useReScale) { ParmsId lastParmsId = nKernel0.ParmsId; evaluator.ModSwitchToInplace(coef00PlainText, lastParmsId); lastParmsId = nKernel1.ParmsId; evaluator.ModSwitchToInplace(coef01PlainText, lastParmsId); lastParmsId = nKernel2.ParmsId; evaluator.ModSwitchToInplace(coef02PlainText, lastParmsId); } SvcUtilities.PrintScale(nKernel0, "nKernel0"); //Level 2 SvcUtilities.PrintScale(nKernel1, "nKernel1"); //Level 2 SvcUtilities.PrintScale(nKernel2, "nKernel2"); //Level 2 //Level 2->3 //================================================================= evaluator.MultiplyPlain(nKernel0, coef00PlainText, decision1); evaluator.MultiplyPlain(nKernel1, coef01PlainText, decision2); evaluator.MultiplyPlain(nKernel2, coef02PlainText, decision3); //================================================================= if (useRelinearizeInplace) { Console.WriteLine("RelinearizeInplace decisions"); evaluator.RelinearizeInplace(decision1, relinKeys); evaluator.RelinearizeInplace(decision2, relinKeys); evaluator.RelinearizeInplace(decision3, relinKeys); } if (useReScale) { Console.WriteLine("Rescale decisions"); evaluator.RescaleToNextInplace(decision1); evaluator.RescaleToNextInplace(decision2); evaluator.RescaleToNextInplace(decision3); } SvcUtilities.PrintScale(decision1, "decision1"); //Level 3 SvcUtilities.PrintScale(decision2, "decision2"); //Level 3 SvcUtilities.PrintScale(decision3, "decision3"); //Level 3 SvcUtilities.PrintCyprherText(decryptor, decision1, encoder, "decision1"); SvcUtilities.PrintCyprherText(decryptor, decision2, encoder, "decision2"); SvcUtilities.PrintCyprherText(decryptor, decision3, encoder, "decision3"); //================================================================= //evaluator.RelinearizeInplace(decision1,keygen.RelinKeys()); //evaluator.RelinearizeInplace(decision2, keygen.RelinKeys()); //evaluator.RelinearizeInplace(decision3, keygen.RelinKeys()); //================================================================= //PrintScale(decision1, "decision1"); var decisions = new List <Ciphertext>(); decisions.Add(decision1); decisions.Add(decision2); decisions.Add(decision3); Ciphertext decisionTotal = new Ciphertext(); //================================================================= evaluator.AddMany(decisions, decisionTotal); //================================================================= SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "decision total"); Ciphertext finalTotal = new Ciphertext(); Plaintext interceptsPlainText = new Plaintext(); double scale3 = Math.Pow(2.0, power * 3); if (useReScale) { scale3 = decisionTotal.Scale; } encoder.Encode(_intercepts[0], scale3, interceptsPlainText); if (useReScale) { ParmsId lastParmsId = decisionTotal.ParmsId; evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId); } SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText"); SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //================================================================= evaluator.AddPlainInplace(decisionTotal, interceptsPlainText); //================================================================= timePredictSum.Stop(); SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //Level 3 List <double> result = SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "finalTotal"); using (System.IO.StreamWriter file = new System.IO.StreamWriter( $@"{OutputDir}IrisSimple_IrisSecureSVC_total_{power}_{useRelinearizeInplace}_{useReScale}.txt", !_firstTime) ) { _firstTime = false; file.WriteLine($"{result[0]}"); } if (result[0] > 0) { return(0); } return(1); }
public int Predict(double[] features, int power, bool useRelinearizeInplace, bool useReScale, Stopwatch timePredictSum) { EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS); if (power < 60) { ulong polyModulusDegree = 8192; parms.PolyModulusDegree = polyModulusDegree; parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 40, 40, 60 }); } else { ulong polyModulusDegree = 16384; parms.PolyModulusDegree = polyModulusDegree; parms.CoeffModulus = CoeffModulus.Create(polyModulusDegree, new int[] { 60, 60, 60, 60, 60, 60 }); } // double scale = Math.Pow(2.0, power); SEALContext context = new SEALContext(parms); Console.WriteLine(); KeyGenerator keygen = new KeyGenerator(context); PublicKey publicKey = keygen.PublicKey; SecretKey secretKey = keygen.SecretKey; RelinKeys relinKeys = keygen.RelinKeys(); Encryptor encryptor = new Encryptor(context, publicKey); Evaluator evaluator = new Evaluator(context); Decryptor decryptor = new Decryptor(context, secretKey); CKKSEncoder encoder = new CKKSEncoder(context); ulong slotCount = encoder.SlotCount; Console.WriteLine($"Number of slots: {slotCount}"); timePredictSum.Start(); var featuresLength = features.Length; var plaintexts = new Plaintext[featuresLength]; var featuresCiphertexts = new Ciphertext[featuresLength]; //Encode and encrypt features for (int i = 0; i < featuresLength; i++) { plaintexts[i] = new Plaintext(); encoder.Encode(features[i], scale, plaintexts[i]); SvcUtilities.PrintScale(plaintexts[i], "featurePlaintext" + i); featuresCiphertexts[i] = new Ciphertext(); encryptor.Encrypt(plaintexts[i], featuresCiphertexts[i]); SvcUtilities.PrintScale(featuresCiphertexts[i], "featurefEncrypted" + i); } // Handle SV var numOfrows = _vectors.Length; var numOfcolumns = _vectors[0].Length; var svPlaintexts = new Plaintext[numOfrows, numOfcolumns]; //Encode SV for (int i = 0; i < numOfrows; i++) { for (int j = 0; j < numOfcolumns; j++) { svPlaintexts[i, j] = new Plaintext(); encoder.Encode(_vectors[i][j], scale, svPlaintexts[i, j]); SvcUtilities.PrintScale(svPlaintexts[i, j], "supportVectorsPlaintext" + i + j); } } // Prepare sum of inner product var sums = new Ciphertext[numOfcolumns]; for (int i = 0; i < numOfcolumns; i++) { sums[i] = new Ciphertext(); } var kernels = new Ciphertext[numOfrows]; var decisionsArr = new Ciphertext[numOfrows]; var coefArr = new Plaintext [numOfrows]; for (int i = 0; i < numOfrows; i++) { kernels[i] = new Ciphertext(); decisionsArr[i] = new Ciphertext(); coefArr[i] = new Plaintext(); } // Level 1 for (int i = 0; i < numOfrows; i++) { var ciphertexts = new List <Ciphertext>(); //inner product for (int j = 0; j < numOfcolumns; j++) { evaluator.MultiplyPlain(featuresCiphertexts[j], svPlaintexts[i, j], sums[j]); if (useRelinearizeInplace) { evaluator.RelinearizeInplace(sums[j], relinKeys); } if (useReScale) { evaluator.RescaleToNextInplace(sums[j]); } SvcUtilities.PrintScale(sums[j], "tSum" + j); } evaluator.AddMany(sums, kernels[i]); evaluator.NegateInplace(kernels[i]); SvcUtilities.PrintScale(kernels[i], "kernel" + i); SvcUtilities.PrintCyprherText(decryptor, kernels[i], encoder, "kernel" + i); } // Encode coefficients : ParmsId! , scale! double scale2 = Math.Pow(2.0, power); if (useReScale) { scale2 = kernels[0].Scale; } for (int i = 0; i < numOfrows; i++) { encoder.Encode(_coefficients[0][i], scale2, coefArr[i]); SvcUtilities.PrintScale(coefArr[i], "coefPlainText+i"); } if (useReScale) { for (int i = 0; i < numOfrows; i++) { ParmsId lastParmsId = kernels[i].ParmsId; evaluator.ModSwitchToInplace(coefArr[i], lastParmsId); } } // Level 2 // Calculate decisionArr for (int i = 0; i < numOfrows; i++) { evaluator.MultiplyPlain(kernels[i], coefArr[i], decisionsArr[i]); if (useRelinearizeInplace) { evaluator.RelinearizeInplace(decisionsArr[i], relinKeys); } if (useReScale) { evaluator.RescaleToNextInplace(decisionsArr[i]); } SvcUtilities.PrintScale(decisionsArr[i], "decision" + i); SvcUtilities.PrintCyprherText(decryptor, decisionsArr[i], encoder, "decision" + i); } // Calculate decisionTotal Ciphertext decisionTotal = new Ciphertext(); //================================================================= evaluator.AddMany(decisionsArr, decisionTotal); //================================================================= SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "decision total"); // Encode intercepts : ParmsId! , scale! Plaintext interceptsPlainText = new Plaintext(); double scale3 = Math.Pow(2.0, power * 3); if (useReScale) { scale3 = decisionTotal.Scale; } encoder.Encode(_intercepts[0], scale3, interceptsPlainText); if (useReScale) { ParmsId lastParmsId = decisionTotal.ParmsId; evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId); } SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText"); SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //// Calculate finalTotal Ciphertext finalTotal = new Ciphertext(); //================================================================= evaluator.AddPlainInplace(decisionTotal, interceptsPlainText); //================================================================= timePredictSum.Stop(); SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //Level 3 List <double> result = SvcUtilities.PrintCyprherText(decryptor, decisionTotal, encoder, "finalTotal"); using (System.IO.StreamWriter file = new System.IO.StreamWriter( $@"{OutputDir}IrisLinear_IrisSecureSVC_total_{power}_{useRelinearizeInplace}_{useReScale}.txt", !_firstTime) ) { _firstTime = false; file.WriteLine($"{result[0]}"); } if (result[0] > 0) { return(0); } return(1); }
public int Predict(double[] features, bool useRelinearizeInplace,bool useReScale,out double finalResult) { Console.WriteLine(); ulong slotCount = _encoder.SlotCount; //Console.WriteLine($"Number of slots: {slotCount}"); var featuresLength = features.Length; var plaintexts = new Plaintext(); var featuresCiphertexts = new Ciphertext(); Stopwatch clientStopwatch = new Stopwatch(); clientStopwatch.Start(); //Encode and encrypt features double scale = Math.Pow(2.0, _power); _encoder.Encode(features, scale, plaintexts); _encryptor.Encrypt(plaintexts, featuresCiphertexts); SvcUtilities.PrintScale(plaintexts, "featurePlaintext"); SvcUtilities.PrintScale(featuresCiphertexts, "featurefEncrypted"); clientStopwatch.Stop(); Stopwatch serverInitStopwatch = new Stopwatch(); serverInitStopwatch.Start(); // Handle SV var numOfrowsCount = _vectors.Length; var numOfcolumnsCount = _vectors[0].Length; var svPlaintexts = new Plaintext[numOfrowsCount]; //Encode SV var sums = new Ciphertext[numOfrowsCount]; for (int i = 0; i < numOfrowsCount; i++) { svPlaintexts[i] = new Plaintext(); _encoder.Encode(_vectors[i], scale, svPlaintexts[i]); SvcUtilities.PrintScale(svPlaintexts[i], "supportVectorsPlaintext"+i); sums[i] = new Ciphertext(); } var kernels = new Ciphertext[numOfrowsCount]; var decisionsArr = new Ciphertext[numOfrowsCount]; var coefArr = new Plaintext [numOfrowsCount]; for (int i = 0; i < numOfrowsCount; i++) { kernels[i] = new Ciphertext(); decisionsArr[i] = new Ciphertext(); coefArr[i] = new Plaintext(); } Plaintext gamaPlaintext= new Plaintext(); _encoder.Encode(_gamma, scale, gamaPlaintext); Ciphertext tempCt = new Ciphertext(); serverInitStopwatch.Stop(); Stopwatch innerProductStopwatch = new Stopwatch(); Stopwatch negateStopwatch = new Stopwatch(); Stopwatch degreeStopwatch = new Stopwatch(); // Level 1 for (int i = 0; i < numOfrowsCount; i++) { //Console.WriteLine(i); //inner product innerProductStopwatch.Start(); _evaluator.MultiplyPlain(featuresCiphertexts, svPlaintexts[i],sums[i]); int numOfRotations = (int)Math.Ceiling(Math.Log2(numOfcolumnsCount)); for (int k = 1,m=1; m <= numOfRotations/*(int)encoder.SlotCount/2*/; k <<= 1,m++) { _evaluator.RotateVector(sums[i], k, _galoisKeys, tempCt); _evaluator.AddInplace(sums[i], tempCt); } innerProductStopwatch.Stop(); kernels[i] = sums[i]; SvcUtilities.PrintCyprherText(_decryptor, kernels[i], _encoder, $"inner product TotalValue {i}" ); SvcUtilities.PrintScale(kernels[i], "0. kernels" + i); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(kernels[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(kernels[i]); } SvcUtilities.PrintScale(kernels[i], "1. kernels" + i); kernels[i].Scale = scale; if(_kernel == Kernel.Poly) { if (useReScale) { ParmsId lastParmsId = kernels[i].ParmsId; _evaluator.ModSwitchToInplace(gamaPlaintext, lastParmsId); } _evaluator.MultiplyPlainInplace(kernels[i], gamaPlaintext); SvcUtilities.PrintScale(kernels[i], "2. kernels" + i); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(kernels[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(kernels[i]); } SvcUtilities.PrintScale(kernels[i], "3. kernels" + i); if (Math.Abs(_coef0) > 0) { Plaintext coef0Plaintext = new Plaintext(); _encoder.Encode(_coef0, kernels[i].Scale, coef0Plaintext); if (useReScale) { ParmsId lastParmsId = kernels[i].ParmsId; _evaluator.ModSwitchToInplace(coef0Plaintext, lastParmsId); } //kernels[i].Scale = coef0Plaintext.Scale; _evaluator.AddPlainInplace(kernels[i], coef0Plaintext); } SvcUtilities.PrintScale(kernels[i], "4. kernels" + i); degreeStopwatch.Start(); var kernel = new Ciphertext(kernels[i]); for (int d = 0; d < (int)_degree-1; d++) { kernel.Scale = kernels[i].Scale; if (useReScale) { ParmsId lastParmsId = kernels[i].ParmsId; _evaluator.ModSwitchToInplace(kernel, lastParmsId); } _evaluator.MultiplyInplace(kernels[i], kernel); SvcUtilities.PrintScale(kernels[i], d + " 5. kernels" + i); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(kernels[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(kernels[i]); } SvcUtilities.PrintScale(kernels[i], d + " rescale 6. kernels" + i); } SvcUtilities.PrintScale(kernels[i], "7. kernels" + i); degreeStopwatch.Stop(); } negateStopwatch.Start(); _evaluator.NegateInplace(kernels[i]); negateStopwatch.Stop(); SvcUtilities.PrintScale(kernels[i], "8. kernel"+i); SvcUtilities.PrintCyprherText(_decryptor, kernels[i], _encoder, "kernel"+i); } Stopwatch serverDecisionStopWatch = new Stopwatch(); serverDecisionStopWatch.Start(); // Encode coefficients : ParmsId! , scale! double scale2 = Math.Pow(2.0, _power); if (useReScale) { scale2 = kernels[0].Scale; } for (int i = 0; i < numOfrowsCount; i++) { _encoder.Encode(_coefficients[0][i], scale2, coefArr[i]); SvcUtilities.PrintScale(coefArr[i], "coefPlainText"+i); } if (useReScale) { for (int i = 0; i < numOfrowsCount; i++) { ParmsId lastParmsId = kernels[i].ParmsId; _evaluator.ModSwitchToInplace(coefArr[i], lastParmsId); } } // Level 2 // Calculate decisionArr for (int i = 0; i < numOfrowsCount; i++) { _evaluator.MultiplyPlain(kernels[i], coefArr[i], decisionsArr[i]); if (useRelinearizeInplace) { _evaluator.RelinearizeInplace(decisionsArr[i], _relinKeys); } if (useReScale) { _evaluator.RescaleToNextInplace(decisionsArr[i]); } SvcUtilities.PrintScale(decisionsArr[i], "decision"+i); SvcUtilities.PrintCyprherText(_decryptor, decisionsArr[i], _encoder, "decision" + i); } // Calculate decisionTotal Ciphertext decisionTotal = new Ciphertext(); //================================================================= _evaluator.AddMany(decisionsArr, decisionTotal); //================================================================= SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "decision total"); // Encode intercepts : ParmsId! , scale! Plaintext interceptsPlainText = new Plaintext(); double scale3 = Math.Pow(2.0, _power*3); if (useReScale) { scale3 = decisionTotal.Scale; } _encoder.Encode(_intercepts[0], scale3, interceptsPlainText); if (useReScale) { ParmsId lastParmsId = decisionTotal.ParmsId; _evaluator.ModSwitchToInplace(interceptsPlainText, lastParmsId); } SvcUtilities.PrintScale(interceptsPlainText, "interceptsPlainText"); SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //// Calculate finalTotal Ciphertext finalTotal = new Ciphertext(); //================================================================= _evaluator.AddPlainInplace(decisionTotal, interceptsPlainText); //================================================================= SvcUtilities.PrintScale(decisionTotal, "decisionTotal"); //Level 3 List<double> result = SvcUtilities.PrintCyprherText(_decryptor, decisionTotal, _encoder, "finalTotal",true); serverDecisionStopWatch.Stop(); Console.WriteLine($"client Init elapsed {clientStopwatch.ElapsedMilliseconds} ms"); Console.WriteLine($"server Init elapsed {serverInitStopwatch.ElapsedMilliseconds} ms"); Console.WriteLine($"server innerProductStopwatch elapsed {innerProductStopwatch.ElapsedMilliseconds} ms"); Console.WriteLine($"server negateStopwatch elapsed {negateStopwatch.ElapsedMilliseconds} ms"); Console.WriteLine($"server degreeStopwatch elapsed {degreeStopwatch.ElapsedMilliseconds} ms"); Console.WriteLine($"server Decision elapsed {serverDecisionStopWatch.ElapsedMilliseconds} ms"); finalResult = result[0]; if (result[0] > 0) { return 0; } return 1; }