You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gobang/src/utils.rs

203 lines
6.3 KiB
Rust

use async_trait::async_trait;
use chrono::NaiveDate;
use database_tree::{Database, Table};
use futures::TryStreamExt;
use sqlx::mysql::{MySqlColumn, MySqlPool as MPool, MySqlRow};
use sqlx::{Column as _, Row, TypeInfo};
pub const RECORDS_LIMIT_PER_PAGE: u8 = 200;
#[async_trait]
pub trait Pool {
async fn get_databases(&self) -> anyhow::Result<Vec<Database>>;
async fn get_tables(&self, database: String) -> anyhow::Result<Vec<Table>>;
async fn get_records(
&self,
database: &str,
table: &str,
page: u16,
filter: Option<String>,
) -> anyhow::Result<(Vec<String>, Vec<Vec<String>>)>;
async fn get_columns(
&self,
database: &str,
table: &str,
) -> anyhow::Result<(Vec<String>, Vec<Vec<String>>)>;
async fn close(&self);
}
pub struct MySqlPool {
pool: MPool,
}
impl MySqlPool {
pub async fn new(database_url: &str) -> anyhow::Result<Self> {
Ok(Self {
pool: MPool::connect(database_url).await?,
})
}
}
#[async_trait]
impl Pool for MySqlPool {
async fn get_databases(&self) -> anyhow::Result<Vec<Database>> {
let databases = sqlx::query("SHOW DATABASES")
.fetch_all(&self.pool)
.await?
.iter()
.map(|table| table.get(0))
.collect::<Vec<String>>();
let mut list = vec![];
for db in databases {
list.push(Database::new(
db.clone(),
get_tables(db.clone(), &self.pool).await?,
))
}
Ok(list)
}
async fn get_tables(&self, database: String) -> anyhow::Result<Vec<Table>> {
let tables =
sqlx::query_as::<_, Table>(format!("SHOW TABLE STATUS FROM `{}`", database).as_str())
.fetch_all(&self.pool)
.await?;
Ok(tables)
}
async fn get_records(
&self,
database: &str,
table: &str,
page: u16,
filter: Option<String>,
) -> anyhow::Result<(Vec<String>, Vec<Vec<String>>)> {
let query = if let Some(filter) = filter {
format!(
"SELECT * FROM `{database}`.`{table}` WHERE {filter} LIMIT {page}, {limit}",
database = database,
table = table,
filter = filter,
page = page,
limit = RECORDS_LIMIT_PER_PAGE
)
} else {
format!(
"SELECT * FROM `{}`.`{}` limit {page}, {limit}",
database,
table,
page = page,
limit = RECORDS_LIMIT_PER_PAGE
)
};
let mut rows = sqlx::query(query.as_str()).fetch(&self.pool);
let mut headers = vec![];
let mut records = vec![];
while let Some(row) = rows.try_next().await? {
headers = row
.columns()
.iter()
.map(|column| column.name().to_string())
.collect();
let mut new_row = vec![];
for column in row.columns() {
new_row.push(convert_column_value_to_string(&row, column)?)
}
records.push(new_row)
}
Ok((headers, records))
}
async fn get_columns(
&self,
database: &str,
table: &str,
) -> anyhow::Result<(Vec<String>, Vec<Vec<String>>)> {
let query = format!("SHOW FULL COLUMNS FROM `{}`.`{}`", database, table);
let mut rows = sqlx::query(query.as_str()).fetch(&self.pool);
let mut headers = vec![];
let mut records = vec![];
while let Some(row) = rows.try_next().await? {
headers = row
.columns()
.iter()
.map(|column| column.name().to_string())
.collect();
let mut new_row = vec![];
for column in row.columns() {
new_row.push(convert_column_value_to_string(&row, column)?)
}
records.push(new_row)
}
Ok((headers, records))
}
async fn close(&self) {
self.pool.close().await;
}
}
pub async fn get_tables(database: String, pool: &MPool) -> anyhow::Result<Vec<Table>> {
let tables =
sqlx::query_as::<_, Table>(format!("SHOW TABLE STATUS FROM `{}`", database).as_str())
.fetch_all(pool)
.await?;
Ok(tables)
}
pub fn convert_column_value_to_string(
row: &MySqlRow,
column: &MySqlColumn,
) -> anyhow::Result<String> {
let column_name = column.name();
match column.type_info().clone().name() {
"INT" | "SMALLINT" | "BIGINT" => {
if let Ok(value) = row.try_get(column_name) {
let value: Option<i64> = value;
return Ok(value.map_or("NULL".to_string(), |v| v.to_string()));
}
}
"DECIMAL" => {
if let Ok(value) = row.try_get(column_name) {
let value: Option<rust_decimal::Decimal> = value;
return Ok(value.map_or("NULL".to_string(), |v| v.to_string()));
}
}
"INT UNSIGNED" => {
if let Ok(value) = row.try_get(column_name) {
let value: Option<u64> = value;
return Ok(value.map_or("NULL".to_string(), |v| v.to_string()));
}
}
"VARCHAR" | "CHAR" | "ENUM" | "TEXT" | "LONGTEXT" => {
return Ok(row
.try_get(column_name)
.unwrap_or_else(|_| "NULL".to_string()))
}
"DATE" => {
if let Ok(value) = row.try_get(column_name) {
let value: Option<NaiveDate> = value;
return Ok(value.map_or("NULL".to_string(), |v| v.to_string()));
}
}
"TIMESTAMP" => {
if let Ok(value) = row.try_get(column_name) {
let value: Option<chrono::DateTime<chrono::Utc>> = value;
return Ok(value.map_or("NULL".to_string(), |v| v.to_string()));
}
}
"BOOLEAN" => {
if let Ok(value) = row.try_get(column_name) {
let value: Option<bool> = value;
return Ok(value.map_or("NULL".to_string(), |v| v.to_string()));
}
}
_ => (),
}
Err(anyhow::anyhow!(
"column type not implemented: `{}` {}",
column_name,
column.type_info().clone().name()
))
}