Example #1
0
        public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            if (context.IsHtml())
            {
                var body = context.ReadAsString();
                if (body != null)
                {
                    foreach (var comment in FastHtmlParser.FindAllComments(body))
                    {
                        if (
                            // respect Knockout.js comments
                            comment.StartsWith("<!-- ko", StringComparison.Ordinal) ||
                            comment.StartsWith("<!-- /ko", StringComparison.Ordinal) ||

                            // respect IE conditional comments
                            comment.StartsWith("<!--[if ", StringComparison.Ordinal) ||
                            comment.StartsWith("[endif]-->", StringComparison.Ordinal)
                            )
                        {
                            continue;
                        }

                        context.AddChange(TextChange.Remove(body, comment.Offset, comment.Length));
                    }
                }
            }
        }
Example #2
0
 public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
 {
     if (context.Response.Headers.TryGetValue("Content-Encoding", out var contentEncoding))
     {
         foreach (var encoding in contentEncoding.SelectMany(v => new StringTokenizer(v, Separators)))
         {
             if (encoding.Equals("gzip", StringComparison.OrdinalIgnoreCase))
             {
                 using var gzip = new GZipStream(context.ReadAsStream().Rewind(), CompressionMode.Decompress, leaveOpen: true);
                 var buffer = new MemoryStream();
                 gzip.CopyTo(buffer);
                 context.SetBodyFromStream(buffer.Rewind());
             }
             else if (encoding.Equals("deflate", StringComparison.OrdinalIgnoreCase))
             {
                 using var gzip = new DeflateStream(context.ReadAsStream().Rewind(), CompressionMode.Decompress, leaveOpen: true);
                 var buffer = new MemoryStream();
                 gzip.CopyTo(buffer);
                 context.SetBodyFromStream(buffer.Rewind());
             }
             else if (encoding.Equals("br", StringComparison.OrdinalIgnoreCase))
             {
                 using var gzip = new BrotliStream(context.ReadAsStream().Rewind(), CompressionMode.Decompress, leaveOpen: true);
                 var buffer = new MemoryStream();
                 gzip.CopyTo(buffer);
                 context.SetBodyFromStream(buffer.Rewind());
             }
         }
     }
 }
        private static void InspectBody(ResponseAnalysisContext context)
        {
            string body = context.ReadAsString();

            if (body != null)
            {
                // rewrite HTML
                if (context.IsHtml())
                {
                    foreach (var index in body.AllIndexesOf("http:", StringComparison.OrdinalIgnoreCase))
                    {
                        // make sure it is inside an attribute
                        var attributeName = FastHtmlParser.GetAttributeNameFromValueIndex(body, index);
                        if (attributeName.Equals("src", StringComparison.OrdinalIgnoreCase) ||
                            attributeName.Equals("href", StringComparison.OrdinalIgnoreCase) ||
                            attributeName.Equals("srcset", StringComparison.OrdinalIgnoreCase))
                        {
                            context.AddChange(TextChange.Insert(body, index + "http".Length, "s"));

                            // 'xmlns' excluded
                        }
                    }
                }

                // rewrite CSS
                if (context.IsCss())
                {
                    foreach (var index in body.AllIndexesOf("http:", StringComparison.OrdinalIgnoreCase))
                    {
                        context.AddChange(TextChange.Insert(body, index + "http".Length, "s"));
                    }
                }
            }
        }
        public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            if (context.IsHtml())
            {
                // read content
                string html = context.ReadAsString();
                if (html == null)
                {
                    return;
                }

                string csrfTag = null;

                // enumerate forms
                foreach (var formIndex in FastHtmlParser.FindAllTagIndexes(html, "form"))
                {
                    // check method attribute
                    var method = FastHtmlParser.GetAttributeValueAtTag(html, "method", formIndex);
                    if (method.Equals("post", StringComparison.OrdinalIgnoreCase))
                    {
                        // generate CSRF hidden field
                        if (csrfTag == null)
                        {
                            var tokens = Antiforgery.GetAndStoreTokens(context.HttpContext);
                            csrfTag = $"<input type=\"hidden\" name=\"{HtmlEncoder.Default.Encode(tokens.FormFieldName)}\" value=\"{HtmlEncoder.Default.Encode(tokens.RequestToken)}\" />";
                        }

                        // insert field
                        int close = FastHtmlParser.FindClosePairFlatIndex(html, "form", formIndex);
                        context.AddChange(TextChange.Insert(html, close, csrfTag));
                    }
                }
            }
        }
