Rust — Extend an existing trait Link to heading

My favorite trait in Rust is the Iterator trait. It provides a rich set of built-in methods that I don’t have to write myself. In addition, those methods allow me to write Rust with functional programming flavor.

Today, we’ll take a look at how we can extend an existing trait. As a concrete example, we will extend Iterator trait by adding a custom method that flattens Vec<Result<Vec<T>, E>> into Result<Vec<T>, E>.

Suppose we have a library that provides a function parse_line(line: &str) -> Result<Vec<i32>, E> which takes in a line and returns a vector of numbers, wrapped around Result. Our task is to implement parse_lines(lines: &[&str]) -> Result<Vec<i32>, E> which takes multiple lines and parse all the numbers into a single array using parse_line().

use std::num::ParseIntError;
type Result<R> = std::result::Result<R, ParseIntError>;

// this is a provided function from the library
fn parse_line(line: &str) -> Result<Vec<i32>> {
  // some library implementaiton goes here
}

// we need to implement this function using `parse_line`
fn parse_lines(lines: &[&str]) -> Result<Vec<i32>> {
  // this is what we need to implement
}

One way is to implement it directly, but this is not scalable, as flattening an array of Reslt<Vec<T>, E>> is probably a common task, and we don’t want to repeat implementing it over and over again. Instead, we could extend Iterator trait to take care of this!

Imagine how you want your code would look like once we extend the iterator. Here is how I’d like to use it

fn parse_lines(lines: &[&str]) -> Result<Vec<i32>> {
  lines.iter().map(|line| parse_line(line)).flatten_results()
}

That is, add flatten_result() method to any iterator whose item is Result<Vec<T>, E>. To do this, we first need to create our own trait that does just this:

pub trait FlattenResults {
    fn flatten_results<T, E>(self) -> std::result::Result<Vec<T>, E>
    where
        Self: Iterator<Item = std::result::Result<Vec<T>, E>> + Sized;
}

Our trait, called FlattenResults has a single method flatten_result() where Self is an iterator whose item is of type Result<Vec<T>, E> for some generic T and E. Next, we need to implement our trait for any generic Iterator:

impl<It> FlattenResults for It
where
    It: Iterator + Sized,
{
    fn flatten_results<T, E>(mut self) -> std::result::Result<Vec<T>, E>
    where
        Self: Iterator<Item = std::result::Result<Vec<T>, E>> + Sized,
    {
        let mut xs = Vec::new();
        loop {
            match self.next() {
                Some(Ok(x)) => xs.extend(x),
                Some(e) => {
                    return e; // propagate error
                }
                None => {
                    break;
                }
            }
        }
        Ok(xs)
    }
}

The implementation itself is straightforward — we extend xs as long as we get Ok variant. If we encounter an error, we should propagate the error immediately.

Finally, below is the full code. We can verify the code works as expected.

use std::num::ParseIntError;
type Result<R> = std::result::Result<R, ParseIntError>;

// this is a provided function
fn parse_line(line: &str) -> Result<Vec<i32>> {
    line.split_whitespace().map(str::parse).collect::<Result<Vec<_>>>()
}

// we need to implement this function
fn parse_lines(lines: &[&str]) -> Result<Vec<i32>> {
  lines.iter().map(|line| parse_line(line)).flatten_results()
}

// a separate trait for methods to be added
pub trait FlattenResults {
    fn flatten_results<T, E>(self) -> std::result::Result<Vec<T>, E>
    where
        Self: Iterator<Item = std::result::Result<Vec<T>, E>> + Sized;
}

// implement for any interator
impl<It> FlattenResults for It
where
    It: Iterator + Sized,
{
    fn flatten_results<T, E>(mut self) -> std::result::Result<Vec<T>, E>
    where
        Self: Iterator<Item = std::result::Result<Vec<T>, E>> + Sized,
    {
        let mut xs = Vec::new();
        loop {
            match self.next() {
                Some(Ok(x)) => xs.extend(x),
                Some(e) => {
                    return e;
                }
                None => {
                    break;
                }
            }
        }
        Ok(xs)
    }
}

fn main() {
    assert_eq!(parse_lines(&["1 2 3", "4 5"]), Ok(vec![1, 2, 3, 4, 5]));
    assert!(parse_lines(&["1 2 3", "x 5", "6 7"]).is_err());
}