Kaynağa Gözat

tokio is needed by both temporal & gcloud pubsub, so make just one instance for the whole library, add init/shutdown functions for it exposed to C

Grant Limberg 1 ay önce
ebeveyn
işleme
837f15d01b

+ 4 - 0
controller/CV1.cpp

@@ -67,6 +67,8 @@ CV1::CV1(const Identity& myId, const char* path, int listenPort, RedisConfig* rc
 	auto span = tracer->StartSpan("cv1::CV1");
 	auto scope = tracer->WithActiveSpan(span);
 
+	rustybits::init_async_runtime();
+
 	char myAddress[64];
 	_myAddressStr = myId.address().toString(myAddress);
 	_connString = std::string(path);
@@ -157,6 +159,8 @@ CV1::~CV1()
 		_smee = NULL;
 	}
 
+	rustybits::shutdown_async_runtime();
+
 	_run = 0;
 	std::this_thread::sleep_for(std::chrono::milliseconds(100));
 

+ 5 - 0
controller/CV2.cpp

@@ -26,6 +26,7 @@
 #include <climits>
 #include <iomanip>
 #include <libpq-fe.h>
+#include <rustybits.h>
 #include <sstream>
 
 using json = nlohmann::json;
@@ -43,6 +44,8 @@ CV2::CV2(const Identity& myId, const char* path, int listenPort) : DB(), _pool()
 	auto span = tracer->StartSpan("cv2::CV2");
 	auto scope = tracer->WithActiveSpan(span);
 
+	rustybits::init_async_runtime();
+
 	fprintf(stderr, "CV2::CV2\n");
 	char myAddress[64];
 	_myAddressStr = myId.address().toString(myAddress);
@@ -83,6 +86,8 @@ CV2::CV2(const Identity& myId, const char* path, int listenPort) : DB(), _pool()
 
 CV2::~CV2()
 {
+	rustybits::shutdown_async_runtime();
+
 	_run = 0;
 	std::this_thread::sleep_for(std::chrono::milliseconds(100));
 

+ 33 - 2
rustybits/src/ext.rs

@@ -14,6 +14,38 @@ use std::ffi::{CStr, CString};
 use std::os::raw::c_char;
 use url::Url;
 
+static mut RT: Option<tokio::runtime::Runtime> = None;
+
+static START: std::sync::Once = std::sync::Once::new();
+static SHUTDOWN: std::sync::Once = std::sync::Once::new();
+
+#[no_mangle]
+pub unsafe extern "C" fn init_async_runtime() {
+    START.call_once(|| {
+        let rt = tokio::runtime::Builder::new_multi_thread()
+            .worker_threads(4)
+            .thread_name("rust-async-worker")
+            .enable_all()
+            .build()
+            .expect("Failed to create tokio runtime");
+
+        unsafe { RT = Some(rt) };
+    });
+}
+
+#[no_mangle]
+#[allow(static_mut_refs)]
+pub unsafe extern "C" fn shutdown_async_runtime() {
+    SHUTDOWN.call_once(|| {
+        // Shutdown the tokio runtime
+        unsafe {
+            if let Some(rt) =  RT.take() {
+                rt.shutdown_timeout(std::time::Duration::from_secs(5));
+            }
+        }
+    });
+}
+
 #[cfg(feature = "zeroidc")]
 use crate::zeroidc::ZeroIDC;
 
@@ -419,8 +451,7 @@ pub unsafe extern "C" fn smee_client_delete(ptr: *mut SmeeClient) {
         assert!(!ptr.is_null());
         Box::from_raw(&mut *ptr)
     };
-
-    smee.shutdown();
+    drop(smee);
 }
 
 #[cfg(feature = "ztcontroller")]

+ 18 - 0
rustybits/src/pubsub/mod.rs

@@ -9,3 +9,21 @@
  * On the date above, in accordance with the Business Source License, use
  * of this software will be governed by version 2.0 of the Apache License.
  */
+
+use gcloud_pubsub::client::{Client, ClientConfig};
+
+pub struct PubSubClient {
+    client: Client,
+}
+
+impl PubSubClient {
+    pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
+        let config = ClientConfig::default().with_auth().await.unwrap();
+        let client = Client::new(config).await?;
+
+        // Assuming a topic name is required for the client
+        let topic_name = "default-topic".to_string();
+
+        Ok(Self { client })
+    }
+}

+ 5 - 14
rustybits/src/smeeclient/mod.rs

@@ -17,6 +17,7 @@ use temporal_sdk_core_protos::{
     coresdk::AsJsonPayloadExt,
     temporal::api::enums::v1::{WorkflowIdConflictPolicy, WorkflowIdReusePolicy},
 };
+use tokio::runtime::{Handle, Runtime};
 use url::Url;
 use uuid::Uuid;
 
@@ -43,16 +44,13 @@ impl NetworkJoinedParams {
 }
 
 pub struct SmeeClient {
-    tokio_rt: tokio::runtime::Runtime,
     client: RetryClient<Client>,
     task_queue: String,
 }
 
 impl SmeeClient {
     pub fn new(temporal_url: &str, namespace: &str, task_queue: &str) -> Result<Self, Box<dyn std::error::Error>> {
-        // start tokio runtime.  Required by temporal
-        let rt = tokio::runtime::Runtime::new()?;
-
+        let rt = Handle::current();
         let c = ClientOptionsBuilder::default()
             .target_url(Url::from_str(temporal_url).unwrap())
             .client_name(CLIENT_NAME)
@@ -61,11 +59,7 @@ impl SmeeClient {
 
         let con = rt.block_on(async { c.connect(namespace.to_string(), None).await })?;
 
-        Ok(Self {
-            tokio_rt: rt,
-            client: con,
-            task_queue: task_queue.to_string(),
-        })
+        Ok(Self { client: con, task_queue: task_queue.to_string() })
     }
 
     pub fn notify_network_joined(&self, params: NetworkJoinedParams) -> Result<(), Box<dyn std::error::Error>> {
@@ -89,7 +83,8 @@ impl SmeeClient {
 
         let workflow_id = Uuid::new_v4();
 
-        self.tokio_rt.block_on(async {
+        let rt = Handle::current();
+        rt.block_on(async {
             println!("calilng start_workflow");
             self.client
                 .start_workflow(
@@ -105,8 +100,4 @@ impl SmeeClient {
 
         Ok(())
     }
-
-    pub fn shutdown(self) {
-        self.tokio_rt.shutdown_timeout(Duration::from_secs(5))
-    }
 }