Browse Source

refactor out the separate AuthInfo struct

consolidated everything into the single IDC struct.  Should help keep from rotating the pkce token as often & causing issues with the login window flapping
Grant Limberg 3 years ago
parent
commit
df9a7497b1
3 changed files with 176 additions and 150 deletions
  1. 9 22
      service/OneService.cpp
  2. 14 33
      zeroidc/src/ext.rs
  3. 153 95
      zeroidc/src/lib.rs

+ 9 - 22
service/OneService.cpp

@@ -155,7 +155,6 @@ public:
 		: _webPort(9993)
 		: _webPort(9993)
 		, _tap((EthernetTap *)0)
 		, _tap((EthernetTap *)0)
 		, _idc(nullptr)
 		, _idc(nullptr)
-		, _ainfo(nullptr)
 	{
 	{
 		// Real defaults are in network 'up' code in network event handler
 		// Real defaults are in network 'up' code in network event handler
 		_settings.allowManaged = true;
 		_settings.allowManaged = true;
@@ -170,11 +169,6 @@ public:
 		this->_managedRoutes.clear();
 		this->_managedRoutes.clear();
 		this->_tap.reset();
 		this->_tap.reset();
 
 
-		if (_ainfo) {
-			zeroidc::zeroidc_auth_info_delete(_ainfo);
-			_ainfo = nullptr;
-		}
-
 		if (_idc) {
 		if (_idc) {
 			zeroidc::zeroidc_stop(_idc);
 			zeroidc::zeroidc_stop(_idc);
 			zeroidc::zeroidc_delete(_idc);
 			zeroidc::zeroidc_delete(_idc);
@@ -284,18 +278,13 @@ public:
 				// fprintf(stderr, "idc created (%s, %s, %s)\n", _config.issuerURL, _config.ssoClientID, _config.centralAuthURL);
 				// fprintf(stderr, "idc created (%s, %s, %s)\n", _config.issuerURL, _config.ssoClientID, _config.centralAuthURL);
 			}
 			}
 
 
-			if (_ainfo != nullptr) {
-				zeroidc::zeroidc_auth_info_delete(_ainfo);
-				_ainfo = nullptr;
-			}
-
-			_ainfo = zeroidc::zeroidc_get_auth_info(
+			zeroidc::zeroidc_set_nonce_and_csrf(
 				_idc,
 				_idc,
 				_config.ssoState,
 				_config.ssoState,
 				_config.ssoNonce
 				_config.ssoNonce
 			);
 			);
 
 
-			const char* url = zeroidc::zeroidc_get_auth_url(_ainfo);
+			const char* url = zeroidc::zeroidc_get_auth_url(_idc);
 			memcpy(_config.authenticationURL, url, strlen(url));
 			memcpy(_config.authenticationURL, url, strlen(url));
 			_config.authenticationURL[strlen(url)] = 0;
 			_config.authenticationURL[strlen(url)] = 0;
 		}
 		}
@@ -314,28 +303,27 @@ public:
 	}
 	}
 
 
 	const char* getAuthURL() {
 	const char* getAuthURL() {
-		if (_ainfo != nullptr) {
-			return zeroidc::zeroidc_get_auth_url(_ainfo);
+		if (_idc != nullptr) {
+			return zeroidc::zeroidc_get_auth_url(_idc);
 		}
 		}
-		fprintf(stderr, "_ainfo is null\n");
+		fprintf(stderr, "_idc is null\n");
 		return "";
 		return "";
 	}
 	}
 
 
 	void doTokenExchange(const char *code) {
 	void doTokenExchange(const char *code) {
-		if (_ainfo == nullptr || _idc == nullptr) {
+		if (_idc == nullptr) {
 			fprintf(stderr, "ainfo or idc null\n");
 			fprintf(stderr, "ainfo or idc null\n");
 			return;
 			return;
 		}
 		}
 
 
-		zeroidc::zeroidc_token_exchange(_idc, _ainfo, code);
-		zeroidc::zeroidc_auth_info_delete(_ainfo);
-		_ainfo = zeroidc::zeroidc_get_auth_info(
+		zeroidc::zeroidc_token_exchange(_idc, code);
+		zeroidc::zeroidc_set_nonce_and_csrf(
 			_idc,
 			_idc,
 			_config.ssoState,
 			_config.ssoState,
 			_config.ssoNonce
 			_config.ssoNonce
 		);
 		);
 
 
-		const char* url = zeroidc::zeroidc_get_auth_url(_ainfo);
+		const char* url = zeroidc::zeroidc_get_auth_url(_idc);
 		memcpy(_config.authenticationURL, url, strlen(url));
 		memcpy(_config.authenticationURL, url, strlen(url));
 		_config.authenticationURL[strlen(url)] = 0;
 		_config.authenticationURL[strlen(url)] = 0;
 	}
 	}
