Browse Source

net: fix leaking sockets in listen_tcp if an error occurs

Laytan Laats 10 months ago
parent
commit
5c63617191

+ 1 - 0
core/net/socket_darwin.odin

@@ -120,6 +120,7 @@ _listen_tcp :: proc(interface_endpoint: Endpoint, backlog := 1000) -> (skt: TCP_
 	family := family_from_endpoint(interface_endpoint)
 	family := family_from_endpoint(interface_endpoint)
 	sock := create_socket(family, .TCP) or_return
 	sock := create_socket(family, .TCP) or_return
 	skt = sock.(TCP_Socket)
 	skt = sock.(TCP_Socket)
+	defer if err != nil { close(skt) }
 
 
 	// NOTE(tetra): This is so that if we crash while the socket is open, we can
 	// NOTE(tetra): This is so that if we crash while the socket is open, we can
 	// bypass the cooldown period, and allow the next run of the program to
 	// bypass the cooldown period, and allow the next run of the program to

+ 1 - 0
core/net/socket_freebsd.odin

@@ -137,6 +137,7 @@ _listen_tcp :: proc(interface_endpoint: Endpoint, backlog := 1000) -> (socket: T
 	family := family_from_endpoint(interface_endpoint)
 	family := family_from_endpoint(interface_endpoint)
 	new_socket := create_socket(family, .TCP) or_return
 	new_socket := create_socket(family, .TCP) or_return
 	socket = new_socket.(TCP_Socket)
 	socket = new_socket.(TCP_Socket)
+	defer if err != nil { close(socket) }
 
 
 	bind(socket, interface_endpoint) or_return
 	bind(socket, interface_endpoint) or_return
 
 

+ 21 - 13
core/net/socket_linux.odin

@@ -167,40 +167,48 @@ _bind :: proc(sock: Any_Socket, endpoint: Endpoint) -> (Network_Error) {
 }
 }
 
 
 @(private)
 @(private)
-_listen_tcp :: proc(endpoint: Endpoint, backlog := 1000) -> (TCP_Socket, Network_Error) {
+_listen_tcp :: proc(endpoint: Endpoint, backlog := 1000) -> (socket: TCP_Socket, err: Network_Error) {
 	errno: linux.Errno
 	errno: linux.Errno
 	assert(backlog > 0 && i32(backlog) < max(i32))
 	assert(backlog > 0 && i32(backlog) < max(i32))
+
 	// Figure out the address family and address of the endpoint
 	// Figure out the address family and address of the endpoint
 	ep_family := _unwrap_os_family(family_from_endpoint(endpoint))
 	ep_family := _unwrap_os_family(family_from_endpoint(endpoint))
 	ep_address := _unwrap_os_addr(endpoint)
 	ep_address := _unwrap_os_addr(endpoint)
+
 	// Create TCP socket
 	// Create TCP socket
 	os_sock: linux.Fd
 	os_sock: linux.Fd
 	os_sock, errno = linux.socket(ep_family, .STREAM, {.CLOEXEC}, .TCP)
 	os_sock, errno = linux.socket(ep_family, .STREAM, {.CLOEXEC}, .TCP)
 	if errno != .NONE {
 	if errno != .NONE {
-		// TODO(flysand): should return invalid file descriptor here casted as TCP_Socket
-		return {}, Create_Socket_Error(errno)
+		err = Create_Socket_Error(errno)
+		return
 	}
 	}
+	socket = cast(TCP_Socket)os_sock
+	defer if err != nil { close(socket) }
+
 	// NOTE(tetra): This is so that if we crash while the socket is open, we can
 	// NOTE(tetra): This is so that if we crash while the socket is open, we can
 	// bypass the cooldown period, and allow the next run of the program to
 	// bypass the cooldown period, and allow the next run of the program to
 	// use the same address immediately.
 	// use the same address immediately.
 	//
 	//
 	// TODO(tetra, 2022-02-15): Confirm that this doesn't mean other processes can hijack the address!
 	// TODO(tetra, 2022-02-15): Confirm that this doesn't mean other processes can hijack the address!
 	do_reuse_addr: b32 = true
 	do_reuse_addr: b32 = true
-	errno = linux.setsockopt(os_sock, linux.SOL_SOCKET, linux.Socket_Option.REUSEADDR, &do_reuse_addr)
-	if errno != .NONE {
-		return cast(TCP_Socket) os_sock, Listen_Error(errno)
+	if errno = linux.setsockopt(os_sock, linux.SOL_SOCKET, linux.Socket_Option.REUSEADDR, &do_reuse_addr); errno != .NONE {
+		err = Listen_Error(errno)
+		return
 	}
 	}
+
 	// Bind the socket to endpoint address
 	// Bind the socket to endpoint address
-	errno = linux.bind(os_sock, &ep_address)
-	if errno != .NONE {
-		return cast(TCP_Socket) os_sock, Bind_Error(errno)
+	if errno = linux.bind(os_sock, &ep_address); errno != .NONE {
+		err = Bind_Error(errno)
+		return
 	}
 	}
+
 	// Listen on bound socket
 	// Listen on bound socket
-	errno = linux.listen(os_sock, cast(i32) backlog)
-	if errno != .NONE {
-		return cast(TCP_Socket) os_sock, Listen_Error(errno)
+	if errno = linux.listen(os_sock, cast(i32) backlog); errno != .NONE {
+		err = Listen_Error(errno)
+		return
 	}
 	}
-	return cast(TCP_Socket) os_sock, nil
+
+	return
 }
 }
 
 
 @(private)
 @(private)

+ 1 - 0
core/net/socket_windows.odin

@@ -107,6 +107,7 @@ _listen_tcp :: proc(interface_endpoint: Endpoint, backlog := 1000) -> (socket: T
 	family := family_from_endpoint(interface_endpoint)
 	family := family_from_endpoint(interface_endpoint)
 	sock := create_socket(family, .TCP) or_return
 	sock := create_socket(family, .TCP) or_return
 	socket = sock.(TCP_Socket)
 	socket = sock.(TCP_Socket)
+	defer if err != nil { close(socket) }
 
 
 	// NOTE(tetra): While I'm not 100% clear on it, my understanding is that this will
 	// NOTE(tetra): While I'm not 100% clear on it, my understanding is that this will
 	// prevent hijacking of the server's endpoint by other applications.
 	// prevent hijacking of the server's endpoint by other applications.