/// <summary> /// This function is called by transport stack as a callback when the transport receive any message /// </summary> /// <param name="endPoint">Where the packet is received</param> /// <param name="messageBytes">The received packet</param> /// <param name="consumedLength">[OUT]The consumed length of the message</param> /// <param name="expectedLength">[OUT]The expected length</param> /// <returns>A array of stackpacket</returns> public StackPacket[] Smb2DecodePacketCallback( object endPoint, byte[] messageBytes, out int consumedLength, out int expectedLength) { if (messageBytes.Length == 0) { consumedLength = 0; expectedLength = 0; return(null); } Smb2Packet packet = DecodeTransportPayload( messageBytes, decodeRole, transportType, false, out consumedLength, out expectedLength); if (packet == null) { return(null); } else { return(new StackPacket[] { packet }); } }
/// <summary> /// Decode the message as smb2 single request packet /// </summary> /// <param name="messageBytes">The received packet</param> /// <param name="consumedLength">[OUT]The consumed length of the message</param> /// <param name="expectedLength">[OUT]The expected length</param> /// <returns>A Smb2Packet</returns> private static Smb2Packet DecodeSingleRequestPacket(byte[] messageBytes, out int consumedLength, out int expectedLength) { Packet_Header smb2Header; int offset = 0; smb2Header = TypeMarshal.ToStruct <Packet_Header>(messageBytes, ref offset); if (smb2Header.Command == Smb2Command.OPLOCK_BREAK) { ushort structureSize = TypeMarshal.ToStruct <ushort>(messageBytes, ref offset); } Smb2Packet packet = null; switch (smb2Header.Command) { case Smb2Command.NEGOTIATE: packet = new Smb2NegotiateRequestPacket(); break; default: throw new InvalidOperationException("Received an unknown packet! the type of the packet is " + smb2Header.Command.ToString()); } packet.FromBytes(messageBytes, out consumedLength, out expectedLength); return(packet); }
/// <summary> /// Decode the message except length field which may exist if transport is tcp /// </summary> /// <param name="messageBytes">The received packet</param> /// <param name="role">The role of this decoder, client or server</param> /// <param name="ignoreCompoundFlag">indicate whether decode the packet as a single packet or a compound packet /// when compound flag is set</param> /// <param name="realSessionId">The real sessionId for this packet</param> /// <param name="realTreeId">The real treeId for this packet</param> /// <param name="consumedLength">[OUT]The consumed length of the message</param> /// <param name="expectedLength">[OUT]The expected length</param> /// <returns>A Smb2Packet</returns> public Smb2Packet DecodeCompletePacket( byte[] messageBytes, Smb2Role role, bool ignoreCompoundFlag, ulong realSessionId, uint realTreeId, out int consumedLength, out int expectedLength ) { //protocol version is of 4 bytes len byte[] protocolVersion = new byte[sizeof(uint)]; Array.Copy(messageBytes, 0, protocolVersion, 0, protocolVersion.Length); SmbVersion version = DecodeVersion(protocolVersion); if (version == SmbVersion.Version1) { // SMB Negotiate packet return(DecodeSmbPacket(messageBytes, role, out consumedLength, out expectedLength)); } else if (version == SmbVersion.Version2Encrypted) { // SMB2 encrypted packet return(DecodeEncryptedSmb2Packet( messageBytes, role, ignoreCompoundFlag, realSessionId, realTreeId, out consumedLength, out expectedLength )); } else { // SMB2 packet not encrypted Smb2Packet decodedPacket = DecodeSmb2Packet( messageBytes, role, ignoreCompoundFlag, realSessionId, realTreeId, out consumedLength, out expectedLength ); //For single packet signature verification if (decodedPacket is Smb2SinglePacket) { //verify signature of a single packet Smb2SinglePacket singlePacket = decodedPacket as Smb2SinglePacket; TryVerifySignatureExceptSessionSetupResponse(singlePacket, singlePacket.Header.SessionId, messageBytes); } else if (decodedPacket is Smb2CompoundPacket)//For Compound packet signature verification { //verify signature of the compound packet TryVerifySignature(decodedPacket as Smb2CompoundPacket, messageBytes); } return(decodedPacket); } }
private void ReplacePacketByStructureSize(Smb2Packet packet) { Smb2CreateRequestPacket request = packet as Smb2CreateRequestPacket; if (request == null) return; request.PayLoad.StructureSize += 1; }
public Smb2Packet DecodeTransportPayload( byte[] messageBytes, Smb2Role role, Smb2TransportType transportType, bool ignoreCompoundFlag, out int consumedLength, out int expectedLength ) { //tcp transport will prefix 4 bytes length in the beginning. and netbios won't do this. if (transportType == Smb2TransportType.Tcp) { if (messageBytes.Length < Smb2Consts.TcpPrefixedLenByteCount) { consumedLength = 0; expectedLength = 4; return(null); } //in the header of tcp payload, there are 4 bytes(in fact only 3 bytes are used) which indicate //the length of smb2 int dataLenShouldHave = (messageBytes[1] << 16) + (messageBytes[2] << 8) + messageBytes[3]; if (dataLenShouldHave > (messageBytes.Length - Smb2Consts.TcpPrefixedLenByteCount)) { consumedLength = 0; expectedLength = Smb2Consts.TcpPrefixedLenByteCount + dataLenShouldHave; return(null); } byte[] smb2Message = new byte[dataLenShouldHave]; Array.Copy(messageBytes, Smb2Consts.TcpPrefixedLenByteCount, smb2Message, 0, smb2Message.Length); Smb2Packet packet = DecodeCompletePacket( smb2Message, role, ignoreCompoundFlag, 0, 0, out consumedLength, out expectedLength); // Here we ignore the consumedLength returned by DecodeCompletePacket(), there may be some tcp padding data // at the end which we are not interested. consumedLength = dataLenShouldHave + Smb2Consts.TcpPrefixedLenByteCount; return(packet); } else { Smb2Packet packet = DecodeCompletePacket( messageBytes, role, ignoreCompoundFlag, 0, 0, out consumedLength, out expectedLength); //Some packet has unknown padding data at the end. consumedLength = messageBytes.Length; return(packet); } }
private void ReplacePacketByInvalidCreateContext(Smb2Packet packet) { Smb2CreateRequestPacket request = packet as Smb2CreateRequestPacket; if (request == null) return; string name = null; byte[] dataBuffer = null; if (Create_ContextType == CreateContextType.InvalidCreateContext) { // <235> Section 3.3.5.9: Windows Vista, Windows Server 2008, Windows 7, Windows Server 2008 R2, Windows 8, and Windows Server 2012 ignore create contexts having a NameLength // greater than 4 and ignores create contexts with length of 4 that are not specified in section 2.2.13.2. // So use three characters' name here name = "Inv"; } else if (Create_ContextType == CreateContextType.InvalidCreateContextSize) { // Use SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST to test InvalidCreateContextSize since it contains a data buffer. name = CreateContextNames.SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST; var createQueryMaximalAccessRequestStruct = new CREATE_QUERY_MAXIMAL_ACCESS_REQUEST { Timestamp = new _FILETIME() }; dataBuffer = TypeMarshal.ToBytes(createQueryMaximalAccessRequestStruct); } else { throw new ArgumentException("contextType"); } var nameBuffer = Encoding.ASCII.GetBytes(name); var createContextStruct = new CREATE_CONTEXT(); createContextStruct.Buffer = new byte[0]; createContextStruct.NameOffset = 16; createContextStruct.NameLength = (ushort)nameBuffer.Length; createContextStruct.Buffer = createContextStruct.Buffer.Concat(nameBuffer).ToArray(); if (dataBuffer != null && dataBuffer.Length > 0) { Smb2Utility.Align8(ref createContextStruct.Buffer); createContextStruct.DataOffset = (ushort)(16 + createContextStruct.Buffer.Length); createContextStruct.DataLength = (uint)dataBuffer.Length; createContextStruct.Buffer = createContextStruct.Buffer.Concat(dataBuffer).ToArray(); } byte[] createContextValuesBuffer = new byte[0]; Smb2Utility.Align8(ref createContextValuesBuffer); createContextValuesBuffer = createContextValuesBuffer.Concat(TypeMarshal.ToBytes(createContextStruct)).ToArray(); if (Create_ContextType == CreateContextType.InvalidCreateContextSize) { // Change DataLength to invalid here, after marshalling. createContextValuesBuffer[12] += 1; } Smb2Utility.Align8(ref request.Buffer); request.PayLoad.CreateContextsOffset = (uint)(request.BufferOffset + request.Buffer.Length); request.PayLoad.CreateContextsLength = (uint)createContextValuesBuffer.Length; request.Buffer = request.Buffer.Concat(createContextValuesBuffer).ToArray(); }
/// <summary> /// Send packet to a client specified by the endpoint, this method is for negative test, for normal use, please use /// SendPacket(Smb2Packet packet) /// </summary> /// <param name="endpoint">The client endpoint</param> /// <param name="packet">The packet</param> public void SendPacket(Smb2Endpoint endpoint, Smb2Packet packet) { Smb2Event smb2Event = new Smb2Event(); smb2Event.ConnectionId = endpoint.EndpointId; smb2Event.Packet = packet; smb2Event.Type = Smb2EventType.PacketSent; context.UpdateContext(smb2Event); SendPacket(endpoint, packet.ToBytes()); }
/// <summary> /// Send packet to a client /// </summary> /// <param name="packet">The packet</param> public void SendPacket(Smb2Packet packet) { SendPacket(packet.Endpoint, packet); }
private bool VerifyMessageId(Smb2Packet packet, int connectionId) { ulong messageId = 0; List<ulong> messageIds = new List<ulong>(); bool allMessageIdValid = true; if (packet is SmbNegotiateRequestPacket) { messageId = (packet as SmbNegotiateRequestPacket).Header.Mid; messageIds.Add(messageId); } else { Smb2SinglePacket singlePacket = packet as Smb2SinglePacket; if (singlePacket is Smb2CancelRequestPacket) { return true; } else { messageId = singlePacket.Header.MessageId; messageIds.Add(messageId); int messageIdIndex = connectionList[connectionId].commandSequenceWindow.IndexOf(messageId); if (messageIdIndex == -1) { return false; } uint maxLen = 0; if (transportType == Smb2TransportType.Tcp && connectionList[connectionId].dialect == Smb2Consts.NegotiateDialect2_10String && singlePacket.Header.CreditCharge != 0) { switch (singlePacket.Header.Command) { case Smb2Command.READ: Smb2ReadRequestPacket readRequest = singlePacket as Smb2ReadRequestPacket; maxLen = readRequest.PayLoad.Length; break; case Smb2Command.WRITE: Smb2WriteRequestPacket writeRequet = singlePacket as Smb2WriteRequestPacket; maxLen = writeRequet.PayLoad.Length; break; case Smb2Command.CHANGE_NOTIFY: Smb2ChangeNotifyRequestPacket changeNotifyRequest = singlePacket as Smb2ChangeNotifyRequestPacket; maxLen = changeNotifyRequest.PayLoad.OutputBufferLength; break; case Smb2Command.QUERY_DIRECTORY: Smb2QueryDirectoryRequestPacket queryDirectory = singlePacket as Smb2QueryDirectoryRequestPacket; maxLen = queryDirectory.PayLoad.OutputBufferLength; break; } //CreditCharge >= (max(SendPayloadSize, Expected ResponsePayloadSize) ¨C 1)/ 65536 + 1 int expectedCreditCharge = 1 + ((int)maxLen - 1) / 65536; if (expectedCreditCharge > singlePacket.Header.CreditCharge) { throw new InvalidOperationException(string.Format("The CreditCharge in header is not valid. The expected value is {0}, " + "and the actual value is {1}", expectedCreditCharge, singlePacket.Header.CreditCharge)); } for (int i = 1; i < singlePacket.Header.CreditCharge; i++) { if ((messageIdIndex + i) < connectionList[connectionId].commandSequenceWindow.Count) { messageIds.Add(connectionList[connectionId].commandSequenceWindow[messageIdIndex + i]); } else { allMessageIdValid = false; break; } } } } } foreach (ulong item in messageIds) { if (connectionList[connectionId].commandSequenceWindow.Contains(item)) { connectionList[connectionId].RemoveMessageId(item); } else { allMessageIdValid = false; } } return allMessageIdValid; }
/// <summary> /// Set the sessionKey field of the packet /// </summary> /// <param name="connectionId">Used to find the connection</param> /// <param name="packet">The packet</param> private void SetSessionKeyInPacket(int connectionId, Smb2Packet packet) { Smb2SinglePacket singlePacket = packet as Smb2SinglePacket; if (singlePacket != null) { if ((singlePacket.Header.Flags & Packet_Header_Flags_Values.FLAGS_SIGNED) != Packet_Header_Flags_Values.FLAGS_SIGNED) { return; } singlePacket.SessionKey = globalSessionTable[singlePacket.GetSessionId()].sessionKey; } else { //it is smb negotiate packet, do not need verify signature. } }