@@ -357,7 +345,6 @@ private:
 	std::map< InetAddress, SharedPtr<ManagedRoute> > _managedRoutes;
 	std::map< InetAddress, SharedPtr<ManagedRoute> > _managedRoutes;
 	OneService::NetworkSettings _settings;
 	OneService::NetworkSettings _settings;
 	zeroidc::ZeroIDC *_idc;
 	zeroidc::ZeroIDC *_idc;
-	zeroidc::AuthInfo *_ainfo;
 };
 };
 
 
 namespace {
 namespace {

+ 14 - 33
zeroidc/src/ext.rs

@@ -2,7 +2,7 @@ use std::ffi::{CStr, CString};
 use std::os::raw::c_char;
 use std::os::raw::c_char;
 use url::{Url};
 use url::{Url};
 
 
-use crate::{AuthInfo, ZeroIDC};
+use crate::ZeroIDC;
 
 
 #[no_mangle]
 #[no_mangle]
 pub extern "C" fn zeroidc_new(
 pub extern "C" fn zeroidc_new(
@@ -120,11 +120,10 @@ pub extern "C" fn zeroidc_get_exp_time(ptr: *mut ZeroIDC) -> u64 {
 // }
 // }
 
 
 #[no_mangle]
 #[no_mangle]
-pub extern "C" fn zeroidc_get_auth_info(
+pub extern "C" fn zeroidc_set_nonce_and_csrf(
     ptr: *mut ZeroIDC,
     ptr: *mut ZeroIDC,
     csrf_token: *const c_char,
     csrf_token: *const c_char,
-    nonce: *const c_char,
-) -> *mut AuthInfo {
+    nonce: *const c_char) {
     let idc = unsafe {
     let idc = unsafe {
         assert!(!ptr.is_null());
         assert!(!ptr.is_null());
         &mut *ptr
         &mut *ptr
@@ -132,12 +131,12 @@ pub extern "C" fn zeroidc_get_auth_info(
 
 
     if csrf_token.is_null() {
     if csrf_token.is_null() {
         println!("csrf_token is null");
         println!("csrf_token is null");
-        return std::ptr::null_mut();
+        return;
     }
     }
 
 
     if nonce.is_null() {
     if nonce.is_null() {
         println!("nonce is null");
         println!("nonce is null");
-        return std::ptr::null_mut();
+        return;
     }
     }
 
 
     let csrf_token = unsafe { CStr::from_ptr(csrf_token) }
     let csrf_token = unsafe { CStr::from_ptr(csrf_token) }
@@ -148,47 +147,31 @@ pub extern "C" fn zeroidc_get_auth_info(
         .to_str()
         .to_str()
         .unwrap()
         .unwrap()
         .to_string();
         .to_string();
-
-    match idc.get_auth_info(csrf_token, nonce) {
-        Some(a) => Box::into_raw(Box::new(a)),
-        None => std::ptr::null_mut(),
-    }
+     
+    idc.set_nonce_and_csrf(csrf_token, nonce);
 }
 }
 
 
 #[no_mangle]
 #[no_mangle]
-pub extern "C" fn zeroidc_auth_info_delete(ptr: *mut AuthInfo) {
-    if ptr.is_null() {
-        return;
-    }
-    unsafe {
-        Box::from_raw(ptr);
-    }
-}
-
-#[no_mangle]
-pub extern "C" fn zeroidc_get_auth_url(ptr: *mut AuthInfo) -> *const c_char {
+pub extern "C" fn zeroidc_get_auth_url(ptr: *mut ZeroIDC) -> *const c_char {
     if ptr.is_null() {
     if ptr.is_null() {
         println!("passed a null object");
         println!("passed a null object");
         return std::ptr::null_mut();
         return std::ptr::null_mut();
     }
     }
-    let ai = unsafe {
+    let idc = unsafe {
         &mut *ptr
         &mut *ptr
     };
     };
     
     
-    let s = CString::new(ai.url.to_string()).unwrap();
+    let s = CString::new(idc.auth_url()).unwrap();
     return s.into_raw();
     return s.into_raw();
 }
 }
 
 
 #[no_mangle]
 #[no_mangle]
-pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, ai: *mut AuthInfo, code: *const c_char ) {
+pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, code: *const c_char ) {
     if idc.is_null() {
     if idc.is_null() {
         println!("idc is null");
         println!("idc is null");
         return
         return
     }
     }
-    if ai.is_null() {
-        println!("ai is null");
-        return
-    }
+
     if code.is_null() {
     if code.is_null() {
         println!("code is null");
         println!("code is null");
         return
         return
@@ -196,12 +179,10 @@ pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, ai: *mut AuthInfo, c
     let idc = unsafe {
     let idc = unsafe {
         &mut *idc
         &mut *idc
     };
     };
-    let ai = unsafe {
-        &mut *ai
-    };
+
     let code = unsafe{CStr::from_ptr(code)}.to_str().unwrap();
     let code = unsafe{CStr::from_ptr(code)}.to_str().unwrap();
 
 
-    idc.do_token_exchange(ai, code);
+    idc.do_token_exchange( code);
 }
 }
 
 
 #[no_mangle]
 #[no_mangle]

+ 153 - 95
zeroidc/src/lib.rs

@@ -30,6 +30,18 @@ struct Inner {
     access_token: Option<AccessToken>,
     access_token: Option<AccessToken>,
     refresh_token: Option<RefreshToken>,
     refresh_token: Option<RefreshToken>,
     exp_time: u64,
     exp_time: u64,
+
+    url: Option<Url>,
+    csrf_token: Option<CsrfToken>,
+    nonce: Option<Nonce>,
+    pkce_verifier: Option<PkceCodeVerifier>,
+}
+
+impl Inner {
+    #[inline]
+    fn as_opt(&mut self) -> Option<&mut Inner> {
+        Some(self)
+    }
 }
 }
 
 
 #[derive(Debug, Serialize, Deserialize)]
 #[derive(Debug, Serialize, Deserialize)]
@@ -45,13 +57,6 @@ fn nonce_func(nonce: String) -> Box<dyn Fn() -> Nonce> {
     return Box::new(move || Nonce::new(nonce.to_string()));
     return Box::new(move || Nonce::new(nonce.to_string()));
 }
 }
 
 
-pub struct AuthInfo {
-    url: Url,
-    csrf_token: CsrfToken,
-    nonce: Nonce,
-    pkce_verifier: Option<PkceCodeVerifier>,
-}
-
 fn systemtime_strftime<T>(dt: T, format: &str) -> String
 fn systemtime_strftime<T>(dt: T, format: &str) -> String
    where T: Into<OffsetDateTime>
    where T: Into<OffsetDateTime>
 {
 {
@@ -87,6 +92,11 @@ impl ZeroIDC {
                 access_token: None,
                 access_token: None,
                 refresh_token: None,
                 refresh_token: None,
                 exp_time: 0,
                 exp_time: 0,
+
+                url: None,
+                csrf_token: None,
+                nonce: None,
+                pkce_verifier: None, 
             })),
             })),
         };
         };
 
 
@@ -264,102 +274,150 @@ impl ZeroIDC {
         return (*self.inner.lock().unwrap()).exp_time;
         return (*self.inner.lock().unwrap()).exp_time;
     }
     }
 
 
-    fn do_token_exchange(&mut self, auth_info: &mut AuthInfo, code: &str) {
-        if let Some(verifier) = auth_info.pkce_verifier.take() {
-            let token_response = (*self.inner.lock().unwrap()).oidc_client.as_ref().map(|c| {
-                let r = c.exchange_code(AuthorizationCode::new(code.to_string()))
-                    .set_pkce_verifier(verifier)
-                    .request(http_client);
-                match r {
-                    Ok(res) =>{
-                         return Some(res);
-                    },
-                    Err(e) => {
-                        println!("token response error: {}", e.to_string());
-                        
-                        return None;
-                    },
+    fn set_nonce_and_csrf(&mut self, csrf_token: String, nonce: String) {
+        let local = Arc::clone(&self.inner);
+        (*local.lock().expect("can't lock inner")).as_opt().map(|i| {
+            let mut csrf_diff = false;
+            let mut nonce_diff = false;
+            let mut need_verifier = false;
+        
+            match i.pkce_verifier {
+                None => {
+                    need_verifier = true;
+                },
+                _ => (),
+            }
+            if let Some(csrf) = i.csrf_token.clone() {
+                if *csrf.secret() != csrf_token {
+                    csrf_diff = true;
                 }
                 }
-            });
-            // TODO: do stuff with token response
-            if let Some(Some(tok)) = token_response {
-                let id_token = tok.id_token().unwrap();
-                println!("ID token: {}", id_token.to_string());
-
-                let split = auth_info.csrf_token.secret().split("_");
-                let split = split.collect::<Vec<&str>>();
-                
-                let params = [("id_token", id_token.to_string()),("state", split[0].to_string())];
-                let client = reqwest::blocking::Client::new();
-                let res = client.post((*self.inner.lock().unwrap()).auth_endpoint.clone())
-                    .form(&params)
-                    .send();
-
-                match res {
-                    Ok(res) => {
-                        println!("hit url: {}", res.url().as_str());
-                        println!("Status: {}", res.status());
-
-                        let at = tok.access_token().secret();
-                        let exp = dangerous_insecure_decode::<Exp>(&at);
-                        if let Ok(e) = exp {
-                            (*self.inner.lock().unwrap()).exp_time = e.claims.exp
-                        }
-
-                        (*self.inner.lock().unwrap()).access_token = Some(tok.access_token().clone());
-                        if let Some(t) = tok.refresh_token() {
-                            (*self.inner.lock().unwrap()).refresh_token = Some(t.clone());
-                            self.start();
-                        }
-                    },
-                    Err(res) => {
-                        println!("hit url: {}", res.url().unwrap().as_str());
-                        println!("Status: {}", res.status().unwrap());
-                        println!("Post error: {}", res.to_string());
-                        (*self.inner.lock().unwrap()).exp_time = 0;
-                    }
+            }
+            if let Some(n) = i.nonce.clone() {
+                if *n.secret() != nonce {
+                    nonce_diff = true;
                 }
                 }
+            }
 
 
-                let access_token = tok.access_token();
-                println!("Access Token: {}", access_token.secret());
+            if need_verifier || csrf_diff || nonce_diff {
+                let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
+                let r = i.oidc_client.as_ref().map(|c| {
+                    let (auth_url, csrf_token, nonce) = c
+                    .authorize_url(
+                        AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
+                        csrf_func(csrf_token),
+                        nonce_func(nonce),
+                    )
+                    .add_scope(Scope::new("profile".to_string()))
+                    .add_scope(Scope::new("email".to_string()))
+                    .add_scope(Scope::new("offline_access".to_string()))
+                    .add_scope(Scope::new("openid".to_string()))
+                    .set_pkce_challenge(pkce_challenge)
+                    .url();
+
+                    (auth_url, csrf_token, nonce)
+                });
+
+                if let Some(r) = r {
+                    i.url = Some(r.0);
+                    i.csrf_token = Some(r.1);
+                    i.nonce = Some(r.2);
+                    i.pkce_verifier = Some(pkce_verifier);
+                }
+            }
+        });
+    }
 
 
-                let refresh_token = tok.refresh_token();
-                println!("Refresh Token: {}", refresh_token.unwrap().secret());
+    fn auth_url(&self) -> String {
+        let url = (*self.inner.lock().expect("can't lock inner")).as_opt().map(|i| {
+            match i.url.clone() {
+                Some(u) => u.to_string(),
+                _ => "".to_string(),
             }
             }
-        } else {
-            println!("No pkce verifier!  Can't exchange tokens!!!");
+        });
+
+        match url {
+            Some(url) => url.to_string(),
+            None => "".to_string(),
         }
         }
     }
     }
 
 
-    fn get_auth_info(&mut self, csrf_token: String, nonce: String) -> Option<AuthInfo> {
-        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
-        let network_id = self.get_network_id();
-
-        let r = (*self.inner.lock().unwrap()).oidc_client.as_ref().map(|c| {
-            let (auth_url, csrf_token, nonce) = c
-                .authorize_url(
-                    AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
-                    csrf_func(csrf_token),
-                    nonce_func(nonce),
-                )
-                .add_scope(Scope::new("profile".to_string()))
-                .add_scope(Scope::new("email".to_string()))
-                .add_scope(Scope::new("offline_access".to_string()))
-                .add_scope(Scope::new("openid".to_string()))
-                .set_pkce_challenge(pkce_challenge)
-                .add_extra_param("network_id", network_id)
-                .url();
-
-            // println!("URL: {}", auth_url);
-
-            return AuthInfo {
-                url: auth_url,
-                pkce_verifier: Some(pkce_verifier),
-                csrf_token,
-                nonce,
-            };
-        });
+    fn do_token_exchange(&mut self, code: &str) {
+        let local = Arc::clone(&self.inner);
+        (*local.lock().unwrap()).as_opt().map(|i| {
+            if let Some(verifier) = i.pkce_verifier.take() {
+                let token_response = i.oidc_client.as_ref().map(|c| {
+                    let r = c.exchange_code(AuthorizationCode::new(code.to_string()))
+                        .set_pkce_verifier(verifier)
+                        .request(http_client);
+                    match r {
+                        Ok(res) =>{
+                            return Some(res);
+                        },
+                        Err(e) => {
+                            println!("token response error: {}", e.to_string());
+                            
+                            return None;
+                        },
+                    }
+                });
+                // TODO: do stuff with token response
+                if let Some(Some(tok)) = token_response {
+                    let id_token = tok.id_token().unwrap();
+                    println!("ID token: {}", id_token.to_string());
+
+                    let mut split = "".to_string();
+                    match i.csrf_token.clone() {
+                        Some(csrf_token) => {
+                            split = csrf_token.secret().to_owned();
+                        },
+                        _ => (),
+                    }
+
+                    let split = split.split("_").collect::<Vec<&str>>();
+                    
+                    if split.len() == 2 {
+                        let params = [("id_token", id_token.to_string()),("state", split[0].to_string())];
+                        let client = reqwest::blocking::Client::new();
+                        let res = client.post((*self.inner.lock().unwrap()).auth_endpoint.clone())
+                            .form(&params)
+                            .send();
+
+                        match res {
+                            Ok(res) => {
+                                println!("hit url: {}", res.url().as_str());
+                                println!("Status: {}", res.status());
+
+                                let at = tok.access_token().secret();
+                                let exp = dangerous_insecure_decode::<Exp>(&at);
+                                if let Ok(e) = exp {
+                                    (*self.inner.lock().unwrap()).exp_time = e.claims.exp
+                                }
 
 
-        r
+                                (*self.inner.lock().unwrap()).access_token = Some(tok.access_token().clone());
+                                if let Some(t) = tok.refresh_token() {
+                                    (*self.inner.lock().unwrap()).refresh_token = Some(t.clone());
+                                    self.start();
+                                }
+                            },
+                            Err(res) => {
+                                println!("hit url: {}", res.url().unwrap().as_str());
+                                println!("Status: {}", res.status().unwrap());
+                                println!("Post error: {}", res.to_string());
+                                (*self.inner.lock().unwrap()).exp_time = 0;
+                            }
+                        }
+
+                        let access_token = tok.access_token();
+                        println!("Access Token: {}", access_token.secret());
+
+                        let refresh_token = tok.refresh_token();
+                        println!("Refresh Token: {}", refresh_token.unwrap().secret());
+                    } else {
+                        println!("invalid split length?!?");
+                    }
+                }
+            }
+        });
     }
     }
 }
 }
+