Example #5
0
 public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
 {
     if (context.IsHtml())
     {
         InspectBody(context);
     }
 }
Example #6
0
        public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            if (context.IsHtml())
            {
                var html = context.ReadAsString();
                if (html == null)
                {
                    return;
                }

                foreach (var query in context.Response.HttpContext.Request.Query)
                {
                    foreach (var value in query.Value)
                    {
                        if (ShouldEncode(value))
                        {
                            if (html.Contains(value, StringComparison.Ordinal))
                            {
                                context.ReportDiagnostic(new Diagnostic(Rule, Location.QueryString(query.Key)));
                            }
                        }
                    }
                }
            }
        }
        public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            InspectHeader(context.Response);

            if (context.IsHtmlOrCssOrJs())
            {
                InspectBody(context);
            }
        }
Example #8
0
 public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
 {
     if (context.Response.Headers.TryGetValue("Set-Cookie", out var values))
     {
         if (values.Any(v => !IsSecure(v)))
         {
             context.Response.Headers["Set-Cookie"] = new StringValues(values.Select(MakeSecure).ToArray());
         }
     }
 }
Example #9
0
 protected void AnalyzeResponse(ResponseAnalysisContext context, CancellationToken cancellationToken)
 {
     foreach (var inspector in ResponseInspectors)
     {
         inspector.Inspect(context, cancellationToken);
         if (context.IsMalicious && Options.Value.Depth == AnalysisDepth.FindFirst)
         {
             return;
         }
     }
 }
Example #10
0
        protected async Task AnalyzeResponseAsync(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            foreach (var inspector in AsyncResponseInspectors)
            {
                await inspector.InspectAsync(context, cancellationToken);

                if (context.IsMalicious && Options.Value.Depth == AnalysisDepth.FindFirst)
                {
                    return;
                }
            }
        }
Example #11
0
        private static void InspectBody(ResponseAnalysisContext context)
        {
            var html = context.ReadAsString();

            if (html != null)
            {
                foreach (var index in FastHtmlParser.FindAllTagIndexes(html, "meta"))
                {
                    if (FastHtmlParser.GetAttributeValueAtTag(html, "name", index).Equals("generator", StringComparison.OrdinalIgnoreCase))
                    {
                        var end = FastHtmlParser.FindEndOfOpenTag(html, index) + 1;
                        context.AddChange(TextChange.Remove(html, index, end - index));
                    }
                }
            }
        }
Example #12
0
        public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            if (context.IsTextLike())
            {
                var content = context.ReadAsString();
                foreach (var found in ResponseBody.Where(b => content.Contains(b.Term)))
                {
                    context.ReportDiagnostic(new Diagnostic(Rule.With(found), Location.ResponseBody));

                    if (Options.CurrentValue.Depth == AnalysisDepth.FindFirst)
                    {
                        return;
                    }
                }
            }
        }
