handle access denied

This commit is contained in:
lemonsh 2023-05-01 10:34:42 +02:00
parent 043559ca7c
commit 734cf7a2a8
3 changed files with 103 additions and 14 deletions

View File

@ -1,4 +1,3 @@
use reqwest::header::ToStrError;
use thiserror::Error; use thiserror::Error;
/// The error type used globally by the library. /// The error type used globally by the library.
@ -12,11 +11,15 @@ pub enum Error {
UnexpectedResponse(String), UnexpectedResponse(String),
#[error("you are not logged in, or perhaps the session has expired")] #[error("you are not logged in, or perhaps the session has expired")]
NotAuthorized, NotAuthorized,
#[error("access denied, most likely someone else is already logged in")]
AccessDenied,
#[error("an unexpected redirection has occurred: {0:?}")]
UnexpectedRedirect(String),
#[error(transparent)] #[error(transparent)]
URLParseError(#[from] url::ParseError), URLParseError(#[from] url::ParseError),
#[error(transparent)] #[error(transparent)]
InvalidHeaderValue(#[from] ToStrError), InvalidHeaderValue(#[from] reqwest::header::ToStrError),
#[error(transparent)] #[error(transparent)]
HttpError(#[from] reqwest::Error), HttpError(#[from] reqwest::Error),
#[error(transparent)] #[error(transparent)]

View File

@ -6,7 +6,7 @@ pub use error::Error;
use reqwest::{ use reqwest::{
cookie::{CookieStore, Jar}, cookie::{CookieStore, Jar},
redirect::Policy, redirect::Policy,
Client, Url, Client, Url, header::HeaderValue,
}; };
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -15,7 +15,7 @@ mod functions;
/// Data structures used by the library. /// Data structures used by the library.
pub mod models; pub mod models;
// A Result type based on the library's Error /// A Result type based on the library's Error
pub type Result<T> = std::result::Result<T, error::Error>; pub type Result<T> = std::result::Result<T, error::Error>;
type Field<'a, 'b> = (Cow<'a, str>, Cow<'b, str>); type Field<'a, 'b> = (Cow<'a, str>, Cow<'b, str>);
@ -85,7 +85,7 @@ impl ConnectBox {
if resp.status().is_redirection() { if resp.status().is_redirection() {
if self.auto_reauth && !reauthed { if self.auto_reauth && !reauthed {
reauthed = true; reauthed = true;
tracing::debug!("session has expired, attempting reauth"); tracing::info!("session <{}> has expired, attempting reauth", self.cookie("SID")?.as_deref().unwrap_or("unknown"));
self._login().await?; self._login().await?;
continue; continue;
} }
@ -117,7 +117,7 @@ impl ConnectBox {
if resp.status().is_redirection() { if resp.status().is_redirection() {
if self.auto_reauth && !reauthed { if self.auto_reauth && !reauthed {
reauthed = true; reauthed = true;
tracing::debug!("session has expired, attempting reauth"); tracing::info!("session <{}> has expired, attempting reauth", self.cookie("SID")?.as_deref().unwrap_or("unknown"));
self._login().await?; self._login().await?;
continue; continue;
} }
@ -136,13 +136,25 @@ impl ConnectBox {
("Password".into(), (&self.code).into()), ("Password".into(), (&self.code).into()),
]; ];
let req = self.http.post(self.setter_url.clone()).form(&form); let req = self.http.post(self.setter_url.clone()).form(&form);
let response = req.send().await?.text().await?; let resp = req.send().await?;
if response == "idloginincorrect" { if resp.status().is_redirection() {
if let Some(location) = resp.headers().get("Location").map(HeaderValue::to_str) {
let location = location?;
return if location == "../common_page/Access-denied.html" {
Err(Error::AccessDenied)
} else {
Err(Error::UnexpectedRedirect(location.to_string()))
}
}
}
let resp_text = resp.text().await?;
if resp_text == "idloginincorrect" {
return Err(Error::IncorrectCode); return Err(Error::IncorrectCode);
} }
let sid = response let sid = resp_text
.strip_prefix("successful;SID=") .strip_prefix("successful;SID=")
.ok_or_else(|| Error::UnexpectedResponse(response.clone()))?; .ok_or_else(|| Error::UnexpectedResponse(resp_text.clone()))?;
tracing::info!("session <{sid}>: logged in successfully");
self.cookie_store self.cookie_store
.add_cookie_str(&format!("SID={sid}"), &self.base_url); .add_cookie_str(&format!("SID={sid}"), &self.base_url);

View File

@ -1,6 +1,7 @@
use std::net::Ipv4Addr;
use std::time::Duration; use std::time::Duration;
use serde::de::Error; use serde::de::{Error, self, Unexpected};
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@ -19,21 +20,94 @@ pub struct LanUserTable {
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct ClientInfo { pub struct ClientInfo {
pub index: u32,
pub interface: String, pub interface: String,
#[serde(rename = "interfaceid")]
pub interface_id: u32,
#[serde(rename = "IPv4Addr")] #[serde(rename = "IPv4Addr")]
pub ipv4_addr: String, pub ipv4_addr: String,
pub index: u32,
#[serde(rename = "interfaceid")]
pub interface_id: u32,
pub hostname: String, pub hostname: String,
#[serde(rename = "MACAddr")] #[serde(rename = "MACAddr")]
pub mac: String, pub mac: String,
pub method: u32,
#[serde(rename = "leaseTime")] #[serde(rename = "leaseTime")]
#[serde(deserialize_with = "deserialize_lease_time")] #[serde(deserialize_with = "deserialize_lease_time")]
pub lease_time: Duration, pub lease_time: Duration,
pub speed: u32, pub speed: u32,
} }
#[derive(Deserialize, Debug)]
pub struct PortForwards {
#[serde(rename = "LanIP")]
pub lan_ip: Ipv4Addr,
#[serde(rename = "subnetmask")]
pub subnet_mask: Ipv4Addr,
#[serde(rename = "instance")]
#[serde(deserialize_with = "unwrap_xml_list")]
pub entries: Vec<PortForwardEntry>,
}
#[derive(Deserialize, Debug)]
pub struct PortForwardEntry {
pub id: u32,
#[serde(rename = "local_IP")]
pub local_ip: Ipv4Addr,
pub start_port: u16,
pub end_port: u16,
#[serde(rename = "start_portIn")]
pub start_port_in: u16,
#[serde(rename = "end_portIn")]
pub end_port_in: u16,
pub protocol: PortForwardProtocol,
#[serde(deserialize_with = "bool_from_int")]
pub enable: bool
}
#[derive(Debug)]
pub enum PortForwardProtocol {
Tcp,
Udp,
Both,
}
impl PortForwardProtocol {
fn id(&self) -> u8 {
match self {
PortForwardProtocol::Tcp => 1,
PortForwardProtocol::Udp => 2,
PortForwardProtocol::Both => 3,
}
}
}
impl<'de> Deserialize<'de> for PortForwardProtocol {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
match u8::deserialize(d)? {
1 => Ok(Self::Tcp),
2 => Ok(Self::Udp),
3 => Ok(Self::Both),
_ => Err(D::Error::custom("protocol not in range 1..=3")),
}
}
}
fn bool_from_int<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: Deserializer<'de>,
{
match u8::deserialize(deserializer)? {
0 => Ok(false),
1 => Ok(true),
other => Err(de::Error::invalid_value(
Unexpected::Unsigned(other as u64),
&"zero or one",
)),
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct List<T> { struct List<T> {
#[serde(rename = "$value")] #[serde(rename = "$value")]