Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WASI] sockets #106977

Merged
merged 16 commits into from
Oct 2, 2024
Prev Previous commit
Next Next commit
fix PacketInformation and enable the test
  • Loading branch information
pavelsavara committed Sep 26, 2024
commit 1a14c50e4fa5100d2be7bae57510e4b8db257c52
16 changes: 14 additions & 2 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public partial class Socket : IDisposable
private ProtocolType _protocolType;

// Bool marked true if the native socket option IP_PKTINFO or IPV6_PKTINFO has been set.
private bool _receivingPacketInformation;
private bool _receivingPacketInformation = OperatingSystem.IsWasi();

private int _closeTimeout = Socket.DefaultCloseTimeout;
private bool _disposed;
Expand Down Expand Up @@ -3166,7 +3166,7 @@ private bool ReceiveMessageFromAsync(SocketAsyncEventArgs e, CancellationToken c
// e.m_SocketAddres for Create to work later.
e.RemoteEndPoint = endPointSnapshot;

if (!OperatingSystem.IsWasi()) SetReceivingPacketInformation();
SetReceivingPacketInformation();

// Prepare for and make the native call.
e.StartOperationCommon(this, SocketAsyncOperation.ReceiveMessageFrom);
Expand Down Expand Up @@ -3559,6 +3559,8 @@ internal void SetReceivingPacketInformation()
{
if (!_receivingPacketInformation)
{
if (OperatingSystem.IsWasi()) return; // WASI is always set to receive PacketInformation

// DualMode: When bound to IPv6Any you must enable both socket options.
// When bound to an IPv4 mapped IPv6 address you must enable the IPv4 socket option.
IPEndPoint? ipEndPoint = _rightEndPoint as IPEndPoint;
Expand Down Expand Up @@ -3586,6 +3588,16 @@ internal void SetReceivingPacketInformation()

internal unsafe void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName optionName, int optionValue, bool silent)
{
// WASI is always set to receive PacketInformation
if (OperatingSystem.IsWasi() && optionName == SocketOptionName.PacketInformation)
{
if (optionValue == 0)
{
UpdateStatusAfterSocketOptionErrorAndThrowException(SocketError.ProtocolOption);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need to throw if the platform does more than asked for. If we have concerns we can can strip to info internally IMHO.

Copy link
Member Author

@pavelsavara pavelsavara Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only throwing when user asks to disable PacketInformation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right. We can still simply throw it out. And not receiving the info is default AFAIK. mostly perf improvement IMHo but that probably does not matter for your use case.

}
return;
}

if (silent && (Disposed || _handle.IsInvalid))
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "skipping the call");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Threading;
using System.Threading.Tasks;

using Xunit;

namespace System.Net.Sockets.Tests
{
[SkipOnPlatform(TestPlatforms.Wasi, "Wasi doesn't support IPPacketInformation")]
public class IPPacketInformationTest
{
[Fact]
Expand All @@ -28,9 +28,9 @@ public void GetHashCode_DefaultValues_Success()
}

[Fact]
public void Equals_NonDefaultValue_Success()
public async Task Equals_NonDefaultValue_Success()
{
IPPacketInformation packetInfo = GetNonDefaultIPPacketInformation();
IPPacketInformation packetInfo = await GetNonDefaultIPPacketInformation();
IPPacketInformation packetInfoCopy = packetInfo;

Assert.Equal(packetInfo, packetInfoCopy);
Expand All @@ -49,40 +49,28 @@ public void Equals_NonDefaultValue_Success()
}

[Fact]
public void GetHashCode_NonDefaultValue_Success()
public async Task GetHashCode_NonDefaultValue_Success()
{
IPPacketInformation packetInfo = GetNonDefaultIPPacketInformation();
IPPacketInformation packetInfo = await GetNonDefaultIPPacketInformation();

Assert.Equal(packetInfo.GetHashCode(), packetInfo.GetHashCode());
}

private IPPacketInformation GetNonDefaultIPPacketInformation()
private async Task<IPPacketInformation> GetNonDefaultIPPacketInformation()
{
const int ReceiveTimeout = 10000;

using (var receiver = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
using (var sender = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
{
int port = receiver.BindToAnonymousPort(IPAddress.Loopback);

var waitHandle = new ManualResetEvent(false);

SocketAsyncEventArgs receiveArgs = new SocketAsyncEventArgs {
RemoteEndPoint = new IPEndPoint(IPAddress.Loopback, port),
UserToken = waitHandle
};

receiveArgs.SetBuffer(new byte[1], 0, 1);
receiveArgs.Completed += (_, args) => ((ManualResetEvent)args.UserToken).Set();

Assert.True(receiver.ReceiveMessageFromAsync(receiveArgs), "receiver.ReceiveMessageFromAsync");

// Send a few packets, in case they aren't delivered reliably.
sender.SendTo(new byte[1], new IPEndPoint(IPAddress.Loopback, port));
var receiveTask = receiver.ReceiveMessageFromAsync(new byte[1], new IPEndPoint(IPAddress.Loopback, port));
var sendTask = sender.SendToAsync(new byte[1], new IPEndPoint(IPAddress.Loopback, port));

Assert.True(await Task.WhenAny(receiveTask, Task.Delay(TestSettings.PassingTestTimeout)) == receiveTask, "Timed out");

Assert.True(waitHandle.WaitOne(ReceiveTimeout), "waitHandle.WaitOne");
var result = await receiveTask;

return receiveArgs.ReceiveMessageFromPacketInfo;
return result.PacketInformation;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public async Task ReceiveSent_UDP_Success(bool ipv4)
using Socket receiver = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using Socket sender = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp);

if (!OperatingSystem.IsWasi()) receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true);
receiver.SetSocketOption(ipv4 ? SocketOptionLevel.IP : SocketOptionLevel.IPv6, SocketOptionName.PacketInformation, true);
ConfigureNonBlocking(sender);
ConfigureNonBlocking(receiver);

Expand All @@ -154,11 +154,8 @@ public async Task ReceiveSent_UDP_Success(bool ipv4)
AssertExtensions.SequenceEqual(emptyBuffer, new ReadOnlySpan<byte>(receiveInternalBuffer, 0, Offset));
AssertExtensions.SequenceEqual(sendBuffer, new ReadOnlySpan<byte>(receiveInternalBuffer, Offset, DatagramSize));
Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint);
if (!OperatingSystem.IsWasi())
{
IPPacketInformation packetInformation = result.PacketInformation;
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, packetInformation.Address);
}
IPPacketInformation packetInformation = result.PacketInformation;
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, packetInformation.Address);
}
}

