/// <summary> /// Compress SMB2 packet. /// </summary> /// <param name="packet">The SMB2 packet.</param> /// <param name="compressionInfo">Compression info.</param> /// <param name="role">SMB2 role.</param> /// <returns>The compressed packet, or original packet if compression is not applicable.</returns> public static Smb2Packet Compress(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { Smb2Packet compressed; if (compressionInfo.SupportChainedCompression) { compressed = CompressForChained(packet, compressionInfo, role); } else { compressed = CompressForNonChained(packet, compressionInfo, role); } if (compressed == packet) { // Compression is not applicable. return(packet); } var originalBytes = packet.ToBytes(); var compressedBytes = compressed.ToBytes(); // Check whether compression shrinks the on-wire packet size if (compressedBytes.Length < originalBytes.Length) { return(compressed); } else { return(packet); } }
/// <summary> /// Decompress the Smb2CompressedPacket. /// </summary> /// <param name="packet">The compressed packet.</param> /// <param name="compressionInfo">Compression info.</param> /// <param name="role">SMB2 role.</param> /// <returns>Byte array containing the decompressed packet.</returns> public static byte[] Decompress(Smb2CompressedPacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { bool isChained = packet.Header.Flags.HasFlag(Compression_Transform_Header_Flags.SMB2_COMPRESSION_FLAG_CHAINED); byte[] decompressedData; int uncompressedDataSize = 0; if (isChained) { decompressedData = DecompressForChained(packet, compressionInfo, role); } else { decompressedData = DecompressForNonChained(packet, compressionInfo, role); var p = packet as Smb2NonChainedCompressedPacket; uncompressedDataSize = p.UncompressedData.Length; } // If the packed is not chained, the length of UncompressedData may not be 0. if (decompressedData.Length != packet.Header.OriginalCompressedSegmentSize + uncompressedDataSize) { throw new InvalidOperationException($"The length of decompressed data ({decompressedData.Length}) is inconsistent with the sum of compression header length ({packet.Header.OriginalCompressedSegmentSize}) and UncompressedData length ({uncompressedDataSize})."); } return(decompressedData); }
private static Smb2Packet CompressForChained(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { bool needCompression = IsCompressionNeeded(packet, compressionInfo, role); if (!needCompression) { return(packet); } bool isPatternV1Supported = compressionInfo.CompressionIds.Contains(CompressionAlgorithm.Pattern_V1); Smb2Packet compressed; if (isPatternV1Supported) { compressed = CompressWithPatternV1(packet, compressionInfo, role); } else { // Regress to non-chained since pattern scanning algorithm is not supported. compressed = CompressForNonChained(packet, compressionInfo, role); } return(compressed); }
/// <summary> /// Compress SMB2 packet. /// </summary> /// <param name="packet">The SMB2 packet.</param> /// <param name="compressionInfo">Compression info.</param> /// <param name="role">SMB2 role.</param> /// <param name="offset">The offset where compression start, default zero.</param> /// <returns></returns> public static Smb2Packet Compress(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role, uint offset = 0) { var compressionAlgorithm = GetCompressionAlgorithm(packet, compressionInfo, role); if (compressionAlgorithm == CompressionAlgorithm.NONE) { return packet; } var packetBytes = packet.ToBytes(); var compressor = GetCompressor(compressionAlgorithm); var compressedPacket = new Smb2CompressedPacket(); compressedPacket.Header.ProtocolId = Smb2Consts.ProtocolIdInCompressionTransformHeader; compressedPacket.Header.OriginalCompressedSegmentSize = (uint)packetBytes.Length; compressedPacket.Header.CompressionAlgorithm = compressionAlgorithm; compressedPacket.Header.Reserved = 0; compressedPacket.Header.Offset = offset; compressedPacket.UncompressedData = packetBytes.Take((int)offset).ToArray(); compressedPacket.CompressedData = compressor.Compress(packetBytes.Skip((int)offset).ToArray()); var compressedPackectBytes = compressedPacket.ToBytes(); // Check whether compression shrinks the on-wire packet size if (compressedPackectBytes.Length < packetBytes.Length) { compressedPacket.OriginalPacket = packet; return compressedPacket; } else { return packet; } }
private static bool IsCompressionNeeded(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { bool supportCompression = compressionInfo.CompressionIds.Any(compressionAlgorithm => compressionAlgorithm != CompressionAlgorithm.NONE); if (!supportCompression) { return(false); } bool needCompression = false; switch (role) { case Smb2Role.Client: { // Client will compress outgoing packets when: // 1. EligibleForCompression is set for write request. (when user hopes write request to be compressed) // 2. CompressAllPackets is set. (when user hopes all request to be compressed) if (compressionInfo.CompressAllPackets) { needCompression = true; } else if (packet is Smb2WriteRequestPacket) { needCompression = packet.EligibleForCompression; } } break; case Smb2Role.Server: { // Server will compress outgoing packets when: // 1. CompressAllPackets is set and EligibleForCompression. (when server hopes all responses to be compressed, and request is compressed) // 2. EligibleForCompression is set for read response. (when compress read is specified in read request) if (compressionInfo.CompressAllPackets || packet is Smb2ReadResponsePacket) { needCompression = packet.EligibleForCompression; } } break; default: { throw new InvalidOperationException("Unknown SMB2 role!"); } } if (needCompression && compressionInfo.CompressBufferOnly) { // Not compress packet if it does not contain buffer. if (!(packet is IPacketBuffer) || (packet as IPacketBuffer).BufferLength == 0) { needCompression = false; } } return(needCompression); }
private static byte[] DecompressForChained(Smb2CompressedPacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { var p = packet as Smb2ChainedCompressedPacket; var result = p.Payloads.SelectMany(tuple => { var payloadHeader = tuple.Item1; switch (payloadHeader.CompressionAlgorithm) { case CompressionAlgorithm.NONE: { return(tuple.Item2 as byte[]); } break; case CompressionAlgorithm.LZNT1: case CompressionAlgorithm.LZ77: case CompressionAlgorithm.LZ77Huffman: { var decompressor = GetDecompressor(payloadHeader.CompressionAlgorithm); var decompressedData = decompressor.Decompress(tuple.Item2 as byte[]); if (decompressedData.Length != payloadHeader.OriginalPayloadSize) { throw new InvalidOperationException("The length decompressed chained payload is inconsistent with OriginalPayloadSize of payload header."); } return(decompressedData); } break; case CompressionAlgorithm.Pattern_V1: { var pattern = (SMB2_COMPRESSION_PATTERN_PAYLOAD_V1)tuple.Item2; return(Enumerable.Repeat(pattern.Pattern, (int)pattern.Repetitions).ToArray()); } break; default: { throw new InvalidOperationException("Unexpected compression algorithm!"); } break; } }); return(result.ToArray()); }
/// <summary> /// Compress SMB2 packet. /// </summary> /// <param name="packet">The SMB2 packet.</param> /// <param name="compressionInfo">Compression info.</param> /// <param name="role">SMB2 role.</param> /// <param name="offset">The offset where compression start, default zero.</param> /// <returns></returns> public static Smb2Packet Compress(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role, uint offset = 0) { var compressionAlgorithm = GetCompressionAlgorithm(packet, compressionInfo, role); if (compressionAlgorithm == CompressionAlgorithm.NONE) { return(packet); } var packetBytes = packet.ToBytes(); var compressor = GetCompressor(compressionAlgorithm); var compressedPacket = new Smb2CompressedPacket(); compressedPacket.Header.ProtocolId = Smb2Consts.ProtocolIdInCompressionTransformHeader; compressedPacket.Header.OriginalCompressedSegmentSize = (uint)packetBytes.Length; compressedPacket.Header.CompressionAlgorithm = compressionAlgorithm; compressedPacket.Header.Reserved = 0; compressedPacket.Header.Offset = offset; compressedPacket.UncompressedData = packetBytes.Take((int)offset).ToArray(); compressedPacket.CompressedData = compressor.Compress(packetBytes.Skip((int)offset).ToArray()); // HACK: fake size if (((Smb2SinglePacket)packet).Header.Command == Smb2Command.WRITE) { ((Smb2WriteRequestPacket)packet).PayLoad.Length += 0x1000; compressedPacket.Header.OriginalCompressedSegmentSize += 0x1000; } // HACK: force compressed packet to be sent return(compressedPacket); // var compressedPackectBytes = compressedPacket.ToBytes(); // Check whether compression shrinks the on-wire packet size // if (compressedPackectBytes.Length < packetBytes.Length) // { // compressedPacket.OriginalPacket = packet; // return compressedPacket; // } // else // { // return packet; // } }
/// <summary> /// Decompress the Smb2CompressedPacket. /// </summary> /// <param name="packet">The compressed packet.</param> /// <param name="compressionInfo">Compression info.</param> /// <param name="role">SMB2 role.</param> /// <returns>Byte array containing the decompressed packet.</returns> public static byte[] Decompress(Smb2CompressedPacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { if (packet.Header.CompressionAlgorithm == CompressionAlgorithm.NONE) { throw new InvalidOperationException("Invalid CompressionAlgorithm in header!"); } if (!compressionInfo.CompressionIds.Any(compressionAlgorithm => compressionAlgorithm == packet.Header.CompressionAlgorithm)) { throw new InvalidOperationException("The CompressionAlgorithm is not supported!"); } var decompressor = GetDecompressor(packet.Header.CompressionAlgorithm); var decompressedBytes = decompressor.Decompress(packet.CompressedData); var originalPacketBytes = packet.UncompressedData.Concat(decompressedBytes).ToArray(); return(originalPacketBytes); }
/// <summary> /// Decompress the Smb2CompressedPacket. /// </summary> /// <param name="packet">The compressed packet.</param> /// <param name="compressionInfo">Compression info.</param> /// <param name="role">SMB2 role.</param> /// <returns>Byte array containing the decompressed packet.</returns> public static byte[] Decompress(Smb2CompressedPacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { bool isChained = packet.Header.Flags.HasFlag(Compression_Transform_Header_Flags.SMB2_COMPRESSION_FLAG_CHAINED); byte[] decompressedData; if (isChained) { decompressedData = DecompressForChained(packet, compressionInfo, role); } else { decompressedData = DecompressForNonChained(packet, compressionInfo, role); } if (decompressedData.Length != packet.Header.OriginalCompressedSegmentSize) { throw new InvalidOperationException($"The length of decompressed data (0x{decompressedData.Length:X08}) is inconsistent with compression header (0x{packet.Header.OriginalCompressedSegmentSize:X08})."); } return(decompressedData); }
private static CompressionAlgorithm GetPreferredCompressionAlgorithm(Smb2CompressionInfo compressionInfo) { if (compressionInfo.PreferredCompressionAlgorithm == CompressionAlgorithm.NONE) { var commonSupportedCompressionAlgorithms = Smb2Utility.GetSupportedCompressionAlgorithms(compressionInfo.CompressionIds); if (commonSupportedCompressionAlgorithms.Length > 0) { return(commonSupportedCompressionAlgorithms.First()); } else { return(CompressionAlgorithm.NONE); } } else { if (!compressionInfo.CompressionIds.Contains(compressionInfo.PreferredCompressionAlgorithm)) { throw new InvalidOperationException("Specified preferred compression algorithm is not supported by SUT!"); } return(compressionInfo.PreferredCompressionAlgorithm); } }
/// <summary> /// Sign, compress and encrypt for Single or Compound packet. /// </summary> public static Smb2Packet SignCompressAndEncrypt(Smb2Packet originalPacket, Dictionary <ulong, Smb2CryptoInfo> cryptoInfoTable, Smb2CompressionInfo compressioninfo, Smb2Role role) { ulong sessionId; bool isCompound = false; bool notEncryptNotSign = false; bool notEncrypt = false; var compressedPacket = originalPacket; if (originalPacket is Smb2SinglePacket) { Smb2SinglePacket singlePacket = originalPacket as Smb2SinglePacket; sessionId = singlePacket.Header.SessionId; // [MS-SMB2] Section 3.2.4.1.8, the request being sent is SMB2 NEGOTIATE, // or the request being sent is SMB2 SESSION_SETUP with the SMB2_SESSION_FLAG_BINDING bit set in the Flags field, // the client MUST NOT encrypt the message if (sessionId == 0 || (singlePacket.Header.Command == Smb2Command.NEGOTIATE && (singlePacket is Smb2NegotiateRequestPacket))) { notEncryptNotSign = true; } else if ((singlePacket.Header.Command == Smb2Command.SESSION_SETUP && (singlePacket is Smb2SessionSetupRequestPacket) && (singlePacket as Smb2SessionSetupRequestPacket).PayLoad.Flags == SESSION_SETUP_Request_Flags.SESSION_FLAG_BINDING)) { notEncrypt = true; } } else if (originalPacket is Smb2CompoundPacket) { isCompound = true; // The subsequent request in compound packet should use the SessionId of the first request for encryption sessionId = (originalPacket as Smb2CompoundPacket).Packets[0].Header.SessionId; } else { throw new NotImplementedException(string.Format("Signing and encryption are not implemented for packet: {0}", originalPacket.ToString())); } if (sessionId == 0 || notEncryptNotSign || !cryptoInfoTable.ContainsKey(sessionId)) { if (originalPacket is Smb2CompressiblePacket) { compressedPacket = Smb2Compression.Compress(originalPacket as Smb2CompressiblePacket, compressioninfo, role); } return(compressedPacket); } Smb2CryptoInfo cryptoInfo = cryptoInfoTable[sessionId]; #region Encrypt // Try to encrypt the message whenever the encryption is supported or not except for sesstion setup. // If it's not supported, do it for negative test. // For compound packet, the encryption is done for the entire message. if (!notEncrypt) { if (originalPacket is Smb2CompressiblePacket) { compressedPacket = Smb2Compression.Compress(originalPacket as Smb2CompressiblePacket, compressioninfo, role); } var encryptedPacket = Encrypt(sessionId, cryptoInfo, role, compressedPacket, originalPacket); if (encryptedPacket != null) { return(encryptedPacket); } } #endregion #region Sign if (cryptoInfo.EnableSessionSigning) { if (isCompound) { // Calculate signature for every packet in the chain foreach (Smb2SinglePacket packet in (originalPacket as Smb2CompoundPacket).Packets) { // If the packet is the first one in the chain or the unralated one, use its own SessionId for sign and encrypt // If it's not the first one and it's the related one, use the SessionId of the first request for sign and encrypt if (!packet.Header.Flags.HasFlag(Packet_Header_Flags_Values.FLAGS_RELATED_OPERATIONS)) { sessionId = packet.Header.SessionId; cryptoInfo = cryptoInfoTable[sessionId]; } packet.Header.Signature = Sign(cryptoInfo, packet.ToBytes()); } } else { (originalPacket as Smb2SinglePacket).Header.Signature = Sign(cryptoInfo, originalPacket.ToBytes()); } } #endregion if (originalPacket is Smb2CompressiblePacket) { compressedPacket = Smb2Compression.Compress(originalPacket as Smb2CompressiblePacket, compressioninfo, role); } return(compressedPacket); }
private static CompressionAlgorithm GetCompressionAlgorithm(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { bool supportCompression = compressionInfo.CompressionIds.Any(compressionAlgorithm => compressionAlgorithm != CompressionAlgorithm.NONE); if (!supportCompression) { return(CompressionAlgorithm.NONE); } bool needCompression = false; switch (role) { case Smb2Role.Client: { // Client will compress outgoing packets when: // 1. EligibleForCompression is set for write request. (when user hopes write request to be compressed) // 2. CompressAllPackets is set. (when user hopes all request to be compressed) if (compressionInfo.CompressAllPackets) { needCompression = true; } else if (packet is Smb2WriteRequestPacket) { needCompression = packet.EligibleForCompression; } } break; case Smb2Role.Server: { // Server will compress outgoing packets when: // 1. CompressAllPackets is set and EligibleForCompression. (when server hopes all responses to be compressed, and request is compressed) // 2. EligibleForCompression is set for read response. (when compress read is specified in read request) if (compressionInfo.CompressAllPackets || packet is Smb2ReadResponsePacket) { needCompression = packet.EligibleForCompression; } } break; default: { throw new InvalidOperationException("Unknown SMB2 role!"); } } if (!needCompression) { return(CompressionAlgorithm.NONE); } if (compressionInfo.PreferredCompressionAlgorithm == CompressionAlgorithm.NONE) { return(compressionInfo.CompressionIds.First(compressionAlgorithm => compressionAlgorithm != CompressionAlgorithm.NONE)); } else { if (!compressionInfo.CompressionIds.Contains(compressionInfo.PreferredCompressionAlgorithm)) { throw new InvalidOperationException("Specified preferred compression algorithm is not supported by SUT!"); } return(compressionInfo.PreferredCompressionAlgorithm); } }
private static CompressionAlgorithm GetCompressionAlgorithm(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { bool needCompression = IsCompressionNeeded(packet, compressionInfo, role); if (!needCompression) { return(CompressionAlgorithm.NONE); } var result = GetPreferredCompressionAlgorithm(compressionInfo); return(result); }
private static Smb2Packet CompressForNonChained(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { var compressionAlgorithm = GetCompressionAlgorithm(packet, compressionInfo, role); if (compressionAlgorithm == CompressionAlgorithm.NONE) { return(packet); } var packetBytes = packet.ToBytes(); var compressor = GetCompressor(compressionAlgorithm); uint offset = 0; if (compressionInfo.CompressBufferOnly) { offset = (packet as IPacketBuffer).BufferOffset; } var compressedPacket = new Smb2NonChainedCompressedPacket(); compressedPacket.Header.ProtocolId = Smb2Consts.ProtocolIdInCompressionTransformHeader; compressedPacket.Header.OriginalCompressedSegmentSize = (uint)packetBytes.Length; compressedPacket.Header.CompressionAlgorithm = compressionAlgorithm; compressedPacket.Header.Flags = Compression_Transform_Header_Flags.SMB2_COMPRESSION_FLAG_NONE; compressedPacket.Header.Offset = offset; compressedPacket.UncompressedData = packetBytes.Take((int)offset).ToArray(); compressedPacket.CompressedData = compressor.Compress(packetBytes.Skip((int)offset).ToArray()); compressedPacket.OriginalPacket = packet; return(compressedPacket); }
private static Smb2Packet CompressWithPatternV1(Smb2CompressiblePacket packet, Smb2CompressionInfo compressionInfo, Smb2Role role) { var data = packet.ToBytes(); var dataToCompress = data; if (compressionInfo.CompressBufferOnly) { dataToCompress = data.Skip((packet as IPacketBuffer).BufferOffset).ToArray(); } SMB2_COMPRESSION_PATTERN_PAYLOAD_V1?forwardDataPattern; SMB2_COMPRESSION_PATTERN_PAYLOAD_V1?backwardDataPattern; ScanForPatternV1(dataToCompress, out forwardDataPattern, out backwardDataPattern); if (forwardDataPattern == null && backwardDataPattern == null) { // Regress to non-chained since no pattern is found at front or end. return(CompressForNonChained(packet, compressionInfo, role)); } var result = new Smb2ChainedCompressedPacket(); result.OriginalPacket = packet; result.Header = new Compression_Transform_Header(); result.Header.ProtocolId = Smb2Consts.ProtocolIdInCompressionTransformHeader; result.Header.OriginalCompressedSegmentSize = (UInt32)data.Length; result.Header.Flags = Compression_Transform_Header_Flags.SMB2_COMPRESSION_FLAG_CHAINED; var payloads = new List <Tuple <SMB2_COMPRESSION_PAYLOAD_HEADER, object> >(); bool isFirst = true; if (compressionInfo.CompressBufferOnly) { var header = data.Take((packet as IPacketBuffer).BufferOffset).ToArray(); var payloadHeader = SMB2_COMPRESSION_PAYLOAD_HEADER.Create(CompressionAlgorithm.NONE, (UInt32)header.Length, 0, ref isFirst); payloads.Add(new Tuple <SMB2_COMPRESSION_PAYLOAD_HEADER, object>(payloadHeader, header)); } if (forwardDataPattern != null) { var payloadHeader = SMB2_COMPRESSION_PAYLOAD_HEADER.Create(CompressionAlgorithm.Pattern_V1, (UInt32)PatternPayloadLength, 0, ref isFirst); payloads.Add(new Tuple <SMB2_COMPRESSION_PAYLOAD_HEADER, object>(payloadHeader, forwardDataPattern)); } int forwardDataPatternLength = (int)(forwardDataPattern?.Repetitions ?? 0); int backwardDataPatternLength = (int)(backwardDataPattern?.Repetitions ?? 0); var innerData = dataToCompress.Skip(forwardDataPatternLength).Take(dataToCompress.Length - forwardDataPatternLength - backwardDataPatternLength).ToArray(); if (innerData.Length > 0) { var innerPayload = ChainedCompressWithCompressionAlgorithm(innerData, compressionInfo, ref isFirst); payloads.Add(innerPayload); } if (backwardDataPattern != null) { var payloadHeader = SMB2_COMPRESSION_PAYLOAD_HEADER.Create(CompressionAlgorithm.Pattern_V1, (UInt32)PatternPayloadLength, 0, ref isFirst); payloads.Add(new Tuple <SMB2_COMPRESSION_PAYLOAD_HEADER, object>(payloadHeader, backwardDataPattern)); } result.Payloads = payloads.ToArray(); return(result); }
private static Tuple <SMB2_COMPRESSION_PAYLOAD_HEADER, object> ChainedCompressWithCompressionAlgorithm(byte[] data, Smb2CompressionInfo compressionInfo, ref bool isFirst) { var compressionAlgorithm = GetPreferredCompressionAlgorithm(compressionInfo); SMB2_COMPRESSION_PAYLOAD_HEADER payloadHeader; bool isFirstCopy; isFirstCopy = isFirst; var compressedPayloadHeader = SMB2_COMPRESSION_PAYLOAD_HEADER.Create(compressionAlgorithm, 0, (UInt32)data.Length, ref isFirstCopy); isFirstCopy = isFirst; var uncompressedPayloadHeader = SMB2_COMPRESSION_PAYLOAD_HEADER.Create(CompressionAlgorithm.NONE, (UInt32)data.Length, 0, ref isFirstCopy); isFirst = isFirstCopy; byte[] payloadData; if (compressionAlgorithm != CompressionAlgorithm.NONE) { var compressor = GetCompressor(compressionAlgorithm); var compressedData = compressor.Compress(data); int compressedDataLength = TypeMarshal.ToBytes(compressedPayloadHeader).Length + compressedData.Length; int uncompressedDataLength = TypeMarshal.ToBytes(uncompressedPayloadHeader).Length + data.Length; if (compressedDataLength < uncompressedDataLength) { compressedPayloadHeader.Length = (UInt32)(compressedData.Length + FieldSizeOriginalPayloadSize); payloadHeader = compressedPayloadHeader; payloadData = compressedData; } else { payloadHeader = uncompressedPayloadHeader; payloadData = data; } } else { payloadHeader = uncompressedPayloadHeader; payloadData = data; } return(new Tuple <SMB2_COMPRESSION_PAYLOAD_HEADER, object>(payloadHeader, payloadData)); }