Browse Source

use async connection pool for db. (#6832)

fakeshadow 3 years ago
parent
commit
8650b68efc

+ 7 - 6
frameworks/Rust/xitca-web/Cargo.toml

@@ -12,20 +12,21 @@ name = "xitca-web-diesel"
 path = "./src/main_diesel.rs"
 
 [dependencies]
-xitca-http = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "7499e8a5aa4f7e1f17bbf6b6ee0816828dff3149" }
-xitca-web = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "7499e8a5aa4f7e1f17bbf6b6ee0816828dff3149" }
+xitca-http = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "a470092d8f1e1c3bb7a9831c175bf112b70f81e7" }
+xitca-web = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "a470092d8f1e1c3bb7a9831c175bf112b70f81e7" }
 
 ahash = { version = "0.7.4", features = ["compile-time-rng"] }
 atoi = "0.4.0"
 bytes = "1"
 core_affinity = "0.5.10"
-diesel = { version = "1.4.7", features = ["postgres", "r2d2"] }
+diesel = { version = "1.4.7", features = ["postgres"] }
 futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
 mimalloc = { version = "0.1.25", default-features = false }
 rand = { version = "0.8", default-features = false, features = ["min_const_gen", "small_rng"] }
 sailfish = "0.3.3"
 serde = "1"
 simd-json = "0.4.6"
+tang-rs = "0.2"
 tokio = { version = "1.7", features = ["macros", "rt"] }
 tokio-postgres = "0.7.2"
 
@@ -36,6 +37,6 @@ codegen-units = 1
 panic = "abort"
 
 [patch.crates-io]
-xitca-http = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "7499e8a5aa4f7e1f17bbf6b6ee0816828dff3149" }
-xitca-server = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "7499e8a5aa4f7e1f17bbf6b6ee0816828dff3149" }
-xitca-service = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "7499e8a5aa4f7e1f17bbf6b6ee0816828dff3149" }
+xitca-http = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "a470092d8f1e1c3bb7a9831c175bf112b70f81e7" }
+xitca-server = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "a470092d8f1e1c3bb7a9831c175bf112b70f81e7" }
+xitca-service = { git = "https://github.com/fakeshadow/xitca-web.git", rev = "a470092d8f1e1c3bb7a9831c175bf112b70f81e7" }

+ 103 - 31
frameworks/Rust/xitca-web/src/db_diesel.rs

@@ -1,28 +1,105 @@
-use std::{error::Error, io};
+use std::{error::Error, fmt, future::Future, io, time::Duration};
 
-use diesel::{prelude::*, r2d2};
+use diesel::prelude::*;
 use rand::{rngs::SmallRng, Rng, SeedableRng};
-use tokio::task::spawn_blocking;
+use tang_rs::{Manager, ManagerFuture, ManagerTimeout, Pool};
+use tokio::{
+    task::spawn_blocking,
+    time::{sleep, Sleep},
+};
 
 use super::ser::{Fortune, Fortunes, World};
 
 type DbResult<T> = Result<T, Box<dyn Error + Send + Sync + 'static>>;
 
+pub struct DieselPoolManager(String);
+
+impl Manager for DieselPoolManager {
+    type Connection = PgConnection;
+    type Error = DieselPoolError;
+    type Timeout = Sleep;
+    type TimeoutError = ();
+
+    fn connect(&self) -> ManagerFuture<Result<Self::Connection, Self::Error>> {
+        let conn = PgConnection::establish(self.0.as_str());
+        Box::pin(async move { Ok(conn?) })
+    }
+
+    fn is_valid<'a>(
+        &'a self,
+        _: &'a mut Self::Connection,
+    ) -> ManagerFuture<'a, Result<(), Self::Error>> {
+        Box::pin(async { Ok(()) })
+    }
+
+    fn is_closed(&self, _: &mut Self::Connection) -> bool {
+        false
+    }
+
+    fn spawn<Fut>(&self, fut: Fut)
+    where
+        Fut: Future<Output = ()> + 'static,
+    {
+        tokio::task::spawn_local(fut);
+    }
+
+    fn timeout<Fut: Future>(&self, fut: Fut, dur: Duration) -> ManagerTimeout<Fut, Self::Timeout> {
+        ManagerTimeout::new(fut, sleep(dur))
+    }
+}
+
+pub enum DieselPoolError {
+    Inner(ConnectionError),
+    TimeOut,
+}
+
+impl fmt::Debug for DieselPoolError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            DieselPoolError::Inner(e) => e.fmt(f),
+            DieselPoolError::TimeOut => f
+                .debug_struct("DieselPoolError")
+                .field("source", &"Connection Timeout")
+                .finish(),
+        }
+    }
+}
+
+impl fmt::Display for DieselPoolError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "{:?}", self)
+    }
+}
+
+impl Error for DieselPoolError {}
+
+impl From<ConnectionError> for DieselPoolError {
+    fn from(e: ConnectionError) -> Self {
+        Self::Inner(e)
+    }
+}
+
+impl From<()> for DieselPoolError {
+    fn from(_: ()) -> Self {
+        Self::TimeOut
+    }
+}
+
 #[derive(Clone)]
 pub struct DieselPool {
-    pool: r2d2::Pool<r2d2::ConnectionManager<PgConnection>>,
+    pool: Pool<DieselPoolManager>,
     rng: SmallRng,
 }
 
