Skip to content

Commit

Permalink
LocalForward: support serving from unix socket. (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Dec 6, 2024
1 parent a766b92 commit f8e6526
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 23 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ class SshClient : IDisposable

Task<SshDataStream> OpenTcpConnectionAsync(string host, int port, CancellationToken cancellationToken = default);
Task<SshDataStream> OpenUnixConnectionAsync(string path, CancellationToken cancellationToken = default);
Task<LocalForward> StartForwardTcpAsync(EndPoint bindEndpoint, string remoteHost, int remotePort, CancellationToken cancellationToken = default);
Task<LocalForward> StartForwardUnixAsync(EndPoint bindEndpoint, string remotePath, CancellationToken cancellationToken = default);
// bindEP can be an IPEndPoint or a UnixDomainSocketEndPoint.
Task<LocalForward> StartForwardTcpAsync(EndPoint bindEP, string remoteHost, int remotePort, CancellationToken cancellationToken = default);
Task<LocalForward> StartForwardUnixAsync(EndPoint bindEP, string remotePath, CancellationToken cancellationToken = default);

Task<SftpClient> OpenSftpClientAsync(CancellationToken cancellationToken);
Task<SftpClient> OpenSftpClientAsync(SftpClientOptions? options = null, CancellationToken cancellationToken = default)
Expand Down
47 changes: 27 additions & 20 deletions src/Tmds.Ssh/LocalForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ internal LocalForward(SshSession session, ILogger<LocalForward> logger)
_remoteEndPoint = "";
}

internal void StartUnixForward(EndPoint bindEndpoint, string remotePath)
internal void StartUnixForward(EndPoint bindEP, string remotePath)
{
CheckBindEndPoint(bindEndpoint);
CheckBindEndPoint(bindEP);
ArgumentException.ThrowIfNullOrEmpty(remotePath);

Func<CancellationToken, Task<SshDataStream>> connect = async ct => await _session.OpenUnixConnectionChannelAsync(remotePath, ct).ConfigureAwait(false);

Start(bindEndpoint, remotePath, connect);
Start(bindEP, remotePath, connect);
}

internal void StartTcpForward(EndPoint bindEndpoint, string remoteHost, int remotePort)
internal void StartTcpForward(EndPoint bindEP, string remoteHost, int remotePort)
{
CheckBindEndPoint(bindEndpoint);
CheckBindEndPoint(bindEP);
ArgumentException.ThrowIfNullOrEmpty(remoteHost);
if (remotePort < 0 || remotePort > 0xffff)
{
Expand All @@ -57,38 +57,42 @@ internal void StartTcpForward(EndPoint bindEndpoint, string remoteHost, int remo

Func<CancellationToken, Task<SshDataStream>> connect = async ct => await _session.OpenTcpConnectionChannelAsync(remoteHost, remotePort, ct).ConfigureAwait(false);

Start(bindEndpoint, $"{remoteHost}:{remotePort}", connect);
Start(bindEP, $"{remoteHost}:{remotePort}", connect);
}

private void CheckBindEndPoint(EndPoint bindEndpoint)
private void CheckBindEndPoint(EndPoint bindEP)
{
ArgumentNullException.ThrowIfNull(bindEndpoint);
if (bindEndpoint is not IPEndPoint)
ArgumentNullException.ThrowIfNull(bindEP);
if (bindEP is not IPEndPoint and not UnixDomainSocketEndPoint)
{
throw new ArgumentException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}.");
throw new ArgumentException($"Unsupported EndPoint type: {bindEP.GetType().FullName}.");
}
}

private void Start(EndPoint bindEndpoint, string remoteEndpoint, Func<CancellationToken, Task<SshDataStream>> connectToRemote)
private void Start(EndPoint bindEP, string remoteEndPoint, Func<CancellationToken, Task<SshDataStream>> connectToRemote)
{
// Assign to bindEndPoint in case we fail to bind/listen so we have an address for logging.
_localEndPoint = bindEndpoint;
_remoteEndPoint = remoteEndpoint;
// Assign to bindEP in case we fail to bind/listen so we have an address for logging.
_localEndPoint = bindEP;
_remoteEndPoint = remoteEndPoint;
_connectToRemote = connectToRemote;

try
{
if (bindEndpoint is IPEndPoint ipEndPoint)
if (bindEP is IPEndPoint ipEndPoint)
{
_serverSocket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
}
else if (bindEP is UnixDomainSocketEndPoint unixEndPoint)
{
_serverSocket = new Socket(unixEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Unspecified);
}
else
{
// Type must be validated before calling this method.
throw new InvalidOperationException($"Unsupported EndPoint type: {bindEndpoint.GetType().FullName}.");
throw new InvalidOperationException($"Unsupported EndPoint type: {bindEP.GetType().FullName}.");
}

_serverSocket.Bind(bindEndpoint);
_serverSocket.Bind(bindEP);
_serverSocket.Listen();

EndPoint localEndPoint = _serverSocket.LocalEndPoint!;
Expand Down Expand Up @@ -123,16 +127,19 @@ private async Task AcceptLoop(EndPoint localEndPoint)
}
}

private async Task Accept(Socket acceptedSocket, EndPoint localEndpoint)
private async Task Accept(Socket acceptedSocket, EndPoint localEndPoint)
{
Debug.Assert(_connectToRemote is not null);
SshDataStream? forwardStream = null;
EndPoint? peerEndPoint = null;
try
{
peerEndPoint = acceptedSocket.RemoteEndPoint!;
_logger.AcceptConnection(localEndpoint, peerEndPoint, _remoteEndPoint);
acceptedSocket.NoDelay = true;
_logger.AcceptConnection(localEndPoint, peerEndPoint, _remoteEndPoint);
if (acceptedSocket.ProtocolType == ProtocolType.Tcp)
{
acceptedSocket.NoDelay = true;
}

// We may want to add a timeout option, and the ability to stop the lister on some conditions like nr of successive fails to connect to the remote.
forwardStream = await _connectToRemote(_cancel!.Token).ConfigureAwait(false);
Expand Down
17 changes: 16 additions & 1 deletion test/Tmds.Ssh.Tests/LocalForwardTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,29 @@ public async Task ForwardsUnix()
await AssertForwards(localForward);
}

[Fact]
public async Task BindUnixSocket()
{
using var client = await _sshServer.CreateClientAsync();

// start a an echo server using socat.
const int socatPort = 1234;
using var soCatProcess = await client.ExecuteAsync($"socat -v tcp-l:{socatPort},fork exec:'/bin/cat'");
await Task.Delay(SocatStartDelay); // wait a little for socat to start.

var ep = new UnixDomainSocketEndPoint(Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()));
using var localForward = await client.StartForwardTcpAsync(ep, "localhost", socatPort);
await AssertForwards(localForward);
}

private async Task AssertForwards(LocalForward localForward)
{
byte[] helloWorldBytes = Encoding.UTF8.GetBytes("hello world");
byte[] receiveBuffer = new byte[128];
for (int i = 0; i < 2; i++)
{
EndPoint endPoint = localForward.EndPoint!;
using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, endPoint.AddressFamily == AddressFamily.InterNetwork ? ProtocolType.Tcp : ProtocolType.Unspecified);
await socket.ConnectAsync(endPoint);

for (int j = 0; j < 2; j++)
Expand Down

0 comments on commit f8e6526

Please sign in to comment.