Browse Source

On our way to processing tokens

Grant Limberg 3 years ago
parent
commit
4ce810b421
3 changed files with 80 additions and 11 deletions
  1. 31 2
      service/OneService.cpp
  2. 14 5
      zeroidc/src/ext.rs
  3. 35 4
      zeroidc/src/lib.rs

+ 31 - 2
service/OneService.cpp

@@ -321,6 +321,25 @@ public:
 		return "";
 	}
 
+	void doTokenExchange(const char *code) {
+		if (_ainfo == nullptr || _idc == nullptr) {
+			fprintf(stderr, "ainfo or idc null\n");
+			return;
+		}
+
+		zeroidc::zeroidc_token_exchange(_idc, _ainfo, code);
+		zeroidc::zeroidc_auth_info_delete(_ainfo);
+		_ainfo = zeroidc::zeroidc_get_auth_info(
+			_idc,
+			_config.ssoState,
+			_config.ssoNonce
+		);
+
+		const char* url = zeroidc::zeroidc_get_auth_url(_ainfo);
+		memcpy(_config.authenticationURL, url, strlen(url));
+		_config.authenticationURL[strlen(url)] = 0;
+	}
+
 private:
 	unsigned int _webPort;
 	std::shared_ptr<EthernetTap> _tap;
@@ -1649,11 +1668,21 @@ public:
 				fprintf(stderr, "path: %s\n", path.c_str());
 				fprintf(stderr, "body: %s\n", body.c_str());
 
-				const char* state = zeroidc::zeroidc_get_state_param_value(path.c_str());
+				const char* state = zeroidc::zeroidc_get_url_param_value("state", path.c_str());
 				const char* nwid = zeroidc::zeroidc_network_id_from_state(state);
 				fprintf(stderr, "state: %s\n", state);
 				fprintf(stderr, "nwid: %s\n", nwid);
-				scode = 200;
+
+				const uint64_t id = Utils::hexStrToU64(nwid);
+				Mutex::Lock l(_nets_m);
+				if (_nets.find(id) != _nets.end()) {
+					NetworkState& ns = _nets[id];
+					const char* code = zeroidc::zeroidc_get_url_param_value("code", path.c_str());
+					ns.doTokenExchange(code);
+					scode = 200;
+				} else {
+					scode = 404;
+				}
 			} else {
 				scode = 401; // isAuth == false && !sso
 			}

+ 14 - 5
zeroidc/src/ext.rs

@@ -179,23 +179,32 @@ pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, ai: *mut AuthInfo, c
         println!("ai is null");
         return
     }
+    if code.is_null() {
+        println!("code is null");
+        return
+    }
     let idc = unsafe {
         &mut *idc
     };
     let ai = unsafe {
         &mut *ai
     };
+    let code = unsafe{CStr::from_ptr(code)}.to_str().unwrap();
 
-
+    idc.do_token_exchange(ai, code);
 }
 
 #[no_mangle]
-pub extern "C" fn zeroidc_get_state_param_value(path: *const c_char) -> *const c_char {
+pub extern "C" fn zeroidc_get_url_param_value(param: *const c_char, path: *const c_char) -> *const c_char {
+    if param.is_null() {
+        println!("param is null");
+        return std::ptr::null();
+    }
     if path.is_null() {
         println!("path is null");
         return std::ptr::null();
     }
-
+    let param = unsafe {CStr::from_ptr(param)}.to_str().unwrap();
     let path =  unsafe {CStr::from_ptr(path)}.to_str().unwrap();
 
     let url = "http://localhost:9993".to_string() + path;
@@ -203,7 +212,7 @@ pub extern "C" fn zeroidc_get_state_param_value(path: *const c_char) -> *const c
 
     let mut pairs = url.query_pairs();  
     for p in pairs {
-        if p.0 == "state" {
+        if p.0 == param {
             let s = CString::new(p.1.into_owned()).unwrap();
             return s.into_raw()
         }
@@ -229,4 +238,4 @@ pub extern "C" fn zeroidc_network_id_from_state(state: *const c_char) -> *const
 
     let s = CString::new(split[1]).unwrap();
     return s.into_raw();
-}
+}

+ 35 - 4
zeroidc/src/lib.rs

@@ -10,10 +10,11 @@ use std::time::Duration;
 
 use openidconnect::core::{CoreClient, CoreProviderMetadata, CoreResponseType};
 use openidconnect::reqwest::http_client;
-use openidconnect::{AuthenticationFlow, PkceCodeVerifier};
-use openidconnect::{ClientId, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, RedirectUrl, Scope};
+use openidconnect::{AuthenticationFlow, PkceCodeVerifier, TokenResponse, OAuth2TokenResponse};
+use openidconnect::{AuthorizationCode, ClientId, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, RedirectUrl, RequestTokenError, Scope};
 
 use url::Url;
+use std::borrow::BorrowMut;
 
 pub struct ZeroIDC {
     inner: Arc<Mutex<Inner>>,
@@ -39,7 +40,7 @@ pub struct AuthInfo {
     url: Url,
     csrf_token: CsrfToken,
     nonce: Nonce,
-    pkce_verifier: PkceCodeVerifier,
+    pkce_verifier: Option<PkceCodeVerifier>,
 }
 
 impl ZeroIDC {
@@ -147,6 +148,36 @@ impl ZeroIDC {
         return (*self.inner.lock().unwrap()).network_id.clone()
     }
 
+    fn do_token_exchange(&mut self, auth_info: &mut AuthInfo, code: &str) {
+        if let Some(verifier) = auth_info.pkce_verifier.take() {
+            let token_response = (*self.inner.lock().unwrap()).oidc_client.as_ref().map(|c| {
+                let r = c.exchange_code(AuthorizationCode::new(code.to_string()))
+                    .set_pkce_verifier(verifier)
+                    .request(http_client);
+                match r {
+                    Ok(res) =>{
+                         return Some(res);
+                    },
+                    Err(e) => {
+                        println!("token response error");
+                        return None;
+                    },
+                }
+            });
+            // TODO: do stuff with token response
+            if let Some(Some(tok)) = token_response {
+                let id_token = tok.id_token().unwrap();
+                let claims = (*self.inner.lock().unwrap()).oidc_client.as_ref().map(|c| {
+
+                });
+                let access_token = tok.access_token();
+                let refresh_token = tok.refresh_token();
+            }
+        } else {
+            println!("No pkce verifier!  Can't exchange tokens!!!");
+        }
+    }
+
     fn get_auth_info(&mut self, csrf_token: String, nonce: String) -> Option<AuthInfo> {
         let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
         let network_id = self.get_network_id();
@@ -170,9 +201,9 @@ impl ZeroIDC {
 
             return AuthInfo {
                 url: auth_url,
+                pkce_verifier: Some(pkce_verifier),
                 csrf_token,
                 nonce,
-                pkce_verifier,
             };
         });