Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 76 additions & 63 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from collections import Counter
from itertools import combinations
from collections import defaultdict

Check failure on line 15 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (I001)

machine_learning/apriori_algorithm.py:14:1: I001 Import block is un-sorted or un-formatted help: Organize imports


def load_data() -> list[list[str]]:
Expand All @@ -25,78 +25,91 @@
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]


def prune(itemset: list, candidates: list, length: int) -> list:
"""
Prune candidate itemsets that are not frequent.
The goal of pruning is to filter out candidate itemsets that are not frequent. This
is done by checking if all the (k-1) subsets of a candidate itemset are present in
the frequent itemsets of the previous iteration (valid subsequences of the frequent
itemsets from the previous iteration).

Prunes candidate itemsets that are not frequent.

>>> itemset = ['X', 'Y', 'Z']
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> prune(itemset, candidates, 2)
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]

>>> itemset = ['1', '2', '3', '4']
>>> candidates = ['1', '2', '4']
>>> prune(itemset, candidates, 3)
[]
# ---------- Helpers ----------

def get_support(itemset, transactions):
"""Compute support count of an itemset efficiently."""
return sum(1 for t in transactions if itemset.issubset(t))


def generate_candidates(prev_frequent, k):
"""
itemset_counter = Counter(tuple(item) for item in itemset)
pruned = []
for candidate in candidates:
is_subsequence = True
for item in candidate:
item_tuple = tuple(item)
if (
item_tuple not in itemset_counter
or itemset_counter[item_tuple] < length - 1
):
is_subsequence = False
break
if is_subsequence:
pruned.append(candidate)
return pruned


def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
Generate candidate itemsets of size k from frequent itemsets of size k-1.
"""
Returns a list of frequent itemsets and their support counts.
prev_list = list(prev_frequent)
candidates = set()

>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> apriori(data, 2)
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
for i in range(len(prev_list)):
for j in range(i + 1, len(prev_list)):
union = prev_list[i] | prev_list[j]
if len(union) == k:
candidates.add(union)

>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
>>> apriori(data, 3)
[]
return candidates


def has_infrequent_subset(candidate, prev_frequent):
"""
itemset = [list(transaction) for transaction in data]
frequent_itemsets = []
length = 1
Apriori pruning: all (k-1)-subsets must be frequent.
"""
for subset in combinations(candidate, len(candidate) - 1):
if frozenset(subset) not in prev_frequent:
return True
return False


# ---------- Main Apriori ----------

def apriori(data: list[list[str]], min_support: int):
transactions = [set(t) for t in data]

# 1. initial 1-itemsets
item_counts = defaultdict(int)
for t in transactions:
for item in t:
item_counts[frozenset([item])] += 1

frequent = {
itemset for itemset, count in item_counts.items()
if count >= min_support
}

all_frequents = [(list(i)[0], c) for i, c in item_counts.items() if c >= min_support]

Check failure on line 77 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (E501)

machine_learning/apriori_algorithm.py:77:89: E501 Line too long (89 > 88)

Check failure on line 77 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (RUF015)

machine_learning/apriori_algorithm.py:77:23: RUF015 Prefer `next(iter(i))` over single element slice help: Replace with `next(iter(i))`

k = 2

while frequent:
# 2. generate candidates
candidates = generate_candidates(frequent, k)

# 3. prune
candidates = {
c for c in candidates
if not has_infrequent_subset(c, frequent)
}

while itemset:
# Count itemset support
counts = [0] * len(itemset)
for transaction in data:
for j, candidate in enumerate(itemset):
if all(item in transaction for item in candidate):
counts[j] += 1
# 4. count support
candidate_counts = defaultdict(int)
for t in transactions:
for c in candidates:
if c.issubset(t):
candidate_counts[c] += 1

# Prune infrequent itemsets
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
# 5. filter frequent
frequent = {
c for c, count in candidate_counts.items()
if count >= min_support
}

# Append frequent itemsets (as a list to maintain order)
for i, item in enumerate(itemset):
frequent_itemsets.append((sorted(item), counts[i]))
all_frequents.extend(
(sorted(list(c)), count)

Check failure on line 105 in machine_learning/apriori_algorithm.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (C414)

machine_learning/apriori_algorithm.py:105:14: C414 Unnecessary `list()` call within `sorted()` help: Remove the inner `list()` call
for c, count in candidate_counts.items()
if count >= min_support
)

length += 1
itemset = prune(itemset, list(combinations(itemset, length)), length)
k += 1

return frequent_itemsets
return all_frequents


if __name__ == "__main__":
Expand Down
Loading