Browse Source

net: add `bound_endpoint` procedure

Laytan Laats 11 months ago
parent
commit
652557bfcd

+ 7 - 0
core/net/socket.odin

@@ -134,6 +134,13 @@ listen_tcp :: proc(interface_endpoint: Endpoint, backlog := 1000) -> (socket: TC
 	return _listen_tcp(interface_endpoint, backlog)
 }
 
+/*
+	Returns the endpoint that the given socket is listening / bound on.
+*/
+bound_endpoint :: proc(socket: Any_Socket) -> (endpoint: Endpoint, err: Network_Error) {
+	return _bound_endpoint(socket)
+}
+
 accept_tcp :: proc(socket: TCP_Socket, options := default_tcp_options) -> (client: TCP_Socket, source: Endpoint, err: Network_Error) {
 	return _accept_tcp(socket, options)
 }

+ 14 - 0
core/net/socket_darwin.odin

@@ -22,6 +22,7 @@ package net
 
 import "core:c"
 import "core:os"
+import "core:sys/posix"
 import "core:time"
 
 Socket_Option :: enum c.int {
@@ -138,6 +139,19 @@ _listen_tcp :: proc(interface_endpoint: Endpoint, backlog := 1000) -> (skt: TCP_
 	return
 }
 
+@(private)
+_bound_endpoint :: proc(sock: Any_Socket) -> (ep: Endpoint, err: Network_Error) {
+	addr: posix.sockaddr_storage
+	addr_len := posix.socklen_t(size_of(addr))
+	res := posix.getsockname(posix.FD(any_socket_to_socket(sock)), (^posix.sockaddr)(&addr), &addr_len)
+	if res != .OK {
+		err = Listen_Error(posix.errno())
+		return
+	}
+	ep = _sockaddr_to_endpoint((^os.SOCKADDR_STORAGE_LH)(&addr))
+	return
+}
+
 @(private)
 _accept_tcp :: proc(sock: TCP_Socket, options := default_tcp_options) -> (client: TCP_Socket, source: Endpoint, err: Network_Error) {
 	sockaddr: os.SOCKADDR_STORAGE_LH

+ 14 - 0
core/net/socket_freebsd.odin

@@ -149,6 +149,20 @@ _listen_tcp :: proc(interface_endpoint: Endpoint, backlog := 1000) -> (socket: T
 	return
 }
 
+@(private)
+_bound_endpoint :: proc(sock: Any_Socket) -> (ep: Endpoint, err: Network_Error) {
+	sockaddr: freebsd.Socket_Address_Storage
+
+	errno := freebsd.getsockname(cast(Fd)any_socket_to_socket(sock), &sockaddr)
+	if errno != nil {
+		err = cast(Listen_Error)errno
+		return
+	}
+
+	ep = _sockaddr_to_endpoint(&sockaddr)
+	return
+}
+
 @(private)
 _accept_tcp :: proc(sock: TCP_Socket, options := default_tcp_options) -> (client: TCP_Socket, source: Endpoint, err: Network_Error) {
 	sockaddr: freebsd.Socket_Address_Storage

+ 13 - 0
core/net/socket_linux.odin

@@ -202,6 +202,19 @@ _listen_tcp :: proc(endpoint: Endpoint, backlog := 1000) -> (TCP_Socket, Network
 	return cast(TCP_Socket) os_sock, nil
 }
 
+@(private)
+_bound_endpoint :: proc(sock: Any_Socket) -> (ep: Endpoint, err: Network_Error) {
+	addr: linux.Sock_Addr_Any
+	errno := linux.getsockname(_unwrap_os_socket(sock), &addr)
+	if errno != .NONE {
+		err = Listen_Error(errno)
+		return
+	}
+
+	ep = _wrap_os_addr(addr)
+	return
+}
+
 @(private)
 _accept_tcp :: proc(sock: TCP_Socket, options := default_tcp_options) -> (tcp_client: TCP_Socket, endpoint: Endpoint, err: Network_Error) {
 	addr: linux.Sock_Addr_Any

+ 14 - 1
core/net/socket_windows.odin

@@ -120,6 +120,19 @@ _listen_tcp :: proc(interface_endpoint: Endpoint, backlog := 1000) -> (socket: T
 	return
 }
 
+@(private)
+_bound_endpoint :: proc(sock: Any_Socket) -> (ep: Endpoint, err: Network_Error) {
+	sockaddr: win.SOCKADDR_STORAGE_LH
+	sockaddrlen := c.int(size_of(sockaddr))
+	if win.getsockname(win.SOCKET(any_socket_to_socket(sock)), &sockaddr, &sockaddrlen) == win.SOCKET_ERROR {
+		err = Listen_Error(win.WSAGetLastError())
+		return
+	}
+
+	ep = _sockaddr_to_endpoint(&sockaddr)
+	return
+}
+
 @(private)
 _accept_tcp :: proc(sock: TCP_Socket, options := default_tcp_options) -> (client: TCP_Socket, source: Endpoint, err: Network_Error) {
 	for {
@@ -368,4 +381,4 @@ _sockaddr_to_endpoint :: proc(native_addr: ^win.SOCKADDR_STORAGE_LH) -> (ep: End
 		panic("native_addr is neither IP4 or IP6 address")
 	}
 	return
-}
+}

+ 1 - 1
core/os/os_darwin.odin

@@ -204,7 +204,7 @@ ENOPROTOOPT     :: _Platform_Error.ENOPROTOOPT
 EPROTONOSUPPORT :: _Platform_Error.EPROTONOSUPPORT
 ESOCKTNOSUPPORT :: _Platform_Error.ESOCKTNOSUPPORT
 ENOTSUP         :: _Platform_Error.ENOTSUP
-EOPNOTSUPP 	:: _Platform_Error.EOPNOTSUPP
+EOPNOTSUPP 	    :: _Platform_Error.EOPNOTSUPP
 EPFNOSUPPORT    :: _Platform_Error.EPFNOSUPPORT
 EAFNOSUPPORT    :: _Platform_Error.EAFNOSUPPORT
 EADDRINUSE      :: _Platform_Error.EADDRINUSE

+ 21 - 0
core/sys/freebsd/syscalls.odin

@@ -21,6 +21,7 @@ SYS_close      : uintptr : 6
 SYS_getpid     : uintptr : 20
 SYS_recvfrom   : uintptr : 29
 SYS_accept     : uintptr : 30
+SYS_getsockname: uintptr : 32
 SYS_fcntl      : uintptr : 92
 SYS_fsync      : uintptr : 95
 SYS_socket     : uintptr : 97
@@ -201,6 +202,26 @@ accept_nil :: proc "contextless" (s: Fd) -> (Fd, Errno) {
 
 accept :: proc { accept_T, accept_nil }
 
+// Get socket name.
+//
+// The getsockname() system call appeared in 4.2BSD.
+getsockname :: proc "contextless" (s: Fd, sockaddr: ^$T) -> Errno {
+	// sockaddr must contain a valid pointer, or this will segfault because
+	// we're telling the syscall that there's memory available to write to.
+	addrlen: socklen_t = size_of(T)
+
+	result, ok := intrinsics.syscall_bsd(SYS_getsockname,
+		cast(uintptr)s,
+		cast(uintptr)sockaddr,
+		cast(uintptr)&addrlen)
+
+	if !ok {
+		return cast(Errno)result
+	}
+
+	return nil
+}
+
 // Synchronize changes to a file.
 //
 // The fsync() system call appeared in 4.2BSD.

+ 27 - 22
tests/core/net/test_core_net.odin

@@ -217,32 +217,21 @@ IP_Address_Parsing_Test_Vectors :: []IP_Address_Parsing_Test_Vector{
 	{ .IP6, "c0a8",                    "", ""},
 }
 
-ENDPOINT_TWO_SERVERS  := net.Endpoint{net.IP4_Address{127, 0, 0, 1}, 9991}
-ENDPOINT_CLOSED_PORT  := net.Endpoint{net.IP4_Address{127, 0, 0, 1}, 9992}
-ENDPOINT_SERVER_SENDS := net.Endpoint{net.IP4_Address{127, 0, 0, 1}, 9993}
-ENDPOINT_UDP_ECHO     := net.Endpoint{net.IP4_Address{127, 0, 0, 1}, 9994}
-ENDPOINT_NONBLOCK     := net.Endpoint{net.IP4_Address{127, 0, 0, 1}, 9995}
-
 @(test)
 two_servers_binding_same_endpoint :: proc(t: ^testing.T) {
-	skt1, err1 := net.listen_tcp(ENDPOINT_TWO_SERVERS)
+	skt1, err1 := net.listen_tcp({address=net.IP4_Address{127, 0, 0, 1}, port=0})
 	defer net.close(skt1)
-	skt2, err2 := net.listen_tcp(ENDPOINT_TWO_SERVERS)
+
+	ep, perr := net.bound_endpoint(skt1)
+
+	skt2, err2 := net.listen_tcp(ep)
 	defer net.close(skt2)
 
 	testing.expect(t, err1 == nil, "expected first server binding to endpoint to do so without error")
+	testing.expect_value(t, perr, nil)
 	testing.expect(t, err2 == net.Bind_Error.Address_In_Use, "expected second server to bind to an endpoint to return .Address_In_Use")
 }
 
-@(test)
-client_connects_to_closed_port :: proc(t: ^testing.T) {
-
-	skt, err := net.dial_tcp(ENDPOINT_CLOSED_PORT)
-	defer net.close(skt)
-	testing.expect(t, err == net.Dial_Error.Refused, "expected dial of a closed endpoint to return .Refused")
-}
-
-
 @(test)
 client_sends_server_data :: proc(t: ^testing.T) {
 	CONTENT: string: "Hellope!"
@@ -250,6 +239,9 @@ client_sends_server_data :: proc(t: ^testing.T) {
 	SEND_TIMEOUT :: time.Duration(1 * time.Second)
 	RECV_TIMEOUT :: time.Duration(1 * time.Second)
 
+	@static endpoint: net.Endpoint
+	endpoint.address = net.IP4_Address{127, 0, 0, 1}
+
 	Thread_Data :: struct {
 		t: ^testing.T,
 		skt: net.Any_Socket,
@@ -266,7 +258,7 @@ client_sends_server_data :: proc(t: ^testing.T) {
 
 		defer sync.wait_group_done(r.wg)
 
-		if r.skt, r.err = net.dial_tcp(ENDPOINT_SERVER_SENDS); r.err != nil {
+		if r.skt, r.err = net.dial_tcp(endpoint); r.err != nil {
 			testing.expectf(r.t, false, "[tcp_client:dial_tcp] %v", r.err)
 			return
 		}
@@ -281,12 +273,17 @@ client_sends_server_data :: proc(t: ^testing.T) {
 
 		defer sync.wait_group_done(r.wg)
 
-		if r.skt, r.err = net.listen_tcp(ENDPOINT_SERVER_SENDS); r.err != nil {
+		if r.skt, r.err = net.listen_tcp(endpoint); r.err != nil {
 			sync.wait_group_done(r.wg)
 			testing.expectf(r.t, false, "[tcp_server:listen_tcp] %v", r.err)
 			return
 		}
 
+		endpoint, r.err = net.bound_endpoint(r.skt.(net.TCP_Socket))
+		if !testing.expect_value(r.t, r.err, nil) {
+			return
+		}
+
 		sync.wait_group_done(r.wg)
 
 		client: net.TCP_Socket
@@ -524,17 +521,25 @@ join_url_test :: proc(t: ^testing.T) {
 
 @test
 test_udp_echo :: proc(t: ^testing.T) {
+	endpoint := net.Endpoint{address=net.IP4_Address{127, 0, 0, 1}, port=0}
+
 	server, make_server_err := net.make_unbound_udp_socket(.IP4)
 	if !testing.expect_value(t, make_server_err, nil) {
 		return
 	}
 	defer net.close(server)
 
-	bind_server_err := net.bind(server, ENDPOINT_UDP_ECHO)
+	bind_server_err := net.bind(server, endpoint)
 	if !testing.expect_value(t, bind_server_err, nil) {
 		return
 	}
 
+	perr: net.Network_Error
+	endpoint, perr = net.bound_endpoint(server)
+	if !testing.expect_value(t, perr, nil) {
+		return
+	}
+
 	client, make_client_err := net.make_unbound_udp_socket(.IP4)
 	if !testing.expect_value(t, make_client_err, nil) {
 		return
@@ -544,7 +549,7 @@ test_udp_echo :: proc(t: ^testing.T) {
 	msg := "Hellope world!"
 	buf: [64]u8
 
-	bytes_written, send_err := net.send_udp(client, transmute([]u8)msg[:], ENDPOINT_UDP_ECHO)
+	bytes_written, send_err := net.send_udp(client, transmute([]u8)msg[:], endpoint)
 	if !testing.expect_value(t, send_err, nil) {
 		return
 	}
@@ -600,7 +605,7 @@ test_dns_resolve :: proc(t: ^testing.T) {
 
 @test
 test_nonblocking_option :: proc(t: ^testing.T) {
-	server, listen_err := net.listen_tcp(ENDPOINT_NONBLOCK)
+	server, listen_err := net.listen_tcp({address=net.IP4_Address{127, 0, 0, 1}, port=0})
 	if !testing.expect_value(t, listen_err, nil) {
 		return
 	}