Skip to content

Handle timeout correctly on Socks5 Proxy #1342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/Renci.SshNet/Connection/Socks5Connector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
};
SocketAbstraction.Send(socket, greeting);

var socksVersion = SocketReadByte(socket);
var socksVersion = SocketReadByte(socket, connectionInfo.Timeout);
if (socksVersion != 0x05)
{
throw new ProxyException(string.Format("SOCKS Version '{0}' is not supported.", socksVersion));
}

var authenticationMethod = SocketReadByte(socket);
var authenticationMethod = SocketReadByte(socket, connectionInfo.Timeout);
switch (authenticationMethod)
{
case 0x00:
Expand Down Expand Up @@ -86,13 +86,13 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
SocketAbstraction.Send(socket, connectionRequest);

// Read Server SOCKS5 version
if (SocketReadByte(socket) != 5)
if (SocketReadByte(socket, connectionInfo.Timeout) != 5)
{
throw new ProxyException("SOCKS5: Version 5 is expected.");
}

// Read response code
var status = SocketReadByte(socket);
var status = SocketReadByte(socket, connectionInfo.Timeout);

switch (status)
{
Expand All @@ -119,21 +119,21 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
}

// Read reserved byte
if (SocketReadByte(socket) != 0)
if (SocketReadByte(socket, connectionInfo.Timeout) != 0)
{
throw new ProxyException("SOCKS5: 0 byte is expected.");
}

var addressType = SocketReadByte(socket);
var addressType = SocketReadByte(socket, connectionInfo.Timeout);
switch (addressType)
{
case 0x01:
var ipv4 = new byte[4];
_ = SocketRead(socket, ipv4, 0, 4);
_ = SocketRead(socket, ipv4, 0, 4, connectionInfo.Timeout);
break;
case 0x04:
var ipv6 = new byte[16];
_ =SocketRead(socket, ipv6, 0, 16);
_ =SocketRead(socket, ipv6, 0, 16, connectionInfo.Timeout);
break;
default:
throw new ProxyException(string.Format("Address type '{0}' is not supported.", addressType));
Expand All @@ -142,7 +142,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
var port = new byte[2];

// Read 2 bytes to be ignored
_ = SocketRead(socket, port, 0, 2);
_ = SocketRead(socket, port, 0, 2, connectionInfo.Timeout);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.Sockets;

using Microsoft.VisualStudio.TestTools.UnitTesting;

using Moq;

using Renci.SshNet.Common;
using Renci.SshNet.Tests.Common;

namespace Renci.SshNet.Tests.Classes.Connection
{
[TestClass]
public class Socks5ConnectorTest_Connect_TimeoutConnectionReply : Socks5ConnectorTestBase
{
private ConnectionInfo _connectionInfo;
private Exception _actualException;
private AsyncSocketListener _proxyServer;
private Socket _clientSocket;
private List<byte> _bytesReceivedByProxy;
private Stopwatch _stopWatch;

protected override void SetupData()
{
base.SetupData();

var random = new Random();

_connectionInfo = CreateConnectionInfo("proxyUser", "proxyPwd");
_connectionInfo.Timeout = TimeSpan.FromMilliseconds(random.Next(50, 200));
_stopWatch = new Stopwatch();
_bytesReceivedByProxy = new List<byte>();

_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

_proxyServer = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.ProxyPort));
_proxyServer.BytesReceived += (bytesReceived, socket) => {
_bytesReceivedByProxy.AddRange(bytesReceived);

if (_bytesReceivedByProxy.Count == 4) {
_ = socket.Send(new byte[]
{
// SOCKS version
0x05,
// Require no authentication
0x00
});
}
};
_proxyServer.Start();
}

protected override void SetupMocks()
{
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
.Returns(_clientSocket);
}

protected override void TearDown()
{
base.TearDown();

_proxyServer?.Dispose();
_clientSocket?.Dispose();
}

protected override void Act()
{
_stopWatch.Start();

try
{
_ = Connector.Connect(_connectionInfo);
Assert.Fail();
}
catch (SocketException ex) {
_actualException = ex;
}
catch (SshOperationTimeoutException ex) {
_actualException = ex;
}
finally
{
_stopWatch.Stop();
}
}

[TestMethod]
public void ConnectShouldHaveThrownSshOperationTimeoutException() {
Assert.IsNull(_actualException.InnerException);
Assert.IsInstanceOfType<SshOperationTimeoutException>(_actualException);
}

[TestMethod]
public void ConnectShouldHaveRespectedTimeout()
{
var errorText = string.Format("Elapsed: {0}, Timeout: {1}",
_stopWatch.ElapsedMilliseconds,
_connectionInfo.Timeout.TotalMilliseconds);

// Compare elapsed time with configured timeout, allowing for a margin of error
Assert.IsTrue(_stopWatch.ElapsedMilliseconds >= _connectionInfo.Timeout.TotalMilliseconds - 10, errorText);
Assert.IsTrue(_stopWatch.ElapsedMilliseconds < _connectionInfo.Timeout.TotalMilliseconds + 100, errorText);
}

[TestMethod]
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
{
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
Times.Once());
}
}
}