Skip to content

Commit

Permalink
more implementation details
Browse files Browse the repository at this point in the history
  • Loading branch information
decathorpe committed Apr 5, 2022
1 parent c7d3d90 commit 6db7114
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 92 deletions.
2 changes: 2 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[build]
rustflags = ["--cfg", "tokio_unstable"]
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ exclude = ["/rustfmt.toml"]
anyhow = "1.0.56"
base64 = "0.13"
boringtun = { version = "0.4", default-features = false }
console-subscriber = "0.1.3"
env_logger = "0.9"
log = "0.4.14"
pretty-hex = "0.2.1"
smoltcp = "0.8"
tokio = { version = "1.15", features = ["macros", "net", "rt-multi-thread", "sync", "time"] }

[profile.release]
codegen-units = 1
lto = true
opt-level = 3

70 changes: 4 additions & 66 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,78 +1,14 @@
use anyhow::anyhow;
use std::net::SocketAddr;

use anyhow::anyhow;

use boringtun::crypto::{X25519PublicKey, X25519SecretKey};

#[allow(unused)]
mod tcp;

mod wg;
use wg::WgServer;

/*
let src_ip = ip_packet.src_addr();
let dst_ip = ip_packet.dst_addr();
log::debug!("WireGuard: IPv4 src address: {}", src_ip);
log::debug!("WireGuard: IPv4 dst address: {}", dst_ip);
let tcp_packet = TcpPacket::new_checked(ip_packet.payload_mut()).unwrap();
log::debug!("TCP packet: {:?}", &tcp_packet);
let src_sock_addr = SocketAddr::new(IpAddr::V4(src_ip.into()), tcp_packet.src_port());
let dst_sock_addr = SocketAddr::new(IpAddr::V4(dst_ip.into()), tcp_packet.dst_port());
if tcp_packet.syn() {
log::debug!("TCP SYN: {}", dst_sock_addr);
let mut tcp_socket = TcpSocket::new(
TcpSocketBuffer::new(vec![0u8; 4096]),
TcpSocketBuffer::new(vec![0u8; 4096]),
);
tcp_socket.set_ack_delay(None);
tcp_socket.listen(dst_sock_addr).unwrap();
iface.add_socket(tcp_socket);
}
let mut recv_buf: Vec<u8> = Vec::new();
iface.device_mut().recv_packet(recv_buf);
//let res = iface.poll(Instant::now()).unwrap();
while let Some(vec) = iface.device_mut().send_packet() {
let mut send_buf = [0u8; 1500];
let mut send_packet = tun.encapsulate(&recv_buf, &mut send_buf);
match send_packet {
TunnResult::Done => log::debug!("WireGuard: Done"),
TunnResult::Err(e) => log::error!("WireGuard: Error {:?}", e),
TunnResult::WriteToNetwork(b) => {
wg_socket.send_to(b, addr).await.unwrap();
},
_ => unreachable!(),
}
}
if tcp_packet.fin() {
log::debug!("TCP FIN: {}", dst_sock_addr);
tcp_connections.remove(&dst_sock_addr);
}
if tcp_packet.ack() {
log::debug!("TCP ACK: {}", dst_sock_addr);
if tcp_connections.contains(&dst_sock_addr) {
// TODO: this does not work
mp_socket.try_write(tcp_packet.into_inner()).unwrap();
} else {
log::debug!("TCP ACK for unknown destination");
}
}
// TODO: handle other TCP packet types (?)
*/

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// TODO: lower default verbosity to LevelFilter::Info
Expand All @@ -81,6 +17,8 @@ async fn main() -> Result<(), anyhow::Error> {
.parse_env("MG_LOG")
.init();

console_subscriber::init();

// TODO: make configurable
let server_priv_key: X25519SecretKey = "c72d788fd0916b1185177fd7fa392451192773c889d17ac739571a63482c18bb"
.parse()
Expand Down
121 changes: 109 additions & 12 deletions src/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,114 @@
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use std::net::{IpAddr, SocketAddr};
use std::time::Duration;

use smoltcp::iface::{Interface, InterfaceBuilder};
use smoltcp::iface::{Interface, InterfaceBuilder, SocketHandle};
use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
use smoltcp::socket::{TcpSocket, TcpSocketBuffer};
use smoltcp::time::Instant;
use smoltcp::wire::TcpPacket;

pub struct VirtualInterface {
use tokio::sync::{broadcast, mpsc};

#[derive(Clone, Debug)]
pub struct TcpMessage {
pub src_ip: IpAddr,
pub dst_ip: IpAddr,
pub packet: TcpPacket<Vec<u8>>,
}

impl TcpMessage {
pub fn new(src_ip: IpAddr, dst_ip: IpAddr, packet: TcpPacket<Vec<u8>>) -> TcpMessage {
TcpMessage { src_ip, dst_ip, packet }
}
}

pub struct TcpHandler {
pusher: mpsc::Receiver<TcpMessage>,
puller: broadcast::Sender<TcpMessage>,
handles: HashMap<SocketAddr, SocketHandle>,
memory: HashMap<SocketHandle, (IpAddr, IpAddr)>,
iface: VirtualInterface,
}

impl TcpHandler {
pub fn new(pusher: mpsc::Receiver<TcpMessage>, puller: broadcast::Sender<TcpMessage>) -> TcpHandler {
let handles = HashMap::new();
let memory = HashMap::new();
let iface = VirtualInterface::new();

TcpHandler { pusher, puller, handles, memory, iface }
}

pub async fn handle(mut self) -> Result<(), anyhow::Error> {
loop {
if let Ok(true) = self.iface.iface.poll(Instant::now()) {
tokio::select!(
// handle outgoing TCP packets
ret = self.pusher.recv() => {
if let Some(message) = ret {
let (src_ip, dst_ip, packet) = (message.src_ip, message.dst_ip, message.packet);
log::debug!("Outgoing TCP packet: {} -> {}: {:?}", src_ip, dst_ip, packet);

let _src_addr = SocketAddr::new(src_ip, packet.src_port());
let dst_addr = SocketAddr::new(dst_ip, packet.dst_port());

if packet.syn() {
let mut socket = TcpSocket::new(
TcpSocketBuffer::new(vec![0u8; 4096]),
TcpSocketBuffer::new(vec![0u8; 4096]),
);

socket.set_ack_delay(None);
socket.listen(dst_addr).unwrap();

let handle = self.iface.iface.add_socket(socket);
self.handles.insert(dst_addr, handle);
self.memory.insert(handle, (src_ip, dst_ip));
}

let fin = packet.fin();
self.iface.recv_packet(packet.into_inner());

if fin {
if let Some(handle) = self.handles.get(&dst_addr) {
self.iface.iface.remove_socket(*handle);
}
}

while let Some(resp_packet) = self.iface.resp_packet() {
let pack = TcpPacket::new_checked(resp_packet).unwrap();
self.puller.send(TcpMessage::new(dst_ip, src_ip, pack)).unwrap();
}
}
},
// handle TCP response packets
/*
ret = std::future::ready(self.iface.resp_packet()) => {
if let Some(response) = ret {
// TODO: lookup src_addr, dst_addr based on which socket has received data
let resp_packet = TcpPacket::new_checked(response)?;
log::debug!("Response TCP packet: {:?}", resp_packet);
//self.puller.send((src_addr, resp_packet))?;
}
}
*/
)
}

if let Some(dur) = self.iface.iface.poll_delay(Instant::now()) {
log::debug!("TCP poll delay: {}", dur);
tokio::time::sleep(dur.into()).await;
} else {
// FIXME: Interface::poll_delay seems to always return `None`
// This statement was only added to avoid busy sleeping.
tokio::time::sleep(Duration::from_millis(200)).await;
}
}
}
}

struct VirtualInterface {
iface: Interface<'static, VirtualDevice>,
}

Expand All @@ -27,11 +131,6 @@ impl VirtualInterface {
pub fn resp_packet(&mut self) -> Option<Vec<u8>> {
self.iface.device_mut().resp_packet()
}

/// send a response packet
pub fn send_packet(&mut self, packet: Vec<u8>) {
self.iface.device_mut().send_packet(packet)
}
}

#[derive(Debug, Default)]
Expand All @@ -48,10 +147,6 @@ impl VirtualDevice {
pub fn resp_packet(&mut self) -> Option<Vec<u8>> {
self.tx_buffer.pop_front()
}

pub fn send_packet(&mut self, packet: Vec<u8>) {
self.tx_buffer.push_back(packet);
}
}

#[derive(Debug)]
Expand All @@ -64,6 +159,7 @@ impl RxToken for VirtualRxToken {
where
F: FnOnce(&mut [u8]) -> smoltcp::Result<R>,
{
// TODO: actually receive packet
f(&mut self.buffer)
}
}
Expand All @@ -78,6 +174,7 @@ impl<'a> TxToken for VirtualTxToken<'a> {
where
F: FnOnce(&mut [u8]) -> smoltcp::Result<R>,
{
// TODO: actually send packet
let mut buffer = vec![0; len];
let result = f(&mut buffer);
self.device.tx_buffer.push_back(buffer);
Expand Down
32 changes: 18 additions & 14 deletions src/wg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use smoltcp::wire::{IpProtocol, Ipv4Packet, TcpPacket};
use tokio::net::UdpSocket;
use tokio::sync::{broadcast, mpsc};

use crate::tcp::{TcpMessage, TcpHandler};

pub struct WgServer {
addr: SocketAddr,
sec_key: Arc<X25519SecretKey>,
Expand All @@ -22,13 +24,13 @@ pub struct WgServer {
wg_pull: mpsc::Sender<(SocketAddr, Vec<u8>)>,

// outgoing TCP packets
tcp_push: mpsc::Sender<(SocketAddr, Vec<u8>)>,
tcp_push: mpsc::Sender<TcpMessage>,
// incoming TCP packets
tcp_pull: broadcast::Sender<(SocketAddr, Vec<u8>)>,
tcp_pull: broadcast::Sender<TcpMessage>,

// store single receivers for later use
wg_puller: mpsc::Receiver<(SocketAddr, Vec<u8>)>,
tcp_pusher: mpsc::Receiver<(SocketAddr, Vec<u8>)>,
tcp_pusher: mpsc::Receiver<TcpMessage>,
}

impl WgServer {
Expand Down Expand Up @@ -77,7 +79,8 @@ impl WgServer {
}

// spawn handler for virtual TCP interface
// TODO
let tcp_handler = TcpHandler::new(self.tcp_pusher, self.tcp_pull);
tokio::spawn(tcp_handler.handle());

// listen for incoming WireGuard connections
let socket = UdpSocket::bind(self.addr).await?;
Expand Down Expand Up @@ -118,8 +121,8 @@ impl WgPeerTunn {
self,
mut wg_push: broadcast::Receiver<(SocketAddr, Vec<u8>)>,
wg_pull: mpsc::Sender<(SocketAddr, Vec<u8>)>,
tcp_push: mpsc::Sender<(SocketAddr, Vec<u8>)>,
mut tcp_pull: broadcast::Receiver<(SocketAddr, Vec<u8>)>,
tcp_push: mpsc::Sender<TcpMessage>,
mut tcp_pull: broadcast::Receiver<TcpMessage>,
) -> Result<(), anyhow::Error> {
let mut wg_buf = [0u8; 1500];

Expand Down Expand Up @@ -168,16 +171,14 @@ impl WgPeerTunn {
continue;
}

let src_ip = ip_packet.src_addr();
let dst_ip = ip_packet.dst_addr();
let src_ip: IpAddr = IpAddr::V4(ip_packet.src_addr().into());
let dst_ip: IpAddr = IpAddr::V4(ip_packet.dst_addr().into());

log::debug!("WireGuard: IPv4 src address: {}", src_ip);
log::debug!("WireGuard: IPv4 dst address: {}", dst_ip);

let tcp_packet = TcpPacket::new_checked(ip_packet.payload_mut()).unwrap();
let sock_addr = SocketAddr::new(IpAddr::V4(src_ip.into()), tcp_packet.src_port());

tcp_push.send((sock_addr, ip_packet.into_inner().to_vec())).await.unwrap();
let tcp_packet = TcpPacket::new_checked(ip_packet.payload_mut().to_vec()).unwrap();
tcp_push.send(TcpMessage::new(src_ip, dst_ip, tcp_packet)).await.unwrap();
},
// IPv6 packet
TunnResult::WriteToTunnelV6(buf, src_addr) => {
Expand All @@ -189,15 +190,18 @@ impl WgPeerTunn {
},
// wait for outgoing data
ret = tcp_pull.recv() => {
let (dst_addr, packet) = ret.unwrap();
let message = ret.unwrap();
let (_, dst_ip, packet) = (message.src_ip, message.dst_ip, message.packet);
let dst_port = packet.dst_port();

let mut result = self.tunn.encapsulate(&packet, &mut wg_buf);
let mut result = self.tunn.encapsulate(&packet.into_inner(), &mut wg_buf);

// encode and handle outgoing WireGuard packet(s)
loop {
match result {
TunnResult::WriteToNetwork(b) => {
log::debug!("WireGuard: WriteToNetwork");
let dst_addr = SocketAddr::new(dst_ip, dst_port);
wg_pull.send((dst_addr, b.to_vec())).await.unwrap();
},
_ => break,
Expand Down

0 comments on commit 6db7114

Please sign in to comment.