Machine Learning in Rust, Smartcore

- 10 mins

Foreword: this is the third part of a three parts series. If you haven’t yet gotten the chance to read the first and the second part of this series, I recommend pausing here to take a few minutes and catch up before moving ahead.

In the third and last article, I want to introduce a new library that will let you quickly fit and evaluate your statistical models in Rust. This library provides a good selection of efficient tools for machine learning and statistical modeling, including classification, regression, clustering, and dimensionality reduction via a clean, uniform, and streamlined API

Meet Smartcore

When you do statistical modeling in Python you don’t have to go through matrix decomposition or use optimization routines to find parameters of your model because you can quickly fit a machine learning algorithm to your dataset in one line of code with Scikit Learn. The question is, are there any libraries in Rust that are similar in functionality to Scikit Learn?

The full answer is complicated, but the short answer is yes. There are several Rust crates which provide implementations of a range of machine learning algorithms. If you are looking for a robust library you can use to bring machine learning into a production system consider Smartcore.

Like Scikit Learn, Smartcore provides tools for supervised and unsupervised learning, as well as various tools for model selection and evaluation. The set of algorithm offerings includes:

All algorithms in Smartcore share a common interface drawn from a limited set of traits and methods. Similar to Scikit Learn, API is deliberately simple and is defined in a handful of interfaces:

Smartcore’s design makes it easy to use data defined as a n-dimentional array with ndarray or nalgebra. If you don’t want to depend on neither of these crates you can use simple Rust vectors along with any algorithm from Smartcore. In case you have some specific requirements to your data it is easy to integrate Smartcore with any new type of matrix or vector as long as you have implemented a simple interface designed specifically for this purpose.

Machine Learning with Smartcore

To give you a taste of just how easy it is to train and evaluate a machile learning algorithm using Smartcore, here’s an example of how to do just that for a linear regression model from the first part of this series.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
use nalgebra::{DMatrix, RowDVector};
use smartcore::dataset::boston;
use smartcore::error::Failed;
use smartcore::linear::linear_regression::LinearRegression;
use smartcore::metrics::mean_absolute_error;
use smartcore::model_selection::train_test_split;

fn load_bos_dataset() -> (DMatrix<f32>, RowDVector<f32>) {
    // Load The Boston Housing dataset
    let data = boston::load_dataset();
    // Turn Boston Housing dataset into nalgebra matrix
    let x = DMatrix::from_row_slice(data.num_samples, data.num_features, &data.data);
    // These are our target values
    let y = RowDVector::from_iterator(data.num_samples, data.target.into_iter());

    (x, y)
}

fn linear_regression() -> Result<(), Failed> {
    let (x, y) = load_bos_dataset();
    // Split data into training/test sets
    let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
    // Fit logistic regression and predict value estimates using test set
    let y_hat = LinearRegression::fit(&x_train, &y_train, Default::default())
        .and_then(|lr| lr.predict(&x_test))?;
    // Validate model performance on a test set
    println!(
        "Linear Regression MAE: {}",
        mean_absolute_error(&y_test, &y_hat)
    );
    Ok(())
}

If later you decide to try Random Forest Regressor on your data, all you need to get a new test score is to replace one algorithm with another.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use smartcore::ensemble::random_forest_regressor::RandomForestRegressor;

fn random_forest() -> Result<(), Failed> {
    let (x, y) = load_bos_dataset();
    // Split data into training/test sets
    let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
    // Fit random forest regression and predict value estimates using test set
    let y_hat = RandomForestRegressor::fit(&x_train, &y_train, Default::default())
        .and_then(|rf| rf.predict(&x_test))?;
    // Validate model performance on a test set
    println!(
        "Random Forest MAE: {}",
        mean_absolute_error(&y_test, &y_hat)
    );
    Ok(())
}

To evaluate logistic and random forest regressors we have split the dataset into training and test sets and used the training portion of the data to train the model and test set to get an estimate of the test error. This method, however, is not very reliable since the error obtained using one test set can be very different from the error obtained on a different test set. K-fold cross-validation provides a better way to estimate test error by dividing the data into folds and ensuring that each fold is used as a test set at some point.

This is how you can estimate the test error of a Random Forest model using 3-fold cross-validation in Smartcore.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use smartcore::model_selection::{cross_validate, KFold};

fn cv_random_forest() -> Result<(), Failed> {
    let (x, y) = load_bos_dataset();
    // cross-validate our model
    let results = cross_validate(
        RandomForestRegressor::fit,        // estimator
        &x,                                // features
        &y,                                // target
        Default::default(),                // hyperparameters
        KFold::default().with_n_splits(3), // 3-fold split
        mean_absolute_error,               // MAE metric
    )?;
    println!("Random Forest CV MAE: {}", results.mean_test_score());
    Ok(())
}

Let’s say all your efforts had not been in vain and you’ve found a useful model that can estimate the price of a house in Boston with a great accuracy. Now you want to save your model on disk to make sure you can use it later. All Smartcore algorithms can be serialized and deserialized using Rust crate Serde. This is how you can save a Random Forest model to a file on disk.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
use std::fs::File;
use std::io::prelude::*;

use nalgebra::{DMatrix, RowDVector};

use ndarray::Array;

use smartcore::dataset::boston;
use smartcore::ensemble::random_forest_regressor::RandomForestRegressor;
use smartcore::error::Failed;

fn save_random_forest_model() -> Result<(), Failed> {
    let (x, y) = load_bos_dataset();
    // Train the model
    let model = RandomForestRegressor::fit(
        &x,
        &y,    
        Default::default(),
    )?;
    // Save the model    
    let model_bytes = bincode::serialize(&model).expect("Can not serialize the model");
    File::create("random_forest.model")
        .and_then(|mut f| f.write_all(&model_bytes))
        .expect("Can not persist the model");  
    Ok(())
}

Conclusion

Implementing machine learning algorithms from scratch might be a great way to learn the theory but it is not practical in production settings. What you want instead is a good framework that provides many robust implementations of the core machine learning algorithms out-of-the-box.

Rust is a great language for scientific computing due to its speed, expressiveness and memory safety but the lack of good machine learning tools makes it harder to use this language for statistical modeling. Smartcore was created to help establish Rust as a leading language for data science and machine learning.

While the framework already provides a good selection of many ML algorithms and tools a lot needs to be done to catch up with flexibility and scope of Scikit Learn. If you’ve enjoyed reading about machine learning algorithms and you want to learn more while helping to expand the capabilities of Smartcore check out Developer’s Guide.

Sources

Vlad Orlov

Vlad Orlov

Data Scientist, Open Source Contributor, Technology Enthusiast

comments powered by Disqus
rss facebook twitter github gitlab youtube mail spotify lastfm instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora quora