Skip to content

Binder

Multiple Conditions with case_when

import janitor
import pandas as pd
from janitor.functions import case_when

janitor.__version__
'0.22.0'
# https://stackoverflow.com/q/19913659/7175713
df = pd.DataFrame({"col1": list("ABBC"), "col2": list("ZZXY")})

df
col1 col2
0 A Z
1 B Z
2 B X
3 C Y
  • Single Condition:
df.case_when(
    df.col1 == "Z",  # condition
    "green",  # value if True
    "red",  # value if False
    column_name="color",
)
col1 col2 color
0 A Z red
1 B Z red
2 B X red
3 C Y red
  • Multiple Conditions:
df.case_when(
    df.col2.eq("Z") & df.col1.eq("A"),
    "yellow",  # first condition and value
    df.col2.eq("Z") & df.col1.eq("B"),
    "blue",  # second condition and value
    df.col1.eq("B"),
    "purple",  # third condition and value
    "black",  # default if no condition is True
    column_name="color",
)
col1 col2 color
0 A Z yellow
1 B Z blue
2 B X purple
3 C Y black

Anonymous functions (lambda) are supported as well:

# https://stackoverflow.com/q/43391591/7175713
raw_data = {"age1": [23, 45, 21], "age2": [10, 20, 50]}
df = pd.DataFrame(raw_data, columns=["age1", "age2"])
df
age1 age2
0 23 10
1 45 20
2 21 50
df.case_when(
    lambda df: (df.age1 - df.age2) > 0,  # condition
    lambda df: df.age1 - df.age2,  # value if True
    lambda df: df.age2 - df.age1,  # default if False
    column_name="diff",
)
age1 age2 diff
0 23 10 13
1 45 20 25
2 21 50 29

data types are preserved; under the hood it uses pd.Series.mask:

df = df.astype("Int64")
df.dtypes
age1    Int64
age2    Int64
dtype: object
result = df.case_when(
    lambda df: (df.age1 - df.age2) > 0,
    lambda df: df.age1 - df.age2,
    lambda df: df.age2 - df.age1,
    column_name="diff",
)

result
age1 age2 diff
0 23 10 13
1 45 20 25
2 21 50 29
result.dtypes
age1    Int64
age2    Int64
diff    Int64
dtype: object

The conditions can be a string, as long as they can be evaluated with pd.eval on the DataFrame, and return a boolean array:

# https://stackoverflow.com/q/54653356/7175713
data = {
    "name": ["Jason", "Molly", "Tina", "Jake", "Amy"],
    "age": [42, 52, 36, 24, 73],
    "preTestScore": [4, 24, 31, 2, 3],
    "postTestScore": [25, 94, 57, 62, 70],
}
df = pd.DataFrame(data, columns=["name", "age", "preTestScore", "postTestScore"])
df
name age preTestScore postTestScore
0 Jason 42 4 25
1 Molly 52 24 94
2 Tina 36 31 57
3 Jake 24 2 62
4 Amy 73 3 70
df.case_when(
    "age < 10",
    "baby",
    "10 <= age < 20",
    "kid",
    "20 <= age < 30",
    "young",
    "30 <= age < 50",
    "mature",
    "grandpa",
    column_name="elderly",
)
name age preTestScore postTestScore elderly
0 Jason 42 4 25 mature
1 Molly 52 24 94 grandpa
2 Tina 36 31 57 mature
3 Jake 24 2 62 young
4 Amy 73 3 70 grandpa

When multiple conditions are satisfied, the first one is used:

df = range(3, 30, 3)
df = pd.DataFrame(df, columns=["odd"])
df
odd
0 3
1 6
2 9
3 12
4 15
5 18
6 21
7 24
8 27
df.case_when(
    df.odd % 9 == 0, "divisible by 9", "divisible by 3", column_name="div_by_3_or_9"
)
odd div_by_3_or_9
0 3 divisible by 3
1 6 divisible by 3
2 9 divisible by 9
3 12 divisible by 3
4 15 divisible by 3
5 18 divisible by 9
6 21 divisible by 3
7 24 divisible by 3
8 27 divisible by 9

lines 2, 5 and 8 are divisible by 3; however, because the first condition tests if it is divisible by 9, that outcome is used instead.

If column_name exists in the DataFrame, then that column's values will be replaced with the outcome of case_when:

df.case_when(df.odd % 9 == 0, "divisible by 9", "divisible by 3", column_name="odd")
odd
0 divisible by 3
1 divisible by 3
2 divisible by 9
3 divisible by 3
4 divisible by 3
5 divisible by 9
6 divisible by 3
7 divisible by 3
8 divisible by 9