فهرست منبع

[rust/viz]: use multiple current thread runtimes in json (#9750)

Fangdun Tsai 4 ماه پیش
والد
کامیت
d052086d88

+ 2 - 0
frameworks/Rust/viz/Cargo.toml

@@ -35,6 +35,8 @@ mime = "0.3"
 rand = { version = "0.9", features = ["small_rng"] }
 thiserror = "2.0"
 futures-util = "0.3"
+socket2 = "0.5.8"
+num_cpus = "1.16.0"
 
 [target.'cfg(not(unix))'.dependencies]
 nanorand = { version = "0.7" }

+ 12 - 7
frameworks/Rust/viz/src/main.rs

@@ -3,7 +3,7 @@
 use serde::Serialize;
 use viz::{
     header::{HeaderValue, CONTENT_TYPE, SERVER},
-    Bytes, Error, Request, Response, ResponseExt, Result, Router,
+    Bytes, Request, Response, ResponseExt, Result, Router,
 };
 
 mod server;
@@ -14,6 +14,7 @@ struct Message {
     message: &'static str,
 }
 
+#[inline(always)]
 async fn plaintext(_: Request) -> Result<Response> {
     let mut res = Response::text("Hello, World!");
     res.headers_mut()
@@ -21,8 +22,9 @@ async fn plaintext(_: Request) -> Result<Response> {
     Ok(res)
 }
 
+#[inline(always)]
 async fn json(_: Request) -> Result<Response> {
-    let mut resp = Response::builder()
+    let mut res = Response::builder()
         .body(
             http_body_util::Full::new(Bytes::from(
                 serde_json::to_vec(&Message {
@@ -33,20 +35,23 @@ async fn json(_: Request) -> Result<Response> {
             .into(),
         )
         .unwrap();
-    let headers = resp.headers_mut();
+    let headers = res.headers_mut();
     headers.insert(SERVER, HeaderValue::from_static("Viz"));
     headers.insert(
         CONTENT_TYPE,
         HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()),
     );
-    Ok(resp)
+    Ok(res)
 }
 
-#[tokio::main]
-async fn main() -> Result<()> {
+async fn app() {
     let app = Router::new()
         .get("/plaintext", plaintext)
         .get("/json", json);
 
-    server::serve(app).await.map_err(Error::Boxed)
+    server::serve(app).await.unwrap();
+}
+
+fn main() {
+    server::run(app)
 }

+ 5 - 2
frameworks/Rust/viz/src/main_diesel.rs

@@ -77,8 +77,7 @@ async fn updates(req: Request) -> Result<Response> {
     Ok(res)
 }
 
-#[tokio::main]
-async fn main() {
+async fn app() {
     let max = available_parallelism().map(|n| n.get()).unwrap_or(16) as u32;
 
     let pool = Pool::<AsyncPgConnection>::builder()
@@ -99,3 +98,7 @@ async fn main() {
 
     server::serve(app).await.unwrap()
 }
+
+fn main() {
+    server::run(app)
+}

+ 6 - 24
frameworks/Rust/viz/src/main_pg.rs

@@ -1,7 +1,4 @@
-use std::{
-    sync::Arc,
-    thread::{available_parallelism, spawn},
-};
+use std::sync::Arc;
 
 use viz::{
     header::{HeaderValue, SERVER},
@@ -76,26 +73,7 @@ async fn updates(req: Request) -> Result<Response> {
     Ok(res)
 }
 
-fn main() {
-    let rt = tokio::runtime::Builder::new_current_thread()
-        .enable_all()
-        .build()
-        .unwrap();
-
-    for _ in 1..available_parallelism().map(|n| n.get()).unwrap_or(16) {
-        spawn(move || {
-            let rt = tokio::runtime::Builder::new_current_thread()
-                .enable_all()
-                .build()
-                .unwrap();
-            rt.block_on(serve());
-        });
-    }
-
-    rt.block_on(serve());
-}
-
-async fn serve() {
+async fn app() {
     let conn = PgConnection::connect(DB_URL).await;
 
     let app = Router::new()
@@ -107,3 +85,7 @@ async fn serve() {
 
     server::serve(app).await.unwrap()
 }
+
+fn main() {
+    server::run(app)
+}

+ 5 - 2
frameworks/Rust/viz/src/main_sqlx.rs

@@ -79,8 +79,7 @@ async fn updates(mut req: Request) -> Result<Response> {
     Ok(res)
 }
 
-#[tokio::main]
-async fn main() -> Result<()> {
+async fn app() -> Result<()> {
     let max = available_parallelism().map(|n| n.get()).unwrap_or(16) as u32;
 
     let pool = PgPoolOptions::new()
@@ -103,6 +102,10 @@ async fn main() -> Result<()> {
     server::serve(app).await.map_err(Error::Boxed)
 }
 
+fn main() {
+    server::run(app)
+}
+
 markup::define! {
     FortunesTemplate(items: Vec<Fortune>) {
         {markup::doctype()}

+ 54 - 21
frameworks/Rust/viz/src/server.rs

@@ -1,49 +1,82 @@
 use std::error::Error;
+use std::future::Future;
 use std::io;
 use std::net::{Ipv4Addr, SocketAddr};
 use std::sync::Arc;
+use std::thread;
 
+use socket2::{Domain, SockAddr, Socket};
 use hyper::server::conn::http1::Builder;
 use hyper_util::rt::TokioIo;
-use tokio::net::{TcpListener, TcpSocket};
+use tokio::{net::TcpListener, runtime};
 use viz::{Responder, Router, Tree};
 
 pub async fn serve(router: Router) -> Result<(), Box<dyn Error + Send + Sync>> {
-    let tree = Arc::<Tree>::new(router.into());
     let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080));
-    let listener = reuse_listener(addr).expect("couldn't bind to addr");
+    let socket = create_socket(addr).expect("couldn't bind to addr");
+    let listener = TcpListener::from_std(socket.into())?;
+
+    let tree = Arc::<Tree>::new(router.into());
+
+    let mut http = Builder::new();
+    http.pipeline_flush(true);
 
     println!("Started viz server at 8080");
 
     loop {
         let (tcp, _) = listener.accept().await?;
-        let io = TokioIo::new(tcp);
+        tcp.set_nodelay(true).expect("couldn't set TCP_NODELAY!");
+
+        let http = http.clone();
         let tree = tree.clone();
 
-        tokio::task::spawn(async move {
-            Builder::new()
-                .pipeline_flush(true)
-                .serve_connection(io, Responder::<Arc<SocketAddr>>::new(tree, None))
-                .with_upgrades()
+        tokio::spawn(async move {
+            http
+                .serve_connection(
+                    TokioIo::new(tcp),
+                    Responder::<Arc<SocketAddr>>::new(tree, None),
+                )
                 .await
         });
     }
 }
 
-fn reuse_listener(addr: SocketAddr) -> io::Result<TcpListener> {
-    let socket = match addr {
-        SocketAddr::V4(_) => TcpSocket::new_v4()?,
-        SocketAddr::V6(_) => TcpSocket::new_v6()?,
+fn create_socket(addr: SocketAddr) -> Result<Socket, io::Error> {
+    let domain = match addr {
+        SocketAddr::V4(_) => Domain::IPV4,
+        SocketAddr::V6(_) => Domain::IPV6,
     };
-
+    let addr = SockAddr::from(addr);
+    let socket = Socket::new(domain, socket2::Type::STREAM, None)?;
+    let backlog = 4096;
     #[cfg(unix)]
-    {
-        if let Err(e) = socket.set_reuseport(true) {
-            eprintln!("error setting SO_REUSEPORT: {e}");
-        }
+    socket.set_reuse_port(true)?;
+    socket.set_reuse_address(true)?;
+    socket.set_nodelay(true)?;
+    socket.set_nonblocking(true)?;
+    socket.bind(&addr)?;
+    socket.listen(backlog)?;
+
+    Ok(socket)
+}
+
+pub fn run<Fut>(f: fn() -> Fut)
+where
+    Fut: Future + Send + 'static,
+{
+    for _ in 1..num_cpus::get() {
+        let runtime = runtime::Builder::new_current_thread()
+            .enable_all()
+            .build()
+            .unwrap();
+        thread::spawn(move || {
+            runtime.block_on(f());
+        });
     }
 
-    socket.set_reuseaddr(true)?;
-    socket.bind(addr)?;
-    socket.listen(1024)
+    let runtime = runtime::Builder::new_current_thread()
+        .enable_all()
+        .build()
+        .unwrap();
+    runtime.block_on(f());
 }