/// <summary>
        /// Unmarshaller the response from the service to the response class.
        /// </summary>
        /// <param name="context"></param>
        /// <returns></returns>
        public override AmazonWebServiceResponse Unmarshall(JsonUnmarshallerContext context)
        {
            DescribeTrainingJobResponse response = new DescribeTrainingJobResponse();

            context.Read();
            int targetDepth = context.CurrentDepth;

            while (context.ReadAtDepth(targetDepth))
            {
                if (context.TestExpression("AlgorithmSpecification", targetDepth))
                {
                    var unmarshaller = AlgorithmSpecificationUnmarshaller.Instance;
                    response.AlgorithmSpecification = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("CreationTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.CreationTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("FailureReason", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.FailureReason = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("HyperParameters", targetDepth))
                {
                    var unmarshaller = new DictionaryUnmarshaller <string, string, StringUnmarshaller, StringUnmarshaller>(StringUnmarshaller.Instance, StringUnmarshaller.Instance);
                    response.HyperParameters = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("InputDataConfig", targetDepth))
                {
                    var unmarshaller = new ListUnmarshaller <Channel, ChannelUnmarshaller>(ChannelUnmarshaller.Instance);
                    response.InputDataConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("LastModifiedTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.LastModifiedTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("ModelArtifacts", targetDepth))
                {
                    var unmarshaller = ModelArtifactsUnmarshaller.Instance;
                    response.ModelArtifacts = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("OutputDataConfig", targetDepth))
                {
                    var unmarshaller = OutputDataConfigUnmarshaller.Instance;
                    response.OutputDataConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("ResourceConfig", targetDepth))
                {
                    var unmarshaller = ResourceConfigUnmarshaller.Instance;
                    response.ResourceConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("RoleArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.RoleArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("SecondaryStatus", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.SecondaryStatus = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("SecondaryStatusTransitions", targetDepth))
                {
                    var unmarshaller = new ListUnmarshaller <SecondaryStatusTransition, SecondaryStatusTransitionUnmarshaller>(SecondaryStatusTransitionUnmarshaller.Instance);
                    response.SecondaryStatusTransitions = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("StoppingCondition", targetDepth))
                {
                    var unmarshaller = StoppingConditionUnmarshaller.Instance;
                    response.StoppingCondition = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingEndTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.TrainingEndTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingJobArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TrainingJobArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingJobName", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TrainingJobName = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingJobStatus", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TrainingJobStatus = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingStartTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.TrainingStartTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TuningJobArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TuningJobArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("VpcConfig", targetDepth))
                {
                    var unmarshaller = VpcConfigUnmarshaller.Instance;
                    response.VpcConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
            }

            return(response);
        }
        /// <summary>
        /// Unmarshaller the response from the service to the response class.
        /// </summary>
        /// <param name="context"></param>
        /// <returns></returns>
        public override AmazonWebServiceResponse Unmarshall(JsonUnmarshallerContext context)
        {
            DescribeTrainingJobResponse response = new DescribeTrainingJobResponse();

            context.Read();
            int targetDepth = context.CurrentDepth;

            while (context.ReadAtDepth(targetDepth))
            {
                if (context.TestExpression("AlgorithmSpecification", targetDepth))
                {
                    var unmarshaller = AlgorithmSpecificationUnmarshaller.Instance;
                    response.AlgorithmSpecification = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("AutoMLJobArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.AutoMLJobArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("BillableTimeInSeconds", targetDepth))
                {
                    var unmarshaller = IntUnmarshaller.Instance;
                    response.BillableTimeInSeconds = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("CheckpointConfig", targetDepth))
                {
                    var unmarshaller = CheckpointConfigUnmarshaller.Instance;
                    response.CheckpointConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("CreationTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.CreationTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("DebugHookConfig", targetDepth))
                {
                    var unmarshaller = DebugHookConfigUnmarshaller.Instance;
                    response.DebugHookConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("DebugRuleConfigurations", targetDepth))
                {
                    var unmarshaller = new ListUnmarshaller <DebugRuleConfiguration, DebugRuleConfigurationUnmarshaller>(DebugRuleConfigurationUnmarshaller.Instance);
                    response.DebugRuleConfigurations = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("DebugRuleEvaluationStatuses", targetDepth))
                {
                    var unmarshaller = new ListUnmarshaller <DebugRuleEvaluationStatus, DebugRuleEvaluationStatusUnmarshaller>(DebugRuleEvaluationStatusUnmarshaller.Instance);
                    response.DebugRuleEvaluationStatuses = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("EnableInterContainerTrafficEncryption", targetDepth))
                {
                    var unmarshaller = BoolUnmarshaller.Instance;
                    response.EnableInterContainerTrafficEncryption = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("EnableManagedSpotTraining", targetDepth))
                {
                    var unmarshaller = BoolUnmarshaller.Instance;
                    response.EnableManagedSpotTraining = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("EnableNetworkIsolation", targetDepth))
                {
                    var unmarshaller = BoolUnmarshaller.Instance;
                    response.EnableNetworkIsolation = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("ExperimentConfig", targetDepth))
                {
                    var unmarshaller = ExperimentConfigUnmarshaller.Instance;
                    response.ExperimentConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("FailureReason", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.FailureReason = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("FinalMetricDataList", targetDepth))
                {
                    var unmarshaller = new ListUnmarshaller <MetricData, MetricDataUnmarshaller>(MetricDataUnmarshaller.Instance);
                    response.FinalMetricDataList = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("HyperParameters", targetDepth))
                {
                    var unmarshaller = new DictionaryUnmarshaller <string, string, StringUnmarshaller, StringUnmarshaller>(StringUnmarshaller.Instance, StringUnmarshaller.Instance);
                    response.HyperParameters = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("InputDataConfig", targetDepth))
                {
                    var unmarshaller = new ListUnmarshaller <Channel, ChannelUnmarshaller>(ChannelUnmarshaller.Instance);
                    response.InputDataConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("LabelingJobArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.LabelingJobArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("LastModifiedTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.LastModifiedTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("ModelArtifacts", targetDepth))
                {
                    var unmarshaller = ModelArtifactsUnmarshaller.Instance;
                    response.ModelArtifacts = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("OutputDataConfig", targetDepth))
                {
                    var unmarshaller = OutputDataConfigUnmarshaller.Instance;
                    response.OutputDataConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("ResourceConfig", targetDepth))
                {
                    var unmarshaller = ResourceConfigUnmarshaller.Instance;
                    response.ResourceConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("RoleArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.RoleArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("SecondaryStatus", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.SecondaryStatus = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("SecondaryStatusTransitions", targetDepth))
                {
                    var unmarshaller = new ListUnmarshaller <SecondaryStatusTransition, SecondaryStatusTransitionUnmarshaller>(SecondaryStatusTransitionUnmarshaller.Instance);
                    response.SecondaryStatusTransitions = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("StoppingCondition", targetDepth))
                {
                    var unmarshaller = StoppingConditionUnmarshaller.Instance;
                    response.StoppingCondition = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TensorBoardOutputConfig", targetDepth))
                {
                    var unmarshaller = TensorBoardOutputConfigUnmarshaller.Instance;
                    response.TensorBoardOutputConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingEndTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.TrainingEndTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingJobArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TrainingJobArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingJobName", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TrainingJobName = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingJobStatus", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TrainingJobStatus = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingStartTime", targetDepth))
                {
                    var unmarshaller = DateTimeUnmarshaller.Instance;
                    response.TrainingStartTime = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TrainingTimeInSeconds", targetDepth))
                {
                    var unmarshaller = IntUnmarshaller.Instance;
                    response.TrainingTimeInSeconds = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("TuningJobArn", targetDepth))
                {
                    var unmarshaller = StringUnmarshaller.Instance;
                    response.TuningJobArn = unmarshaller.Unmarshall(context);
                    continue;
                }
                if (context.TestExpression("VpcConfig", targetDepth))
                {
                    var unmarshaller = VpcConfigUnmarshaller.Instance;
                    response.VpcConfig = unmarshaller.Unmarshall(context);
                    continue;
                }
            }

            return(response);
        }
Пример #3
0
        static void Main(string[] args)
        {
            string accesKeyId      = "";
            string accessKeySecret = "";
            string trainingImage   = "";
            string roleArn         = "";

            string trainingJobName = $"ontology-training-job-{DateTime.UtcNow.Ticks}";
            string endpointName    = $"ontology-endpoint-{DateTime.UtcNow.Ticks}";

            using (AmazonS3Client client = new AmazonS3Client(accesKeyId, accessKeySecret, Amazon.RegionEndpoint.EUCentral1))
            {
                TransferUtility fileTransferUtility = new TransferUtility(client);

                // upload our csv training\test files

                using (AmazonSageMakerClient awsSageMakerClient = new AmazonSageMakerClient(accesKeyId, accessKeySecret, Amazon.RegionEndpoint.EUCentral1))
                {
                    CreateTrainingJobResponse response = awsSageMakerClient.CreateTrainingJobAsync(new CreateTrainingJobRequest()
                    {
                        AlgorithmSpecification = new AlgorithmSpecification()
                        {
                            TrainingInputMode = TrainingInputMode.File,
                            TrainingImage     = trainingImage
                        },
                        OutputDataConfig = new OutputDataConfig()
                        {
                            S3OutputPath = "https://s3.eu-central-1.amazonaws.com/sagemaker-ovechko/sagemaker/test-csv/output"
                        },
                        ResourceConfig = new ResourceConfig()
                        {
                            InstanceCount  = 1,
                            InstanceType   = TrainingInstanceType.MlM4Xlarge,
                            VolumeSizeInGB = 5
                        },
                        TrainingJobName = trainingJobName,
                        HyperParameters = new Dictionary <string, string>()
                        {
                            { "eta", "0.1" },
                            { "objective", "multi:softmax" },
                            { "num_round", "5" },
                            { "num_class", "3" }
                        },
                        StoppingCondition = new StoppingCondition()
                        {
                            MaxRuntimeInSeconds = 3600
                        },
                        RoleArn         = roleArn,
                        InputDataConfig = new List <Channel>()
                        {
                            new Channel()
                            {
                                ChannelName = "train",
                                DataSource  = new DataSource()
                                {
                                    S3DataSource = new S3DataSource()
                                    {
                                        S3DataType             = S3DataType.S3Prefix,
                                        S3Uri                  = "https://s3.eu-central-1.amazonaws.com/sagemaker-ovechko/sagemaker/test-csv/train/",
                                        S3DataDistributionType = S3DataDistribution.FullyReplicated
                                    }
                                },
                                ContentType     = "csv",
                                CompressionType = Amazon.SageMaker.CompressionType.None
                            },
                            new Channel()
                            {
                                ChannelName = "validation",
                                DataSource  = new DataSource()
                                {
                                    S3DataSource = new S3DataSource()
                                    {
                                        S3DataType             = S3DataType.S3Prefix,
                                        S3Uri                  = "https://s3.eu-central-1.amazonaws.com/sagemaker-ovechko/sagemaker/test-csv/validation/",
                                        S3DataDistributionType = S3DataDistribution.FullyReplicated
                                    }
                                },
                                ContentType     = "csv",
                                CompressionType = Amazon.SageMaker.CompressionType.None
                            }
                        }
                    }).Result;

                    string modelName = $"{trainingJobName}-model";

                    DescribeTrainingJobResponse info = new DescribeTrainingJobResponse()
                    {
                        TrainingJobStatus = TrainingJobStatus.InProgress
                    };

                    while (info.TrainingJobStatus == TrainingJobStatus.InProgress)
                    {
                        info = awsSageMakerClient.DescribeTrainingJobAsync(new DescribeTrainingJobRequest()
                        {
                            TrainingJobName = trainingJobName
                        }).Result;

                        if (info.TrainingJobStatus == TrainingJobStatus.InProgress)
                        {
                            Logger.Info("Training job creation is in progress...");
                            Thread.Sleep(10000);
                        }
                    }

                    Logger.Info($"Training job creation has been finished. With status {info.TrainingJobStatus.ToString()}. {info.FailureReason}");

                    if (info.TrainingJobStatus == TrainingJobStatus.Completed)
                    {
                        CreateModelResponse modelCreationInfo = awsSageMakerClient.CreateModelAsync(new CreateModelRequest()
                        {
                            ModelName        = modelName,
                            ExecutionRoleArn = roleArn,
                            PrimaryContainer = new ContainerDefinition()
                            {
                                ModelDataUrl = info.ModelArtifacts.S3ModelArtifacts,
                                Image        = trainingImage
                            }
                        }).Result;

                        string endpointConfigName = $"{endpointName}-config";

                        awsSageMakerClient.CreateEndpointConfigAsync(new CreateEndpointConfigRequest()
                        {
                            EndpointConfigName = endpointConfigName,
                            ProductionVariants = new List <ProductionVariant>()
                            {
                                new ProductionVariant()
                                {
                                    InstanceType         = ProductionVariantInstanceType.MlM4Xlarge,
                                    InitialVariantWeight = 1,
                                    InitialInstanceCount = 1,
                                    ModelName            = modelName,
                                    VariantName          = "AllTraffic"
                                }
                            }
                        });

                        CreateEndpointResponse endpointCreationInfo = awsSageMakerClient.CreateEndpointAsync(new CreateEndpointRequest()
                        {
                            EndpointConfigName = endpointConfigName,
                            EndpointName       = endpointName
                        }).Result;

                        EndpointStatus currentStatus = EndpointStatus.Creating;
                        while (currentStatus == EndpointStatus.Creating)
                        {
                            currentStatus = awsSageMakerClient.DescribeEndpointAsync(new DescribeEndpointRequest()
                            {
                                EndpointName = endpointName
                            }).Result.EndpointStatus;

                            if (currentStatus == EndpointStatus.Creating)
                            {
                                Logger.Info("Endpoint creation is in progress...");
                                Thread.Sleep(10000);
                            }
                        }

                        Logger.Info("Endpoint creation has been finished.");

                        if (currentStatus == EndpointStatus.InService)
                        {
                            using (AmazonSageMakerRuntimeClient sageMakerRuntimeClient = new AmazonSageMakerRuntimeClient(accesKeyId, accessKeySecret, Amazon.RegionEndpoint.EUCentral1))
                            {
                                using (MemoryStream ms = new MemoryStream())
                                {
                                    GetObjectResponse s3Response = client.GetObjectAsync("sagemaker-ovechko", "sagemaker/test-csv/test/test.csv").Result;
                                    s3Response.ResponseStream.CopyTo(ms);
                                    ms.Seek(0, SeekOrigin.Begin);
                                    using (StreamReader sr = new StreamReader(ms))
                                    {
                                        string csv = sr.ReadToEnd();
                                        csv = csv.Replace("", string.Empty);
                                        using (MemoryStream ms2 = new MemoryStream(Encoding.ASCII.GetBytes(csv)))
                                        {
                                            InvokeEndpointResponse endpointResponseInfo = sageMakerRuntimeClient.InvokeEndpointAsync(new InvokeEndpointRequest()
                                            {
                                                ContentType  = "text/csv",
                                                EndpointName = endpointName,
                                                Body         = ms2,
                                            }).Result;
                                            using (StreamReader sr2 = new StreamReader(endpointResponseInfo.Body))
                                            {
                                                string endpointResponseBody = sr2.ReadToEnd();
                                                Logger.Info(endpointResponseBody);
                                            }
                                        }
                                    }
                                }

                                Logger.Info("Performing clean up...");

                                awsSageMakerClient.DeleteEndpointAsync(new DeleteEndpointRequest()
                                {
                                    EndpointName = endpointName
                                });

                                awsSageMakerClient.DeleteEndpointConfigAsync(new DeleteEndpointConfigRequest()
                                {
                                    EndpointConfigName = endpointConfigName
                                });

                                awsSageMakerClient.DeleteModelAsync(new DeleteModelRequest()
                                {
                                    ModelName = modelName
                                });

                                Logger.Info("Clean up finished.");
                            }
                        }
                    }
                }
            }
            Console.ReadLine();
        }