Exemple #1
0
        public static IPredictorProducing <Float> Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            var         predictor = new LinearBinaryPredictor(env, ctx);
            ICalibrator calibrator;

            ctx.LoadModelOrNull <ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
            if (calibrator == null)
            {
                return(predictor);
            }
            if (calibrator is IParameterMixer)
            {
                return(new ParameterMixingCalibratedPredictor(env, predictor, calibrator));
            }
            return(new SchemaBindableCalibratedPredictor(env, predictor, calibrator));
        }
Exemple #2
0
        public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(ch);

            if (!string.IsNullOrEmpty(inputModelFile))
            {
                ch.Trace("Constructing predictor from input model");
                using (var file = env.OpenInputFile(inputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, ch))
                        {
                            ch.Trace("Loading predictor");
                            return(ModelLoadContext.LoadModelOrNull <IPredictor, SignatureLoadModel>(env, out inputPredictor, rep, ModelFileUtils.DirPredictor));
                        }
            }

            inputPredictor = null;
            return(false);
        }
            public TransformInfo(IHostEnvironment env, ModelLoadContext ctx, int colValueCount, string directoryName)
            {
                env.AssertValue(env);
                env.Assert(colValueCount > 0);

                // *** Binary format ***
                // int: d (number of untransformed features)
                // int: NewDim (number of transformed features)
                // bool: UseSin
                // uint[4]: the seeds for the pseudo random number generator.

                SrcDim = ctx.Reader.ReadInt32();
                env.CheckDecode(SrcDim == colValueCount);

                NewDim = ctx.Reader.ReadInt32();
                env.CheckDecode(NewDim > 0);

                _useSin = ctx.Reader.ReadBoolByte();

                var length = ctx.Reader.ReadInt32();

                env.CheckDecode(length == 4);
                _state = TauswortheHybrid.State.Load(ctx.Reader);
                _rand  = new TauswortheHybrid(_state);

                env.CheckDecode(ctx.Repository != null &&
                                ctx.LoadModelOrNull <IFourierDistributionSampler, SignatureLoadModel>(env, out _matrixGenerator, directoryName));

                // initialize the transform matrix
                int roundedUpD           = RoundUp(NewDim, _cfltAlign);
                int roundedUpNumFeatures = RoundUp(SrcDim, _cfltAlign);

                RndFourierVectors = new AlignedArray(roundedUpD * roundedUpNumFeatures, CpuMathUtils.GetVectorAlignment());
                RotationTerms     = _useSin ? null : new AlignedArray(roundedUpD, CpuMathUtils.GetVectorAlignment());
                InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD);
            }