Skip to content

Folds

Polars provides expressions/methods for horizontal aggregations like sum,min, mean, etc. However, when you need a more complex aggregation the default methods Polars supplies may not be sufficient. That's when folds come in handy.

The fold expression operates on columns for maximum speed. It utilizes the data layout very efficiently and often has vectorized execution.

Manual sum

Let's start with an example by implementing the sum operation ourselves, with a fold.

fold

df = pl.DataFrame(
    {
        "a": [1, 2, 3],
        "b": [10, 20, 30],
    }
)

out = df.select(
    pl.fold(acc=pl.lit(0), function=lambda acc, x: acc + x, exprs=pl.all()).alias(
        "sum"
    ),
)
print(out)

fold_exprs

let df = df!(
    "a" => &[1, 2, 3],
    "b" => &[10, 20, 30],
)?;

let out = df
    .lazy()
    .select([fold_exprs(lit(0), |acc, x| Ok(Some(acc + x)), [col("*")]).alias("sum")])
    .collect()?;
println!("{}", out);

shape: (3, 1)
┌─────┐
│ sum │
│ --- │
│ i64 │
╞═════╡
│ 11  │
│ 22  │
│ 33  │
└─────┘

The snippet above recursively applies the function f(acc, x) -> acc to an accumulator acc and a new column x. The function operates on columns individually and can take advantage of cache efficiency and vectorization.

Conditional

In the case where you'd want to apply a condition/predicate on all columns in a DataFrame a fold operation can be a very concise way to express this.

fold

df = pl.DataFrame(
    {
        "a": [1, 2, 3],
        "b": [0, 1, 2],
    }
)

out = df.filter(
    pl.fold(
        acc=pl.lit(True),
        function=lambda acc, x: acc & x,
        exprs=pl.col("*") > 1,
    )
)
print(out)

fold_exprs

let df = df!(
    "a" => &[1, 2, 3],
    "b" => &[0, 1, 2],
)?;

let out = df
    .lazy()
    .filter(fold_exprs(
        lit(true),
        |acc, x| Some(acc.bitand(&x)),
        [col("*").gt(1)],
    ))
    .collect()?;
println!("{}", out);

shape: (1, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 3   ┆ 2   │
└─────┴─────┘

In the snippet we filter all rows where each column value is > 1.

Folds and string data

Folds could be used to concatenate string data. However, due to the materialization of intermediate columns, this operation will have squared complexity.

Therefore, we recommend using the concat_str expression for this.

concat_str

df = pl.DataFrame(
    {
        "a": ["a", "b", "c"],
        "b": [1, 2, 3],
    }
)

out = df.select(pl.concat_str(["a", "b"]))
print(out)

concat_str · Available on feature concat_str

let df = df!(
    "a" => &["a", "b", "c"],
    "b" => &[1, 2, 3],
)?;

let out = df
    .lazy()
    .select([concat_str([col("a"), col("b")], "")])
    .collect()?;
println!("{:?}", out);

shape: (3, 1)
┌─────┐
│ a   │
│ --- │
│ str │
╞═════╡
│ a1  │
│ b2  │
│ c3  │
└─────┘