-pub fn connect(config: &str) -> io::Result<DieselPool> {
-    let manager = r2d2::ConnectionManager::new(config);
-    let pool = r2d2::Builder::new()
+pub async fn create(config: &str) -> io::Result<DieselPool> {
+    let pool = tang_rs::Builder::new()
         .max_size(5)
-        .min_idle(Some(5))
-        .test_on_check_out(false)
+        .min_idle(5)
+        .always_check(false)
         .idle_timeout(None)
         .max_lifetime(None)
-        .build(manager)
+        .build(DieselPoolManager(String::from(config)))
+        .await
         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
 
     Ok(DieselPool {
@@ -33,17 +110,16 @@ pub fn connect(config: &str) -> io::Result<DieselPool> {
 
 impl DieselPool {
     pub async fn get_world(&self) -> DbResult<World> {
-        let mut this = self.clone();
+        let mut rng = self.rng.clone();
+        let conn = self.pool.get_owned().await?;
 
         spawn_blocking(move || {
             use crate::schema::world::dsl::*;
 
-            let conn = this.pool.get()?;
-
-            let random_id = this.rng.gen_range(1..10_001);
+            let random_id = rng.gen_range(1..10_001);
             let w = world
                 .filter(id.eq(random_id))
-                .load::<World>(&conn)?
+                .load::<World>(&*conn)?
                 .pop()
                 .unwrap();
 
@@ -53,19 +129,18 @@ impl DieselPool {
     }
 
     pub async fn get_worlds(&self, num: u16) -> DbResult<Vec<World>> {
-        let mut this = self.clone();
+        let mut rng = self.rng.clone();
+        let conn = self.pool.get_owned().await?;
 
         spawn_blocking(move || {
             use crate::schema::world::dsl::*;
 
-            let conn = this.pool.get()?;
-
             (0..num)
                 .map(|_| {
-                    let w_id = this.rng.gen_range(1..10_001);
+                    let w_id = rng.gen_range(1..10_001);
                     let w = world
                         .filter(id.eq(w_id))
-                        .load::<World>(&conn)?
+                        .load::<World>(&*conn)?
                         .pop()
                         .unwrap();
                     Ok(w)
@@ -76,22 +151,21 @@ impl DieselPool {
     }
 
     pub async fn update(&self, num: u16) -> DbResult<Vec<World>> {
-        let mut this = self.clone();
+        let mut rng = self.rng.clone();
+        let conn = self.pool.get_owned().await?;
 
         spawn_blocking(move || {
             use crate::schema::world::dsl::*;
 
-            let conn = this.pool.get()?;
-
             let mut worlds = (0..num)
                 .map(|_| {
-                    let w_id: i32 = this.rng.gen_range(1..10_001);
+                    let w_id: i32 = rng.gen_range(1..10_001);
                     let mut w = world
                         .filter(id.eq(w_id))
-                        .load::<World>(&conn)?
+                        .load::<World>(&*conn)?
                         .pop()
                         .unwrap();
-                    w.randomnumber = this.rng.gen_range(1..10_001);
+                    w.randomnumber = rng.gen_range(1..10_001);
                     Ok(w)
                 })
                 .collect::<DbResult<Vec<_>>>()?;
@@ -103,7 +177,7 @@ impl DieselPool {
                     diesel::update(world)
                         .filter(id.eq(w.id))
                         .set(randomnumber.eq(w.randomnumber))
-                        .execute(&conn)?;
+                        .execute(&*conn)?;
                 }
                 Ok(())
             })?;
@@ -114,14 +188,12 @@ impl DieselPool {
     }
 
     pub async fn tell_fortune(&self) -> DbResult<Fortunes> {
-        let this = self.clone();
+        let conn = self.pool.get_owned().await?;
 
         spawn_blocking(move || {
             use crate::schema::fortune::dsl::*;
 
-            let conn = this.pool.get()?;
-
-            let mut items = fortune.load::<Fortune>(&conn)?;
+            let mut items = fortune.load::<Fortune>(&*conn)?;
 
             items.push(Fortune::new(0, "Additional fortune added at request time."));
             items.sort_by(|it, next| it.message.cmp(&next.message));

+ 3 - 6
frameworks/Rust/xitca-web/src/main.rs

@@ -14,7 +14,7 @@ use std::{
 
 use bytes::Bytes;
 use xitca_http::http::{
-    header::{HeaderValue, CONTENT_TYPE, SERVER},
+    header::{CONTENT_TYPE, SERVER},
     Method,
 };
 use xitca_web::{dev::fn_service, request::WebRequest, App, HttpServer};
@@ -79,12 +79,9 @@ async fn fortunes(req: &mut WebRequest<'_, State>) -> HandleResult {
         Ok(body) => {
             let mut res = req.as_response(body);
 
+            res.headers_mut().append(SERVER, util::SERVER_HEADER_VALUE);
             res.headers_mut()
-                .append(SERVER, HeaderValue::from_static("TFB"));
-            res.headers_mut().append(
-                CONTENT_TYPE,
-                HeaderValue::from_static("text/html; charset=utf-8"),
-            );
+                .append(CONTENT_TYPE, util::HTML_HEADER_VALUE);
 
             Ok(res)
         }

+ 7 - 11
frameworks/Rust/xitca-web/src/main_diesel.rs

@@ -13,12 +13,12 @@ use std::{error::Error, io};
 
 use bytes::Bytes;
 use xitca_http::http::{
-    header::{HeaderValue, CONTENT_TYPE, SERVER},
+    header::{CONTENT_TYPE, SERVER},
     Method,
 };
 use xitca_web::{dev::fn_service, request::WebRequest, App, HttpServer};
 
-use self::db_diesel::{connect, DieselPool};
+use self::db_diesel::{create, DieselPool};
 use self::util::{
     internal, json, json_response, not_found, plain_text, AppState, HandleResult, QueryParse,
 };
@@ -30,10 +30,9 @@ async fn main() -> io::Result<()> {
     let config = "postgres://benchmarkdbuser:benchmarkdbpass@tfb-database/hello_world";
 
     HttpServer::new(move || {
-        let pool = connect(config).unwrap();
-        App::with_async_state(move || {
-            let pool = pool.clone();
-            async move { AppState::new(pool.clone()) }
+        App::with_async_state(move || async move {
+            let pool = create(config).await.unwrap();
+            AppState::new(pool)
         })
         .service(fn_service(handle))
     })
@@ -70,12 +69,9 @@ async fn fortunes(req: &mut WebRequest<'_, State>) -> HandleResult {
         Ok(body) => {
             let mut res = req.as_response(body);
 
+            res.headers_mut().append(SERVER, util::SERVER_HEADER_VALUE);
             res.headers_mut()
-                .append(SERVER, HeaderValue::from_static("TFB"));
-            res.headers_mut().append(
-                CONTENT_TYPE,
-                HeaderValue::from_static("text/html; charset=utf-8"),
-            );
+                .append(CONTENT_TYPE, util::HTML_HEADER_VALUE);
 
             Ok(res)
         }

+ 13 - 9
frameworks/Rust/xitca-web/src/util.rs

@@ -83,13 +83,19 @@ impl<C> AppState<C> {
     }
 }
 
+pub const SERVER_HEADER_VALUE: HeaderValue = HeaderValue::from_static("TFB");
+
+pub const HTML_HEADER_VALUE: HeaderValue = HeaderValue::from_static("text/html; charset=utf-8");
+
+const TEXT_HEADER_VALUE: HeaderValue = HeaderValue::from_static("text/plain");
+
+const JSON_HEADER_VALUE: HeaderValue = HeaderValue::from_static("application/json");
+
 pub(super) fn plain_text<D>(req: &mut WebRequest<'_, D>) -> HandleResult {
     let mut res = req.as_response(Bytes::from_static(b"Hello, World!"));
 
-    res.headers_mut()
-        .append(SERVER, HeaderValue::from_static("TFB"));
-    res.headers_mut()
-        .append(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
+    res.headers_mut().append(SERVER, SERVER_HEADER_VALUE);
+    res.headers_mut().append(CONTENT_TYPE, TEXT_HEADER_VALUE);
 
     Ok(res)
 }
@@ -109,10 +115,8 @@ where
     let body = writer.take();
 
     let mut res = req.as_response(body);
-    res.headers_mut()
-        .append(SERVER, HeaderValue::from_static("TFB"));
-    res.headers_mut()
-        .append(CONTENT_TYPE, HeaderValue::from_static("application/json"));
+    res.headers_mut().append(SERVER, SERVER_HEADER_VALUE);
+    res.headers_mut().append(CONTENT_TYPE, JSON_HEADER_VALUE);
 
     Ok(res)
 }
@@ -124,7 +128,7 @@ macro_rules! error {
         pub(super) fn $error() -> HandleResult {
             Ok(WebResponseBuilder::new()
                 .status($code)
-                .header(SERVER, HeaderValue::from_static("TFB"))
+                .header(SERVER, SERVER_HEADER_VALUE)
                 .body(Bytes::new().into())
                 .unwrap())
         }