Expand Down
48 changes: 33 additions & 15 deletions src/native/libs/System.Native/pal_networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,6 @@ static struct cmsghdr* GET_CMSG_NXTHDR(struct msghdr* mhdr, struct cmsghdr* cmsg
#pragma clang diagnostic pop
#endif
}
#endif // CMSG_SPACE

int32_t
SystemNative_TryGetIPPacketInformation(MessageHeader* messageHeader, int32_t isIPv4, IPPacketInformation* packetInfo)
Expand All @@ -1064,7 +1063,6 @@ SystemNative_TryGetIPPacketInformation(MessageHeader* messageHeader, int32_t isI
return 0;
}

#if defined(CMSG_SPACE)
struct msghdr header;
ConvertMessageHeaderToMsghdr(&header, messageHeader, -1);

Expand Down Expand Up @@ -1093,13 +1091,35 @@ SystemNative_TryGetIPPacketInformation(MessageHeader* messageHeader, int32_t isI
}

return 0;
#else // CMSG_SPACE
(void)messageHeader;
(void)isIPv4;
(void)packetInfo;
return Error_ENOTSUP;
#endif // CMSG_SPACE
}
#else // !CMSG_SPACE
int32_t
SystemNative_TryGetIPPacketInformation(MessageHeader* messageHeader, int32_t isIPv4, IPPacketInformation* packetInfo)
{
if (messageHeader == NULL || packetInfo == NULL)
{
return 0;
}

if (isIPv4 != 0)
{
struct sockaddr_in* inetSockAddr = (struct sockaddr_in*)messageHeader->SocketAddress;

ConvertInAddrToByteArray(&packetInfo->Address.Address[0], NUM_BYTES_IN_IPV4_ADDRESS, &inetSockAddr->sin_addr);
packetInfo->Address.IsIPv6 = 0;
}
else
{
struct sockaddr_in6* inet6SockAddr = (struct sockaddr_in6*)messageHeader->SocketAddress;

ConvertIn6AddrToByteArray(&packetInfo->Address.Address[0], NUM_BYTES_IN_IPV6_ADDRESS, &inet6SockAddr->sin6_addr);
packetInfo->Address.IsIPv6 = 1;
packetInfo->Address.ScopeId = inet6SockAddr->sin6_scope_id;
}
packetInfo->InterfaceIndex = 0;
return 1;
}
#endif // !CMSG_SPACE

static int8_t GetMulticastOptionName(int32_t multicastOption, int8_t isIPv6, int* optionName)
{
Expand Down Expand Up @@ -1553,18 +1573,16 @@ int32_t SystemNative_ReceiveMessage(intptr_t socket, MessageHeader* messageHeade
}

ssize_t res;
#if defined(CMSG_SPACE)
#if !defined(CMSG_SPACE)
// we will only use 0th buffer
struct iovec* msg_iov = (struct iovec*)messageHeader->IOVectors;
while ((res = recvfrom(fd, msg_iov[0].iov_base, msg_iov[0].iov_len, socketFlags, (sockaddr *)messageHeader->SocketAddress, (socklen_t*) &(messageHeader->SocketAddressLen))) < 0 && errno == EINTR);
#else // CMSG_SPACE
struct msghdr header;
ConvertMessageHeaderToMsghdr(&header, messageHeader, fd);

while ((res = recvmsg(fd, &header, socketFlags)) < 0 && errno == EINTR);
#else // CMSG_SPACE
// we will only use 0th buffer
struct iovec* msg_iov = (struct iovec*)messageHeader->IOVectors;
while ((res = recvfrom(fd, msg_iov[0].iov_base, msg_iov[0].iov_len, socketFlags, (sockaddr *)messageHeader->SocketAddress, (socklen_t*) &(messageHeader->SocketAddressLen))) < 0 && errno == EINTR);
#endif // CMSG_SPACE

#if defined(CMSG_SPACE)
assert(header.msg_name == messageHeader->SocketAddress); // should still be the same location as set in ConvertMessageHeaderToMsghdr
assert(header.msg_control == messageHeader->ControlBuffer);

Expand Down