Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Also return the interface name #11

Merged
merged 23 commits into from
Sep 4, 2024
Merged
Prev Previous commit
Next Next commit
Return interface name
  • Loading branch information
larseggert committed Sep 3, 2024
commit dce0de85896c9a5a629bcd013f77caf4602365e5
49 changes: 17 additions & 32 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ use std::{
use log::trace;

/// Prepare a default error result.
fn default_result<T>() -> Result<(InterfaceId, T), Error> {
fn default_result<T>() -> Result<(String, T), Error> {
Err(Error::new(
ErrorKind::NotFound,
"Local interface MTU not found",
))
}

type InterfaceId = u64;

/// Return a unique interface ID and the maximum transmission unit (MTU) of the local network
/// interface towards the destination [`SocketAddr`] given in `remote`.
///
Expand All @@ -37,14 +35,14 @@ type InterfaceId = u64;
///
/// ```
/// let saddr = "127.0.0.1:443".parse().unwrap();
/// let (id, mtu) = mtu::get_interface_and_mtu(&saddr).unwrap();
/// println!("MTU towards {:?} is {}", saddr, mtu);
/// let (name, mtu) = mtu::get_interface_and_mtu(&saddr).unwrap();
/// println!("MTU towards {saddr:?} is {mtu} on {name}");
/// ```
///
/// # Errors
///
/// This function returns an error if the local interface MTU cannot be determined.
pub fn get_interface_and_mtu(remote: &SocketAddr) -> Result<(InterfaceId, usize), Error> {
pub fn get_interface_and_mtu(remote: &SocketAddr) -> Result<(String, usize), Error> {
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
#[allow(unused_assignments)] // Yes, res is reassigned in the platform-specific code.
let mut res = default_result();
Expand Down Expand Up @@ -79,13 +77,10 @@ pub fn get_interface_and_mtu(remote: &SocketAddr) -> Result<(InterfaceId, usize)
}

#[cfg(any(target_os = "macos", target_os = "linux"))]
fn get_interface_and_mtu_linux_macos(socket: &UdpSocket) -> Result<(InterfaceId, usize), Error> {
fn get_interface_and_mtu_linux_macos(socket: &UdpSocket) -> Result<(String, usize), Error> {
use std::ffi::{c_int, CStr};
#[cfg(target_os = "linux")]
use std::{ffi::c_char, mem, os::fd::AsRawFd};
use std::{
ffi::{c_int, CStr},
hash::{DefaultHasher, Hash, Hasher},
};

use libc::{
freeifaddrs, getifaddrs, ifaddrs, in_addr_t, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6,
Expand All @@ -95,12 +90,6 @@ fn get_interface_and_mtu_linux_macos(socket: &UdpSocket) -> Result<(InterfaceId,
#[cfg(target_os = "linux")]
use libc::{ifreq, ioctl};

fn hash_interface_name(name: &str) -> InterfaceId {
let mut hasher = DefaultHasher::new();
name.hash(&mut hasher);
hasher.finish()
}

// Get the interface list.
let mut ifap: *mut ifaddrs = ptr::null_mut();
if unsafe { getifaddrs(&mut ifap) } != 0 {
Expand Down Expand Up @@ -163,7 +152,7 @@ fn get_interface_and_mtu_linux_macos(socket: &UdpSocket) -> Result<(InterfaceId,
{
let data = unsafe { &*(ifa.ifa_data as *const if_data) };
if let Ok(mtu) = usize::try_from(data.ifi_mtu) {
res = Ok((hash_interface_name(iface), mtu));
res = Ok((iface.to_string(), mtu));
}
break;
}
Expand All @@ -182,7 +171,7 @@ fn get_interface_and_mtu_linux_macos(socket: &UdpSocket) -> Result<(InterfaceId,
if unsafe { ioctl(socket.as_raw_fd(), libc::SIOCGIFMTU, &ifr) } != 0 {
res = Err(Error::last_os_error());
} else if let Ok(mtu) = usize::try_from(unsafe { ifr.ifr_ifru.ifru_mtu }) {
res = Ok((hash_interface_name(iface), mtu));
res = Ok((iface.to_string(), mtu));
}
}
}
Expand All @@ -192,18 +181,15 @@ fn get_interface_and_mtu_linux_macos(socket: &UdpSocket) -> Result<(InterfaceId,
}

#[cfg(target_os = "windows")]
fn get_interface_and_mtu_windows(socket: &UdpSocket) -> Result<(InterfaceId, usize), Error> {
use std::{
ffi::c_void,
hash::{DefaultHasher, Hash, Hasher},
slice,
};
fn get_interface_and_mtu_windows(socket: &UdpSocket) -> Result<(String, usize), Error> {
use std::{ffi::c_void, slice};

use windows::Win32::{
Foundation::NO_ERROR,
NetworkManagement::IpHelper::{
FreeMibTable, GetIpInterfaceTable, GetUnicastIpAddressTable, MIB_IPINTERFACE_ROW,
MIB_IPINTERFACE_TABLE, MIB_UNICASTIPADDRESS_ROW, MIB_UNICASTIPADDRESS_TABLE,
if_indextoname, FreeMibTable, GetIpInterfaceTable, GetUnicastIpAddressTable,
MIB_IPINTERFACE_ROW, MIB_IPINTERFACE_TABLE, MIB_UNICASTIPADDRESS_ROW,
MIB_UNICASTIPADDRESS_TABLE,
},
Networking::WinSock::{AF_INET, AF_INET6, AF_UNSPEC},
};
Expand Down Expand Up @@ -257,10 +243,9 @@ fn get_interface_and_mtu_windows(socket: &UdpSocket) -> Result<(InterfaceId, usi
for iface in ifaces {
if iface.InterfaceIndex == addr.InterfaceIndex {
if let Ok(mtu) = iface.NlMtu.try_into() {
let mut hasher = DefaultHasher::new();
iface.InterfaceIndex.hash(&mut hasher);
let id = hasher.finish();
res = Ok((id, mtu));
if_indextoname(iface.InterfaceIndex, |name| {
res = Ok((name, mtu));
});
}
break 'addr_loop;
}
Expand Down Expand Up @@ -289,7 +274,7 @@ fn get_interface_and_mtu_windows(socket: &UdpSocket) -> Result<(InterfaceId, usi
/// ```
/// let saddr = "127.0.0.1:443".parse().unwrap();
/// let mtu = mtu::get_interface_mtu(&saddr).unwrap();
/// println!("MTU towards {:?} is {}", saddr, mtu);
/// println!("MTU towards {saddr:?} is {mtu}");
/// ```
///
/// # Errors
Expand Down
Loading