Skip to content

Commit

Permalink
Update per review
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-xu committed Dec 30, 2024
1 parent 24d142b commit 510a795
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 52 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ Supported server key algorithms:
- rsa-sha2-256

Supported key exchange methods:
- curve25519-sha256
- curve25519-sha256@libssh.org
- curve25519-sha256, curve25519-sha256@libssh.org
- ecdh-sha2-nistp256
- ecdh-sha2-nistp384
- ecdh-sha2-nistp521
Expand Down
21 changes: 8 additions & 13 deletions src/Tmds.Ssh/Curve25519KeyExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
namespace Tmds.Ssh;

// Curve25519 Key Exchange: https://datatracker.ietf.org/doc/html/rfc8731
class Curve25519KeyExchange : KeyExchange, IKeyExchangeAlgorithm
class Curve25519KeyExchange : KeyExchange
{
public Curve25519KeyExchange()
:base(HashAlgorithmName.SHA256)
{ }
private readonly HashAlgorithmName _hashAlgorithmName = HashAlgorithmName.SHA256;

public async Task<KeyExchangeOutput> TryExchangeAsync(KeyExchangeContext context, IHostKeyVerification hostKeyVerification, Packet firstPacket, KeyExchangeInput input, ILogger logger, CancellationToken ct)
public override async Task<KeyExchangeOutput> TryExchangeAsync(KeyExchangeContext context, IHostKeyVerification hostKeyVerification, Packet firstPacket, KeyExchangeInput input, ILogger logger, CancellationToken ct)
{
var sequencePool = context.SequencePool;
var connectionInfo = input.ConnectionInfo;
Expand Down Expand Up @@ -57,15 +55,15 @@ public async Task<KeyExchangeOutput> TryExchangeAsync(KeyExchangeContext context
}

// Generate exchange hash.
byte[] exchangeHash = CalculateExchangeHash(sequencePool, input.ConnectionInfo, input.ClientKexInitMsg, input.ServerKexInitMsg, ecdhReply.public_host_key.Data, q_c, ecdhReply.q_s, sharedSecret);
byte[] exchangeHash = CalculateExchangeHash(sequencePool, input.ConnectionInfo, input.ClientKexInitMsg, input.ServerKexInitMsg, ecdhReply.public_host_key.Data, q_c, ecdhReply.q_s, sharedSecret, _hashAlgorithmName);

// Verify the server's signature.
VerifySignature(publicHostKey, input.HostKeyAlgorithms, exchangeHash, ecdhReply.exchange_hash_signature, connectionInfo);

return CalculateKeyExchangeOutput(input, sequencePool, sharedSecret, exchangeHash);
return CalculateKeyExchangeOutput(input, sequencePool, sharedSecret, exchangeHash, _hashAlgorithmName);
}

private byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInfo connectionInfo, ReadOnlyPacket clientKexInitMsg, ReadOnlyPacket serverKexInitMsg, byte[] public_host_key, byte[] q_c, byte[] q_s, BigInteger sharedSecret)
private static byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInfo connectionInfo, ReadOnlyPacket clientKexInitMsg, ReadOnlyPacket serverKexInitMsg, byte[] public_host_key, byte[] q_c, byte[] q_s, BigInteger sharedSecret, HashAlgorithmName hashAlgorithmName)
{
/*
string V_C, client's identification string (CR and LF excluded)
Expand All @@ -88,15 +86,15 @@ private byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInf
writer.WriteString(q_s);
writer.WriteMPInt(sharedSecret);

using IncrementalHash hash = IncrementalHash.CreateHash(_hashAlgorithmName);
using IncrementalHash hash = IncrementalHash.CreateHash(hashAlgorithmName);
foreach (var segment in sequence.AsReadOnlySequence())
{
hash.AppendData(segment.Span);
}
return hash.GetHashAndReset();
}

private BigInteger DeriveSharedSecret(AsymmetricKeyParameter privateKey, AsymmetricKeyParameter peerPublicKey)
private static BigInteger DeriveSharedSecret(AsymmetricKeyParameter privateKey, AsymmetricKeyParameter peerPublicKey)
{
var keyAgreement = new X25519Agreement();
keyAgreement.Init(privateKey);
Expand All @@ -108,9 +106,6 @@ private BigInteger DeriveSharedSecret(AsymmetricKeyParameter privateKey, Asymmet
return sharedSecret;
}

public void Dispose()
{ }

private static Packet CreateEcdhInitMessage(SequencePool sequencePool, ReadOnlySpan<byte> q_c)
{
using var packet = sequencePool.RentPacket();
Expand Down
24 changes: 11 additions & 13 deletions src/Tmds.Ssh/ECDHKeyExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
namespace Tmds.Ssh;

// ECDH Key Exchange: https://tools.ietf.org/html/rfc5656#section-4
class ECDHKeyExchange : KeyExchange, IKeyExchangeAlgorithm
class ECDHKeyExchange : KeyExchange
{
private readonly ECCurve _ecCurve;
private readonly HashAlgorithmName _hashAlgorithmName;

public ECDHKeyExchange(ECCurve ecCurve, HashAlgorithmName hashAlgorithmName)
: base(hashAlgorithmName)
{
_ecCurve = ecCurve;
this._hashAlgorithmName = hashAlgorithmName;
}

public async Task<KeyExchangeOutput> TryExchangeAsync(KeyExchangeContext context, IHostKeyVerification hostKeyVerification, Packet firstPacket, KeyExchangeInput input, ILogger logger, CancellationToken ct)
public override async Task<KeyExchangeOutput> TryExchangeAsync(KeyExchangeContext context, IHostKeyVerification hostKeyVerification, Packet firstPacket, KeyExchangeInput input, ILogger logger, CancellationToken ct)
{
var sequencePool = context.SequencePool;
var connectionInfo = input.ConnectionInfo;
Expand All @@ -41,23 +42,23 @@ public async Task<KeyExchangeOutput> TryExchangeAsync(KeyExchangeContext context
BigInteger sharedSecret;
try
{
sharedSecret = DeriveSharedSecret(ecdh, ecdhReply.q_s);
sharedSecret = DeriveSharedSecret(ecdh, _ecCurve, ecdhReply.q_s);
}
catch (Exception ex)
{
throw new ConnectFailedException(ConnectFailedReason.KeyExchangeFailed, "Cannot determine shared secret.", connectionInfo, ex);
}

// Generate exchange hash.
byte[] exchangeHash = CalculateExchangeHash(sequencePool, input.ConnectionInfo, input.ClientKexInitMsg, input.ServerKexInitMsg, ecdhReply.public_host_key.Data, q_c, ecdhReply.q_s, sharedSecret);
byte[] exchangeHash = CalculateExchangeHash(sequencePool, input.ConnectionInfo, input.ClientKexInitMsg, input.ServerKexInitMsg, ecdhReply.public_host_key.Data, q_c, ecdhReply.q_s, sharedSecret, _hashAlgorithmName);

// Verify the server's signature.
VerifySignature(publicHostKey, input.HostKeyAlgorithms, exchangeHash, ecdhReply.exchange_hash_signature, connectionInfo);

return CalculateKeyExchangeOutput(input, sequencePool, sharedSecret, exchangeHash);
return CalculateKeyExchangeOutput(input, sequencePool, sharedSecret, exchangeHash, _hashAlgorithmName);
}

private byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInfo connectionInfo, ReadOnlyPacket clientKexInitMsg, ReadOnlyPacket serverKexInitMsg, byte[] public_host_key, ECPoint q_c, ECPoint q_s, BigInteger sharedSecret)
private static byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInfo connectionInfo, ReadOnlyPacket clientKexInitMsg, ReadOnlyPacket serverKexInitMsg, byte[] public_host_key, ECPoint q_c, ECPoint q_s, BigInteger sharedSecret, HashAlgorithmName hashAlgorithmName)
{
/*
string V_C, client's identification string (CR and LF excluded)
Expand All @@ -80,19 +81,19 @@ private byte[] CalculateExchangeHash(SequencePool sequencePool, SshConnectionInf
writer.WriteString(q_s);
writer.WriteMPInt(sharedSecret);

using IncrementalHash hash = IncrementalHash.CreateHash(_hashAlgorithmName);
using IncrementalHash hash = IncrementalHash.CreateHash(hashAlgorithmName);
foreach (var segment in sequence.AsReadOnlySequence())
{
hash.AppendData(segment.Span);
}
return hash.GetHashAndReset();
}

private BigInteger DeriveSharedSecret(ECDiffieHellman ecdh, ECPoint q)
private static BigInteger DeriveSharedSecret(ECDiffieHellman ecdh, ECCurve curve, ECPoint q)
{
ECParameters parameters = new ECParameters
{
Curve = _ecCurve,
Curve = curve,
Q = q
};
using ECDiffieHellman peerEcdh = ECDiffieHellman.Create(parameters);
Expand All @@ -103,9 +104,6 @@ private BigInteger DeriveSharedSecret(ECDiffieHellman ecdh, ECPoint q)
return sharedSecret;
}

public void Dispose()
{ }

private static Packet CreateEcdhInitMessage(SequencePool sequencePool, ECPoint q_c)
{
using var packet = sequencePool.RentPacket();
Expand Down
47 changes: 23 additions & 24 deletions src/Tmds.Ssh/KeyExchange.cs
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Buffers;
using System.Numerics;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

namespace Tmds.Ssh;

abstract class KeyExchange
abstract class KeyExchange : IKeyExchangeAlgorithm
{
protected readonly HashAlgorithmName _hashAlgorithmName;
public abstract Task<KeyExchangeOutput> TryExchangeAsync(KeyExchangeContext context, IHostKeyVerification hostKeyVerification, Packet firstPacket, KeyExchangeInput input, ILogger logger, CancellationToken ct);

public KeyExchange(HashAlgorithmName hashAlgorithmName)
{
this._hashAlgorithmName = hashAlgorithmName;
}

public static async Task<PublicKey> VerifyHostKeyAsync(IHostKeyVerification hostKeyVerification, KeyExchangeInput input, SshKey public_host_key, CancellationToken ct)
protected static async Task<PublicKey> VerifyHostKeyAsync(IHostKeyVerification hostKeyVerification, KeyExchangeInput input, SshKey public_host_key, CancellationToken ct)
{
var connectionInfo = input.ConnectionInfo;
connectionInfo.ServerKey = new HostKey(public_host_key);
Expand All @@ -33,30 +24,30 @@ public static async Task<PublicKey> VerifyHostKeyAsync(IHostKeyVerification host
return publicHostKey;
}

public static void VerifySignature(PublicKey publicHostKey, IReadOnlyList<Name> allowedAlgorithms, byte[] exchangeHash, ReadOnlySequence<byte> exchange_hash_signature, SshConnectionInfo connectionInfo)
protected static void VerifySignature(PublicKey publicHostKey, IReadOnlyList<Name> allowedAlgorithms, byte[] exchangeHash, ReadOnlySequence<byte> exchange_hash_signature, SshConnectionInfo connectionInfo)
{
if (!publicHostKey.VerifySignature(allowedAlgorithms, exchangeHash, exchange_hash_signature))
{
throw new ConnectFailedException(ConnectFailedReason.KeyExchangeFailed, "Signature does not match host key.", connectionInfo);
}
}

public KeyExchangeOutput CalculateKeyExchangeOutput(KeyExchangeInput input, SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash)
protected static KeyExchangeOutput CalculateKeyExchangeOutput(KeyExchangeInput input, SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash, HashAlgorithmName hashAlgorithmName)
{
byte[] sessionId = input.ConnectionInfo.SessionId ?? exchangeHash;
byte[] initialIVC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'A', sessionId, input.InitialIVC2SLength);
byte[] initialIVS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'B', sessionId, input.InitialIVS2CLength);
byte[] encryptionKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'C', sessionId, input.EncryptionKeyC2SLength);
byte[] encryptionKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'D', sessionId, input.EncryptionKeyS2CLength);
byte[] integrityKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'E', sessionId, input.IntegrityKeyC2SLength);
byte[] integrityKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'F', sessionId, input.IntegrityKeyS2CLength);
byte[] initialIVC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'A', sessionId, input.InitialIVC2SLength, hashAlgorithmName);
byte[] initialIVS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'B', sessionId, input.InitialIVS2CLength, hashAlgorithmName);
byte[] encryptionKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'C', sessionId, input.EncryptionKeyC2SLength, hashAlgorithmName);
byte[] encryptionKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'D', sessionId, input.EncryptionKeyS2CLength, hashAlgorithmName);
byte[] integrityKeyC2S = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'E', sessionId, input.IntegrityKeyC2SLength, hashAlgorithmName);
byte[] integrityKeyS2C = CalculateKey(sequencePool, sharedSecret, exchangeHash, (byte)'F', sessionId, input.IntegrityKeyS2CLength, hashAlgorithmName);

return new KeyExchangeOutput(exchangeHash,
initialIVS2C, encryptionKeyS2C, integrityKeyS2C,
initialIVC2S, encryptionKeyC2S, integrityKeyC2S);
}

private byte[] CalculateKey(SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash, byte c, byte[] sessionId, int keyLength)
protected static byte[] CalculateKey(SequencePool sequencePool, BigInteger sharedSecret, byte[] exchangeHash, byte c, byte[] sessionId, int keyLength, HashAlgorithmName hashAlgorithmName)
{
// https://tools.ietf.org/html/rfc4253#section-7.2

Expand All @@ -71,7 +62,7 @@ private byte[] CalculateKey(SequencePool sequencePool, BigInteger sharedSecret,
writer.WriteByte(c);
writer.Write(sessionId);

using IncrementalHash hash = IncrementalHash.CreateHash(_hashAlgorithmName);
using IncrementalHash hash = IncrementalHash.CreateHash(hashAlgorithmName);
foreach (var segment in sequence.AsReadOnlySequence())
{
hash.AppendData(segment.Span);
Expand Down Expand Up @@ -108,4 +99,12 @@ static void Append(byte[] key, byte[] append, ref int offset)
}
}

protected virtual void Dispose(bool disposing)
{ }

public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}

0 comments on commit 510a795

Please sign in to comment.