Window functions
Window functions are expressions with superpowers. They allow you to perform aggregations on groups in the
select
context. Let's get a feel of what that means. First we create a dataset. The dataset loaded in the
snippet below contains information about pokemon:
import polars as pl
# then let's load some csv data with information about pokemon
df = pl.read_csv(
"https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv"
)
print(df.head())
CsvReader
· Available on feature csv
use polars::prelude::*;
use reqwest::blocking::Client;
let data: Vec<u8> = Client::new()
.get("https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv")
.send()?
.text()?
.bytes()
.collect();
let df = CsvReader::new(std::io::Cursor::new(data))
.has_header(true)
.finish()?;
println!("{}", df);
shape: (5, 13)
┌─────┬───────────────────────┬────────┬────────┬───┬─────────┬───────┬────────────┬───────────┐
│ # ┆ Name ┆ Type 1 ┆ Type 2 ┆ … ┆ Sp. Def ┆ Speed ┆ Generation ┆ Legendary │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str ┆ str ┆ ┆ i64 ┆ i64 ┆ i64 ┆ bool │
╞═════╪═══════════════════════╪════════╪════════╪═══╪═════════╪═══════╪════════════╪═══════════╡
│ 1 ┆ Bulbasaur ┆ Grass ┆ Poison ┆ … ┆ 65 ┆ 45 ┆ 1 ┆ false │
│ 2 ┆ Ivysaur ┆ Grass ┆ Poison ┆ … ┆ 80 ┆ 60 ┆ 1 ┆ false │
│ 3 ┆ Venusaur ┆ Grass ┆ Poison ┆ … ┆ 100 ┆ 80 ┆ 1 ┆ false │
│ 3 ┆ VenusaurMega Venusaur ┆ Grass ┆ Poison ┆ … ┆ 120 ┆ 80 ┆ 1 ┆ false │
│ 4 ┆ Charmander ┆ Fire ┆ null ┆ … ┆ 50 ┆ 65 ┆ 1 ┆ false │
└─────┴───────────────────────┴────────┴────────┴───┴─────────┴───────┴────────────┴───────────┘
Groupby Aggregations in selection
Below we show how to use window functions to group over different columns and perform an aggregation on them.
Doing so allows us to use multiple groupby operations in parallel, using a single query. The results of the aggregation
are projected back to the original rows. Therefore, a window function will always lead to a DataFrame
with the same size
as the original.
Note how we call .over("Type 1")
and .over(["Type 1", "Type 2"])
. Using window functions we can aggregate over different groups in a single select
call! Note that, in Rust, the type of the argument to over()
must be a collection, so even when you're only using one column, you must provided it in an array.
The best part is, this won't cost you anything. The computed groups are cached and shared between different window
expressions.
out = df.select(
[
"Type 1",
"Type 2",
pl.col("Attack").mean().over("Type 1").alias("avg_attack_by_type"),
pl.col("Defense")
.mean()
.over(["Type 1", "Type 2"])
.alias("avg_defense_by_type_combination"),
pl.col("Attack").mean().alias("avg_attack"),
]
)
print(out)
let out = df
.clone()
.lazy()
.select([
col("Type 1"),
col("Type 2"),
col("Attack")
.mean()
.over(["Type 1"])
.alias("avg_attack_by_type"),
col("Defense")
.mean()
.over(["Type 1", "Type 2"])
.alias("avg_defense_by_type_combination"),
col("Attack").mean().alias("avg_attack"),
])
.collect()?;
println!("{}", out);
shape: (163, 5)
┌─────────┬────────┬────────────────────┬─────────────────────────────────┬────────────┐
│ Type 1 ┆ Type 2 ┆ avg_attack_by_type ┆ avg_defense_by_type_combination ┆ avg_attack │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 ┆ f64 ┆ f64 │
╞═════════╪════════╪════════════════════╪═════════════════════════════════╪════════════╡
│ Grass ┆ Poison ┆ 72.923077 ┆ 67.8 ┆ 75.349693 │
│ Grass ┆ Poison ┆ 72.923077 ┆ 67.8 ┆ 75.349693 │
│ Grass ┆ Poison ┆ 72.923077 ┆ 67.8 ┆ 75.349693 │
│ Grass ┆ Poison ┆ 72.923077 ┆ 67.8 ┆ 75.349693 │
│ … ┆ … ┆ … ┆ … ┆ … │
│ Dragon ┆ null ┆ 94.0 ┆ 55.0 ┆ 75.349693 │
│ Dragon ┆ null ┆ 94.0 ┆ 55.0 ┆ 75.349693 │
│ Dragon ┆ Flying ┆ 94.0 ┆ 95.0 ┆ 75.349693 │
│ Psychic ┆ null ┆ 53.875 ┆ 51.428571 ┆ 75.349693 │
└─────────┴────────┴────────────────────┴─────────────────────────────────┴────────────┘
Operations per group
Window functions can do more than aggregation. They can also be viewed as an operation within a group. If, for instance, you
want to sort
the values within a group
, you can write col("value").sort().over("group")
and voilà! We sorted by group!
Let's filter out some rows to make this more clear.
shape: (7, 3)
┌─────────────────────┬────────┬───────┐
│ Name ┆ Type 1 ┆ Speed │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞═════════════════════╪════════╪═══════╡
│ Slowpoke ┆ Water ┆ 15 │
│ Slowbro ┆ Water ┆ 30 │
│ SlowbroMega Slowbro ┆ Water ┆ 30 │
│ Exeggcute ┆ Grass ┆ 40 │
│ Exeggutor ┆ Grass ┆ 55 │
│ Starmie ┆ Water ┆ 115 │
│ Jynx ┆ Ice ┆ 95 │
└─────────────────────┴────────┴───────┘
Observe that the group Water
of column Type 1
is not contiguous. There are two rows of Grass
in between. Also note
that each pokemon within a group are sorted by Speed
in ascending
order. Unfortunately, for this example we want them sorted in
descending
speed order. Luckily with window functions this is easy to accomplish.
shape: (7, 3)
┌─────────────────────┬────────┬───────┐
│ Name ┆ Type 1 ┆ Speed │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞═════════════════════╪════════╪═══════╡
│ Starmie ┆ Water ┆ 115 │
│ Slowbro ┆ Water ┆ 30 │
│ SlowbroMega Slowbro ┆ Water ┆ 30 │
│ Exeggutor ┆ Grass ┆ 55 │
│ Exeggcute ┆ Grass ┆ 40 │
│ Slowpoke ┆ Water ┆ 15 │
│ Jynx ┆ Ice ┆ 95 │
└─────────────────────┴────────┴───────┘
Polars
keeps track of each group's location and maps the expressions to the proper row locations. This will also work over different groups in a single select
.
The power of window expressions is that you often don't need a groupby -> explode
combination, but you can put the logic in a single expression. It also makes the API cleaner. If properly used a:
groupby
-> marks that groups are aggregated and we expect aDataFrame
of sizen_groups
over
-> marks that we want to compute something within a group, but doesn't modify the original size of theDataFrame
Window expression rules
The evaluations of window expressions are as follows (assuming we apply it to a pl.Int32
column):
# aggregate and broadcast within a group
# output type: -> Int32
pl.sum("foo").over("groups")
# sum within a group and multiply with group elements
# output type: -> Int32
(pl.col("x").sum() * pl.col("y")).over("groups")
# sum within a group and multiply with group elements
# and aggregate the group to a implode
# output type: -> List(Int32)
(pl.col("x").sum() * pl.col("y")).implode().over("groups")
# note that it will require an explicit `implode()` call
# sum within a group and multiply with group elements
# and aggregate the group to a list
# the flatten call explodes that list
# This is the fastest method to do things over groups when the groups are sorted
(pl.col("x").sum() * pl.col("y")).implode().over("groups").flatten()
// aggregate and broadcast within a group
// output type: -> i32
sum("foo").over([col("groups")])
// sum within a group and multiply with group elements
// output type: -> i32
(col("x").sum() * col("y"))
.over([col("groups")])
.alias("x1")
// sum within a group and multiply with group elements
// and aggregate the group to a list
// output type: -> ChunkedArray<i32>
(col("x").sum() * col("y"))
.list()
.over([col("groups")])
.alias("x2")
// note that it will require an explicit `list()` call
// sum within a group and multiply with group elements
// and aggregate the group to a list
// the flatten call explodes that list
// This is the fastest method to do things over groups when the groups are sorted
(col("x").sum() * col("y"))
.list()
.over([col("groups")])
.flatten()
.alias("x3");
More examples
For more exercise, below are some window functions for us to compute:
- sort all pokemon by type
- select the first
3
pokemon per type as"Type 1"
- sort the pokemon within a type by speed and select the first
3
as"fastest/group"
- sort the pokemon within a type by attack and select the first
3
as"strongest/group"
- sort the pokemon by name within a type and select the first
3
as"sorted_by_alphabet"
out = df.sort("Type 1").select(
[
pl.col("Type 1").head(3).implode().over("Type 1").flatten(),
pl.col("Name")
.sort_by(pl.col("Speed"))
.head(3)
.implode()
.over("Type 1")
.flatten()
.alias("fastest/group"),
pl.col("Name")
.sort_by(pl.col("Attack"))
.head(3)
.implode()
.over("Type 1")
.flatten()
.alias("strongest/group"),
pl.col("Name")
.sort()
.head(3)
.implode()
.over("Type 1")
.flatten()
.alias("sorted_by_alphabet"),
]
)
print(out)
let out = df
.clone()
.lazy()
.select([
col("Type 1")
.head(Some(3))
.list()
.over(["Type 1"])
.flatten(),
col("Name")
.sort_by(["Speed"], [false])
.head(Some(3))
.list()
.over(["Type 1"])
.flatten()
.alias("fastest/group"),
col("Name")
.sort_by(["Attack"], [false])
.head(Some(3))
.list()
.over(["Type 1"])
.flatten()
.alias("strongest/group"),
col("Name")
.sort(false)
.head(Some(3))
.list()
.over(["Type 1"])
.flatten()
.alias("sorted_by_alphabet"),
])
.collect()?;
println!("{:?}", out);
shape: (43, 4)
┌────────┬─────────────────────┬─────────────────┬─────────────────────────┐
│ Type 1 ┆ fastest/group ┆ strongest/group ┆ sorted_by_alphabet │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ str │
╞════════╪═════════════════════╪═════════════════╪═════════════════════════╡
│ Bug ┆ Paras ┆ Metapod ┆ Beedrill │
│ Bug ┆ Metapod ┆ Kakuna ┆ BeedrillMega Beedrill │
│ Bug ┆ Parasect ┆ Caterpie ┆ Butterfree │
│ Dragon ┆ Dratini ┆ Dratini ┆ Dragonair │
│ … ┆ … ┆ … ┆ … │
│ Rock ┆ Omanyte ┆ Omastar ┆ Geodude │
│ Water ┆ Slowpoke ┆ Magikarp ┆ Blastoise │
│ Water ┆ Slowbro ┆ Tentacool ┆ BlastoiseMega Blastoise │
│ Water ┆ SlowbroMega Slowbro ┆ Horsea ┆ Cloyster │
└────────┴─────────────────────┴─────────────────┴─────────────────────────┘
Flattened window function
If we have a window function that aggregates to a list
like the example above with the following Python expression:
pl.col("Name").sort_by(pl.col("Speed")).head(3).implode().over("Type 1")
and in Rust:
col("Name").sort_by(["Speed"], [false]).head(Some(3)).implode().over(["Type 1"])
This still works, but that would give us a column type List
which might not be what we want (this would significantly increase our memory usage!).
Instead we could flatten
. This just turns our 2D list into a 1D array and projects that array/column back to our DataFrame
.
This is very fast because the reshape is often free, and adding the column back the the original DataFrame
is also a lot cheaper (since we don't require a join like in a normal window function).
However, for this operation to make sense, it is important that the columns used in over([..])
are sorted!