Example #13
0
        public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            if (context.Response.Headers.TryGetValue("Set-Cookie", out var setCookies))
            {
                // inspect cookies
                foreach (var setCookie in setCookies)
                {
                    // parse name
                    int delimiterIndex = setCookie.IndexOf('=');
                    var name           = new StringSegment(setCookie, 0, delimiterIndex);

                    if (!name.Contains("sess", StringComparison.OrdinalIgnoreCase))
                    {
                        continue;
                    }

                    // parse value
                    int semicolonIndex = setCookie.IndexOf(';', delimiterIndex + 1);
                    var value          = new StringSegment(setCookie, delimiterIndex + 1, semicolonIndex != -1 ? semicolonIndex : (setCookie.Length - delimiterIndex - 1));
                    if (value.Length < MinimumCookieLength)
                    {
                        continue;
                    }

                    // match to query string
                    foreach (var query in context.HttpContext.Request.Query)
                    {
                        foreach (var queryValue in query.Value)
                        {
                            if (value.Contains(queryValue))
                            {
                                context.ReportDiagnostic(new Diagnostic(Rule, Location.QueryString(query.Key)));
                                return;
                            }
                        }
                    }
                }
            }
        }
        public void Inspect(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            var response = context.Response;

            if (response.Headers.ContainsKey("Server"))
            {
                response.Headers.Remove("Server");
            }

            if (response.Headers.ContainsKey("Via"))
            {
                response.Headers.Remove("Via");
            }

            if (response.Headers.ContainsKey("X-Generator"))
            {
                response.Headers.Remove("X-Generator");
            }

            if (response.Headers.ContainsKey("X-Powered-By"))
            {
                response.Headers.Remove("X-Powered-By");
            }
        }
Example #15
0
        public async Task InvokeAsync(HttpContext context, RequestDelegate next)
        {
            var options = Options.Value;

            context.Request.EnableBuffering();

            // analyze request
            var requestAnalysisContext = new RequestAnalysisContext(context.Request);

            AnalyzeRequest(requestAnalysisContext, context.RequestAborted);
            if (requestAnalysisContext.IsMalicious && options.Depth == AnalysisDepth.FindFirst)
            {
                await Handle(requestAnalysisContext.Diagnostics);

                return;
            }

            await AnalyzeRequestAsync(requestAnalysisContext, context.RequestAborted);

            if (requestAnalysisContext.IsMalicious && options.Depth == AnalysisDepth.FindFirst)
            {
                await Handle(requestAnalysisContext.Diagnostics);

                return;
            }

            // switch body to buffer
            var original = context.Response.Body;
            var buffer   = MemoryStreamPool.Get();

            context.Response.Body = buffer;

            var responseAnalysisContext = new ResponseAnalysisContext(context.Response);

            try
            {
                // execute
                await next(context);

                // analyze response
                AnalyzeResponse(responseAnalysisContext, context.RequestAborted);
                if (responseAnalysisContext.IsMalicious && options.Depth == AnalysisDepth.FindFirst)
                {
                    await Handle(responseAnalysisContext.Diagnostics);

                    return;
                }

                await AnalyzeResponseAsync(responseAnalysisContext, context.RequestAborted);

                if (responseAnalysisContext.IsMalicious && options.Depth == AnalysisDepth.FindFirst)
                {
                    await Handle(responseAnalysisContext.Diagnostics);

                    return;
                }
            }
            finally
            {
                context.Response.Body = original;
            }

            // write response
            if (responseAnalysisContext.Version != 0)
            {
                // remove content-related headers, because may be outdated
                if (context.Response.Headers.ContainsKey("Content-Encoding"))
                {
                    context.Response.Headers.Remove("Content-Encoding");
                }

                if (context.Response.Headers.ContainsKey("Content-MD5"))
                {
                    context.Response.Headers.Remove("Content-MD5");
                }

                // based on body type
                switch (responseAnalysisContext.SnapshotBodyType)
                {
                case ResponseBodyType.String:
                {
                    var    body    = responseAnalysisContext.ReadAsString(applyPendingChanges: true);
                    byte[] buffer1 = Encoding.UTF8.GetBytes(body);
                    context.Response.ContentLength             = buffer1.Length;
                    context.Response.Headers["Content-Length"] = buffer1.Length.ToString();
                    await context.Response.Body.WriteAsync(buffer1, 0, buffer1.Length);
                }
                break;

                case ResponseBodyType.Stream:
                {
                    var body = responseAnalysisContext.ReadAsStream();
                    context.Response.ContentLength             = body.Length;
                    context.Response.Headers["Content-Length"] = body.Length.ToString();
                    await body.CopyToAsync(context.Response.Body);
                }
                break;

                case ResponseBodyType.ByteArray:
                {
                    var body = responseAnalysisContext.ReadAsByteArray();
                    context.Response.ContentLength             = body.Length;
                    context.Response.Headers["Content-Length"] = body.Length.ToString();
                    await context.Response.Body.WriteAsync(body, 0, body.Length);
                }
                break;
                }
            }
            else
            {
                // write original response
                if (buffer.Length > 0)
                {
                    await buffer.Rewind().CopyToAsync(context.Response.Body);
                }
            }


            async Task Handle(IReadOnlyCollection <Diagnostic> diagnostics)
            {
                if (diagnostics.Count == 1)
                {
                    var diagnostic = diagnostics.Single();
                    Logger.LogWarning($"Request denied. {diagnostic.Rule.Category} ({diagnostic.Rule.Id}) attack found in {diagnostic.Location}: {diagnostic.Rule.Description}.");
                }
                else
                {
                    Logger.LogWarning($"Request denied. Found {diagnostics.Count} diagnostics.");
                }

                if (options.Mode == FirewallMode.Prevention)
                {
                    context.Response.StatusCode = options.DeniedResponseStatusCode;
                }

                var settings = new JsonSerializerOptions();

                settings.PropertyNamingPolicy = JsonNamingPolicy.CamelCase;
                settings.WriteIndented        = true;
                var dto  = new { diagnostics };
                var json = JsonSerializer.Serialize(dto, settings);

                context.Response.ContentLength = Encoding.UTF8.GetByteCount(json);
                context.Response.ContentType   = "application/json";
                await context.Response.WriteAsync(json);
            }
        }
