浏览代码

rusftormat zeroidc

Grant Limberg 2 年之前
父节点
当前提交
f9af9a15f2
共有 4 个文件被更改,包括 114 次插入201 次删除
  1. 1 4
      rustybits/zeroidc/build.rs
  2. 1 3
      rustybits/zeroidc/src/error.rs
  3. 4 17
      rustybits/zeroidc/src/ext.rs
  4. 108 177
      rustybits/zeroidc/src/lib.rs

+ 1 - 4
rustybits/zeroidc/build.rs

@@ -8,10 +8,7 @@ fn main() {
     let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
     let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
 
 
     let package_name = env::var("CARGO_PKG_NAME").unwrap();
     let package_name = env::var("CARGO_PKG_NAME").unwrap();
-    let output_file = target_dir()
-        .join(format!("{}.h", package_name))
-        .display()
-        .to_string();
+    let output_file = target_dir().join(format!("{}.h", package_name)).display().to_string();
 
 
     let config = Config {
     let config = Config {
         language: Language::C,
         language: Language::C,

+ 1 - 3
rustybits/zeroidc/src/error.rs

@@ -15,9 +15,7 @@ use thiserror::Error;
 #[derive(Error, Debug)]
 #[derive(Error, Debug)]
 pub enum ZeroIDCError {
 pub enum ZeroIDCError {
     #[error(transparent)]
     #[error(transparent)]
-    DiscoveryError(
-        #[from] openidconnect::DiscoveryError<openidconnect::reqwest::Error<reqwest::Error>>,
-    ),
+    DiscoveryError(#[from] openidconnect::DiscoveryError<openidconnect::reqwest::Error<reqwest::Error>>),
 
 
     #[error(transparent)]
     #[error(transparent)]
     ParseError(#[from] url::ParseError),
     ParseError(#[from] url::ParseError),

+ 4 - 17
rustybits/zeroidc/src/ext.rs

@@ -160,11 +160,7 @@ pub extern "C" fn zeroidc_get_exp_time(ptr: *mut ZeroIDC) -> u64 {
     target_os = "macos",
     target_os = "macos",
 ))]
 ))]
 #[no_mangle]
 #[no_mangle]
-pub extern "C" fn zeroidc_set_nonce_and_csrf(
-    ptr: *mut ZeroIDC,
-    csrf_token: *const c_char,
-    nonce: *const c_char,
-) {
+pub extern "C" fn zeroidc_set_nonce_and_csrf(ptr: *mut ZeroIDC, csrf_token: *const c_char, nonce: *const c_char) {
     let idc = unsafe {
     let idc = unsafe {
         assert!(!ptr.is_null());
         assert!(!ptr.is_null());
         &mut *ptr
         &mut *ptr
@@ -180,14 +176,8 @@ pub extern "C" fn zeroidc_set_nonce_and_csrf(
         return;
         return;
     }
     }
 
 
-    let csrf_token = unsafe { CStr::from_ptr(csrf_token) }
-        .to_str()
-        .unwrap()
-        .to_string();
-    let nonce = unsafe { CStr::from_ptr(nonce) }
-        .to_str()
-        .unwrap()
-        .to_string();
+    let csrf_token = unsafe { CStr::from_ptr(csrf_token) }.to_str().unwrap().to_string();
+    let nonce = unsafe { CStr::from_ptr(nonce) }.to_str().unwrap().to_string();
 
 
     idc.set_nonce_and_csrf(csrf_token, nonce);
     idc.set_nonce_and_csrf(csrf_token, nonce);
 }
 }
@@ -275,10 +265,7 @@ pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, code: *const c_char)
 }
 }
 
 
 #[no_mangle]
 #[no_mangle]
-pub extern "C" fn zeroidc_get_url_param_value(
-    param: *const c_char,
-    path: *const c_char,
-) -> *mut c_char {
+pub extern "C" fn zeroidc_get_url_param_value(param: *const c_char, path: *const c_char) -> *mut c_char {
     if param.is_null() {
     if param.is_null() {
         println!("param is null");
         println!("param is null");
         return std::ptr::null_mut();
         return std::ptr::null_mut();

+ 108 - 177
rustybits/zeroidc/src/lib.rs

@@ -26,9 +26,8 @@ use jwt::Token;
 use openidconnect::core::{CoreClient, CoreProviderMetadata, CoreResponseType};
 use openidconnect::core::{CoreClient, CoreProviderMetadata, CoreResponseType};
 use openidconnect::reqwest::http_client;
 use openidconnect::reqwest::http_client;
 use openidconnect::{
 use openidconnect::{
-    AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, CsrfToken,
-    IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
-    RefreshToken, Scope, TokenResponse,
+    AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, CsrfToken, IssuerUrl, Nonce,
+    OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, TokenResponse,
 };
 };
 use std::error::Error;
 use std::error::Error;
 use std::str::from_utf8;
 use std::str::from_utf8;
@@ -153,13 +152,9 @@ impl ZeroIDC {
         let redirect = RedirectUrl::new(redir_url.to_string())?;
         let redirect = RedirectUrl::new(redir_url.to_string())?;
 
 
         idc.inner.lock().unwrap().oidc_client = Some(
         idc.inner.lock().unwrap().oidc_client = Some(
-            CoreClient::from_provider_metadata(
-                provider_meta,
-                ClientId::new(client_id.to_string()),
-                None,
-            )
-            .set_redirect_uri(redirect)
-            .set_auth_type(openidconnect::AuthType::RequestBody),
+            CoreClient::from_provider_metadata(provider_meta, ClientId::new(client_id.to_string()), None)
+                .set_redirect_uri(redirect)
+                .set_auth_type(openidconnect::AuthType::RequestBody),
         );
         );
 
 
         Ok(idc)
         Ok(idc)
@@ -184,22 +179,15 @@ impl ZeroIDC {
                 let nonce = inner_local.lock().unwrap().nonce.clone();
                 let nonce = inner_local.lock().unwrap().nonce.clone();
 
 
                 while running {
                 while running {
-                    let exp =
-                        UNIX_EPOCH + Duration::from_secs(inner_local.lock().unwrap().exp_time);
+                    let exp = UNIX_EPOCH + Duration::from_secs(inner_local.lock().unwrap().exp_time);
                     let now = SystemTime::now();
                     let now = SystemTime::now();
 
 
                     #[cfg(debug_assertions)]
                     #[cfg(debug_assertions)]
                     {
                     {
                         println!(
                         println!(
                             "refresh token thread tick, now: {}, exp: {}",
                             "refresh token thread tick, now: {}, exp: {}",
-                            systemtime_strftime(
-                                now,
-                                "[year]-[month]-[day] [hour]:[minute]:[second]"
-                            ),
-                            systemtime_strftime(
-                                exp,
-                                "[year]-[month]-[day] [hour]:[minute]:[second]"
-                            )
+                            systemtime_strftime(now, "[year]-[month]-[day] [hour]:[minute]:[second]"),
+                            systemtime_strftime(exp, "[year]-[month]-[day] [hour]:[minute]:[second]")
                         );
                         );
                     }
                     }
                     let refresh_token = inner_local.lock().unwrap().refresh_token.clone();
                     let refresh_token = inner_local.lock().unwrap().refresh_token.clone();
@@ -220,14 +208,11 @@ impl ZeroIDC {
                                 println!("Refresh Token: {}", refresh_token.secret());
                                 println!("Refresh Token: {}", refresh_token.secret());
                             }
                             }
 
 
-                            let token_response =
-                                inner_local.lock().unwrap().oidc_client.as_ref().map(|c| {
-                                    let res = c
-                                        .exchange_refresh_token(&refresh_token)
-                                        .request(http_client);
+                            let token_response = inner_local.lock().unwrap().oidc_client.as_ref().map(|c| {
+                                let res = c.exchange_refresh_token(&refresh_token).request(http_client);
 
 
-                                    res
-                                });
+                                res
+                            });
 
 
                             if let Some(res) = token_response {
                             if let Some(res) = token_response {
                                 match res {
                                 match res {
@@ -246,20 +231,11 @@ impl ZeroIDC {
                                                 ];
                                                 ];
                                                 #[cfg(debug_assertions)]
                                                 #[cfg(debug_assertions)]
                                                 {
                                                 {
-                                                    println!(
-                                                        "New ID token: {}",
-                                                        id_token.to_string()
-                                                    );
+                                                    println!("New ID token: {}", id_token.to_string());
                                                 }
                                                 }
                                                 let client = reqwest::blocking::Client::new();
                                                 let client = reqwest::blocking::Client::new();
                                                 let r = client
                                                 let r = client
-                                                    .post(
-                                                        inner_local
-                                                            .lock()
-                                                            .unwrap()
-                                                            .auth_endpoint
-                                                            .clone(),
-                                                    )
+                                                    .post(inner_local.lock().unwrap().auth_endpoint.clone())
                                                     .form(&params)
                                                     .form(&params)
                                                     .send();
                                                     .send();
 
 
@@ -268,10 +244,7 @@ impl ZeroIDC {
                                                         if r.status().is_success() {
                                                         if r.status().is_success() {
                                                             #[cfg(debug_assertions)]
                                                             #[cfg(debug_assertions)]
                                                             {
                                                             {
-                                                                println!(
-                                                                    "hit url: {}",
-                                                                    r.url().as_str()
-                                                                );
+                                                                println!("hit url: {}", r.url().as_str());
                                                                 println!("status: {}", r.status());
                                                                 println!("status: {}", r.status());
                                                             }
                                                             }
 
 
@@ -279,24 +252,16 @@ impl ZeroIDC {
                                                             let idt = &id_token.to_string();
                                                             let idt = &id_token.to_string();
 
 
                                                             let t: Result<
                                                             let t: Result<
-                                                                Token<
-                                                                    jwt::Header,
-                                                                    jwt::Claims,
-                                                                    jwt::Unverified<'_>,
-                                                                >,
+                                                                Token<jwt::Header, jwt::Claims, jwt::Unverified<'_>>,
                                                                 jwt::Error,
                                                                 jwt::Error,
                                                             > = Token::parse_unverified(idt);
                                                             > = Token::parse_unverified(idt);
 
 
                                                             if let Ok(t) = t {
                                                             if let Ok(t) = t {
-                                                                let claims =
-                                                                    t.claims().registered.clone();
+                                                                let claims = t.claims().registered.clone();
                                                                 match claims.expiration {
                                                                 match claims.expiration {
                                                                     Some(exp) => {
                                                                     Some(exp) => {
                                                                         println!("exp: {}", exp);
                                                                         println!("exp: {}", exp);
-                                                                        inner_local
-                                                                            .lock()
-                                                                            .unwrap()
-                                                                            .exp_time = exp;
+                                                                        inner_local.lock().unwrap().exp_time = exp;
                                                                     }
                                                                     }
                                                                     None => {
                                                                     None => {
                                                                         panic!("expiration is None.  This shouldn't happen")
                                                                         panic!("expiration is None.  This shouldn't happen")
@@ -306,17 +271,11 @@ impl ZeroIDC {
                                                                 panic!("error parsing claims");
                                                                 panic!("error parsing claims");
                                                             }
                                                             }
 
 
-                                                            inner_local
-                                                                .lock()
-                                                                .unwrap()
-                                                                .access_token =
+                                                            inner_local.lock().unwrap().access_token =
                                                                 Some(access_token.clone());
                                                                 Some(access_token.clone());
                                                             if let Some(t) = res.refresh_token() {
                                                             if let Some(t) = res.refresh_token() {
                                                                 // println!("New Refresh Token: {}", t.secret());
                                                                 // println!("New Refresh Token: {}", t.secret());
-                                                                inner_local
-                                                                    .lock()
-                                                                    .unwrap()
-                                                                    .refresh_token =
+                                                                inner_local.lock().unwrap().refresh_token =
                                                                     Some(t.clone());
                                                                     Some(t.clone());
                                                             }
                                                             }
                                                             #[cfg(debug_assertions)]
                                                             #[cfg(debug_assertions)]
@@ -324,35 +283,22 @@ impl ZeroIDC {
                                                                 println!("Central post succeeded");
                                                                 println!("Central post succeeded");
                                                             }
                                                             }
                                                         } else {
                                                         } else {
-                                                            println!(
-                                                                "Central post failed: {}",
-                                                                r.status()
-                                                            );
-                                                            println!(
-                                                                "hit url: {}",
-                                                                r.url().as_str()
-                                                            );
+                                                            println!("Central post failed: {}", r.status());
+                                                            println!("hit url: {}", r.url().as_str());
                                                             println!("Status: {}", r.status());
                                                             println!("Status: {}", r.status());
                                                             if let Ok(body) = r.bytes() {
                                                             if let Ok(body) = r.bytes() {
-                                                                if let Ok(body) =
-                                                                    std::str::from_utf8(&body)
-                                                                {
+                                                                if let Ok(body) = std::str::from_utf8(&body) {
                                                                     println!("Body: {}", body);
                                                                     println!("Body: {}", body);
                                                                 }
                                                                 }
                                                             }
                                                             }
 
 
-                                                            inner_local.lock().unwrap().exp_time =
-                                                                0;
-                                                            inner_local.lock().unwrap().running =
-                                                                false;
+                                                            inner_local.lock().unwrap().exp_time = 0;
+                                                            inner_local.lock().unwrap().running = false;
                                                         }
                                                         }
                                                     }
                                                     }
                                                     Err(e) => {
                                                     Err(e) => {
                                                         println!("Central post failed: {}", e);
                                                         println!("Central post failed: {}", e);
-                                                        println!(
-                                                            "hit url: {}",
-                                                            e.url().unwrap().as_str()
-                                                        );
+                                                        println!("hit url: {}", e.url().unwrap().as_str());
                                                         println!("Status: {}", e.status().unwrap());
                                                         println!("Status: {}", e.status().unwrap());
                                                         inner_local.lock().unwrap().exp_time = 0;
                                                         inner_local.lock().unwrap().exp_time = 0;
                                                         inner_local.lock().unwrap().running = false;
                                                         inner_local.lock().unwrap().running = false;
@@ -421,88 +367,86 @@ impl ZeroIDC {
 
 
     pub fn set_nonce_and_csrf(&mut self, csrf_token: String, nonce: String) {
     pub fn set_nonce_and_csrf(&mut self, csrf_token: String, nonce: String) {
         let local = Arc::clone(&self.inner);
         let local = Arc::clone(&self.inner);
-        (*local.lock().expect("can't lock inner"))
-            .as_opt()
-            .map(|i| {
-                if i.running {
-                    println!("refresh thread running. not setting new nonce or csrf");
-                    return;
-                }
+        (*local.lock().expect("can't lock inner")).as_opt().map(|i| {
+            if i.running {
+                println!("refresh thread running. not setting new nonce or csrf");
+                return;
+            }
 
 
-                let need_verifier = matches!(i.pkce_verifier, None);
+            let need_verifier = matches!(i.pkce_verifier, None);
 
 
-                let csrf_diff = if let Some(csrf) = i.csrf_token.clone() {
-                    *csrf.secret() != csrf_token
-                } else {
-                    false
-                };
+            let csrf_diff = if let Some(csrf) = i.csrf_token.clone() {
+                *csrf.secret() != csrf_token
+            } else {
+                false
+            };
 
 
-                let nonce_diff = if let Some(n) = i.nonce.clone() {
-                    *n.secret() != nonce
-                } else {
-                    false
-                };
-
-                if need_verifier || csrf_diff || nonce_diff {
-                    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
-                    let r = i.oidc_client.as_ref().map(|c| {
-                        let mut auth_builder = c
-                            .authorize_url(
-                                AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
-                                csrf_func(csrf_token),
-                                nonce_func(nonce),
-                            )
-                            .set_pkce_challenge(pkce_challenge);
-                        match i.provider.as_str() {
-                            "auth0" => {
-                                auth_builder = auth_builder
-                                    .add_scope(Scope::new("profile".to_string()))
-                                    .add_scope(Scope::new("email".to_string()))
-                                    .add_scope(Scope::new("offline_access".to_string()));
-                            }
-                            "okta" => {
-                                auth_builder = auth_builder
-                                    .add_scope(Scope::new("profile".to_string()))
-                                    .add_scope(Scope::new("email".to_string()))
-                                    .add_scope(Scope::new("groups".to_string()))
-                                    .add_scope(Scope::new("offline_access".to_string()));
-                            }
-                            "keycloak" => {
-                                auth_builder = auth_builder
-                                    .add_scope(Scope::new("profile".to_string()))
-                                    .add_scope(Scope::new("email".to_string()));
-                            }
-                            "onelogin" => {
-                                auth_builder = auth_builder
-                                    .add_scope(Scope::new("profile".to_string()))
-                                    .add_scope(Scope::new("email".to_string()))
-                                    .add_scope(Scope::new("groups".to_string()))
-                            }
-                            "default" => {
-                                auth_builder = auth_builder
-                                    .add_scope(Scope::new("profile".to_string()))
-                                    .add_scope(Scope::new("email".to_string()))
-                                    .add_scope(Scope::new("offline_access".to_string()));
-                            }
-                            _ => {
-                                auth_builder = auth_builder
-                                    .add_scope(Scope::new("profile".to_string()))
-                                    .add_scope(Scope::new("email".to_string()))
-                                    .add_scope(Scope::new("offline_access".to_string()));
-                            }
+            let nonce_diff = if let Some(n) = i.nonce.clone() {
+                *n.secret() != nonce
+            } else {
+                false
+            };
+
+            if need_verifier || csrf_diff || nonce_diff {
+                let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
+                let r = i.oidc_client.as_ref().map(|c| {
+                    let mut auth_builder = c
+                        .authorize_url(
+                            AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
+                            csrf_func(csrf_token),
+                            nonce_func(nonce),
+                        )
+                        .set_pkce_challenge(pkce_challenge);
+                    match i.provider.as_str() {
+                        "auth0" => {
+                            auth_builder = auth_builder
+                                .add_scope(Scope::new("profile".to_string()))
+                                .add_scope(Scope::new("email".to_string()))
+                                .add_scope(Scope::new("offline_access".to_string()));
                         }
                         }
+                        "okta" => {
+                            auth_builder = auth_builder
+                                .add_scope(Scope::new("profile".to_string()))
+                                .add_scope(Scope::new("email".to_string()))
+                                .add_scope(Scope::new("groups".to_string()))
+                                .add_scope(Scope::new("offline_access".to_string()));
+                        }
+                        "keycloak" => {
+                            auth_builder = auth_builder
+                                .add_scope(Scope::new("profile".to_string()))
+                                .add_scope(Scope::new("email".to_string()));
+                        }
+                        "onelogin" => {
+                            auth_builder = auth_builder
+                                .add_scope(Scope::new("profile".to_string()))
+                                .add_scope(Scope::new("email".to_string()))
+                                .add_scope(Scope::new("groups".to_string()))
+                        }
+                        "default" => {
+                            auth_builder = auth_builder
+                                .add_scope(Scope::new("profile".to_string()))
+                                .add_scope(Scope::new("email".to_string()))
+                                .add_scope(Scope::new("offline_access".to_string()));
+                        }
+                        _ => {
+                            auth_builder = auth_builder
+                                .add_scope(Scope::new("profile".to_string()))
+                                .add_scope(Scope::new("email".to_string()))
+                                .add_scope(Scope::new("offline_access".to_string()));
+                        }
+                    }
 
 
-                        auth_builder.url()
-                    });
+                    auth_builder.url()
+                });
 
 
-                    if let Some(r) = r {
-                        i.url = Some(r.0);
-                        i.csrf_token = Some(r.1);
-                        i.nonce = Some(r.2);
-                        i.pkce_verifier = Some(pkce_verifier);
-                    }
+                if let Some(r) = r {
+                    i.url = Some(r.0);
+                    i.csrf_token = Some(r.1);
+                    i.nonce = Some(r.2);
+                    i.pkce_verifier = Some(pkce_verifier);
                 }
                 }
-            });
+            }
+        });
     }
     }
 
 
     pub fn auth_url(&self) -> String {
     pub fn auth_url(&self) -> String {
@@ -572,10 +516,7 @@ impl ZeroIDC {
                             };
                             };
 
 
                             if let Some(expected_hash) = claims.access_token_hash() {
                             if let Some(expected_hash) = claims.access_token_hash() {
-                                let actual_hash = match AccessTokenHash::from_token(
-                                    res.access_token(),
-                                    &signing_algo,
-                                ) {
+                                let actual_hash = match AccessTokenHash::from_token(res.access_token(), &signing_algo) {
                                     Ok(h) => h,
                                     Ok(h) => h,
                                     Err(e) => {
                                     Err(e) => {
                                         println!("Error hashing access token: {}", e);
                                         println!("Error hashing access token: {}", e);
@@ -616,10 +557,7 @@ impl ZeroIDC {
                     let split = split.split('_').collect::<Vec<&str>>();
                     let split = split.split('_').collect::<Vec<&str>>();
 
 
                     if split.len() == 2 {
                     if split.len() == 2 {
-                        let params = [
-                            ("id_token", id_token.to_string()),
-                            ("state", split[0].to_string()),
-                        ];
+                        let params = [("id_token", id_token.to_string()), ("state", split[0].to_string())];
                         let client = reqwest::blocking::Client::new();
                         let client = reqwest::blocking::Client::new();
                         let res = client.post(i.auth_endpoint.clone()).form(&params).send();
                         let res = client.post(i.auth_endpoint.clone()).form(&params).send();
 
 
@@ -634,10 +572,8 @@ impl ZeroIDC {
 
 
                                     let idt = &id_token.to_string();
                                     let idt = &id_token.to_string();
 
 
-                                    let t: Result<
-                                        Token<jwt::Header, jwt::Claims, jwt::Unverified<'_>>,
-                                        jwt::Error,
-                                    > = Token::parse_unverified(idt);
+                                    let t: Result<Token<jwt::Header, jwt::Claims, jwt::Unverified<'_>>, jwt::Error> =
+                                        Token::parse_unverified(idt);
 
 
                                     if let Ok(t) = t {
                                     if let Ok(t) = t {
                                         let claims = t.claims().registered.clone();
                                         let claims = t.claims().registered.clone();
@@ -682,13 +618,12 @@ impl ZeroIDC {
                                 } else if res.status() == 402 {
                                 } else if res.status() == 402 {
                                     i.running = false;
                                     i.running = false;
                                     Err(SSOExchangeError::new(
                                     Err(SSOExchangeError::new(
-                                        "additional license seats required. Please contact your network administrator.".to_string(),
+                                        "additional license seats required. Please contact your network administrator."
+                                            .to_string(),
                                     ))
                                     ))
                                 } else {
                                 } else {
                                     i.running = false;
                                     i.running = false;
-                                    Err(SSOExchangeError::new(
-                                        "error from central endpoint".to_string(),
-                                    ))
+                                    Err(SSOExchangeError::new("error from central endpoint".to_string()))
                                 }
                                 }
                             }
                             }
                             Err(res) => {
                             Err(res) => {
@@ -697,16 +632,12 @@ impl ZeroIDC {
                                 println!("Post error: {}", res);
                                 println!("Post error: {}", res);
                                 i.exp_time = 0;
                                 i.exp_time = 0;
                                 i.running = false;
                                 i.running = false;
-                                Err(SSOExchangeError::new(
-                                    "error from central endpoint".to_string(),
-                                ))
+                                Err(SSOExchangeError::new("error from central endpoint".to_string()))
                             }
                             }
                         }
                         }
                     } else {
                     } else {
                         i.running = false;
                         i.running = false;
-                        Err(SSOExchangeError::new(
-                            "error splitting state token".to_string(),
-                        ))
+                        Err(SSOExchangeError::new("error splitting state token".to_string()))
                     }
                     }
                 } else {
                 } else {
                     i.running = false;
                     i.running = false;