Example #16
0
        public async Task InspectAsync(ResponseAnalysisContext context, CancellationToken cancellationToken)
        {
            if (context.IsHtml())
            {
                var html = context.ReadAsString();

                // find all script tags
                foreach (var index in FastHtmlParser.FindAllTagIndexes(html, "script"))
                {
                    // get 'src' attribute
                    var src = FastHtmlParser.GetAttributeValueAtTag(html, "src", index);
                    if (src.IndexOf(':') != -1 || src.StartsWith("//", StringComparison.Ordinal))
                    {
                        // check whether there is an 'integrity' attribute alrady
                        var integrity = FastHtmlParser.GetAttributeValueAtTag(html, "integrity", index);
                        if (integrity.Length > 0)
                        {
                            continue;
                        }

                        // compute integrity
                        var sri = await MemoryCache.GetOrCreateAsync($"SRI_{src}", async ce =>
                        {
                            ce.SetAbsoluteExpiration(DateTimeOffset.MaxValue);
                            ce.SetPriority(CacheItemPriority.High);

                            var client = HttpClientFactory.CreateClient();
                            try
                            {
                                // download script
                                using var request = new HttpRequestMessage(HttpMethod.Get, RelativeTo(src).ToString());
                                request.Headers.TryAddWithoutValidation("Origin", $"{context.HttpContext.Request.Scheme}://{context.HttpContext.Request.Host}");
                                using var response = await client.SendAsync(request);
                                if (!response.IsSuccessStatusCode)
                                {
                                    return(null);
                                }

                                if (response.Content == null)
                                {
                                    return(null);
                                }

                                if (!response.Headers.Any(h => h.Key.StartsWith("Access-Control-Allow", StringComparison.OrdinalIgnoreCase)))
                                {
                                    return(null);
                                }

                                var stream = await response.Content.ReadAsStreamAsync();

                                // compute hash
                                var hash   = HashAlgorithmPool.Sha256.ComputeHash(stream);
                                var base64 = Convert.ToBase64String(hash);
                                return($"sha256-{base64}");
                            }
                            catch (HttpRequestException)
                            {
                                return(null);
                            }
                        });

                        // add attribute
                        if (sri != null)
                        {
                            context.AddChange(FastHtmlParser.CreateInsertAttributeChange(html, index, "script", "integrity", sri));
                            if (FastHtmlParser.GetAttributeValueAtTag(html, "crossorigin", index).Length == 0)
                            {
                                context.AddChange(FastHtmlParser.CreateInsertAttributeChange(html, index, "script", "crossorigin", "anonymous"));
                            }
                        }
                    }
                }
            }
        }