python apriori 알고리즘 구현하기

Apriori algorithm?



예시데이터
구매날짜 구매목록
1일 토리,펭수
2일 펭수,맥주,돼지
3일 토리,곰,맥주,고양이
4일 토리,펭수,맥주,고양이

위의 그림은 A 가 발생하고 나서 (A,B) (A,C) (A,D) 등이 일어날 수 있고 
그 후에 (A,B,C) 가 일어날수 있음에 대한 단순한 트리이다. 
즉, 'A(토리가방)을 구매한 고객은 B,C,D를 구매할 가능성이 높다' 라는 식의 알고리즘이다.

지지도(support)

토리,펭수의 지지도 = 2/4 = P(토리n펭수) = 전체집합군중 토리 펭수가 모두 포함된 집합 수

신뢰도(confidence)

토리를 샀을때 펭수도 살 확률 = P(토리|펭수) = 1

향상도(lift)

토리와 펭수를 동시에 살 확률 / P(토리)와 P(펭수)가 동시에 일어날 확률 
    = P(토리n펭수) / P(토리)*P(펭수)
    = P(토리|펭수)/P(토리)
    = 토리와 펭수가 아무 관계가 없다면 값은 1, 만약 값이 1보다 높다면 연관성이 높다고 할 수 있음. 
    = 1 * 4/5  = 연관성이 낮음


파이썬 구현

dataset=[['사과','치즈','생수'],
['생수','호두','치즈','고등어'],
['수박','사과','생수'],
['생수','호두','치즈','옥수수']]

from apriori import apriori,printResults
items,result=apriori(dataset,
                    minSupport=0.5,
                    minConfidence=0.5)
printResults(items,result)



apriori.py


from itertools import chain, combinations
from collections import defaultdict

def _subsets(arr):
    return chain(*[combinations(arr, i + 1) for i, a in enumerate(arr)])

def _returnItemsWithMinSupport(itemSet, transactionList, minSupport, freqSet):
        _itemSet = set()
        localSet = defaultdict(int)

        for item in itemSet:
                for transaction in transactionList:
                        if item.issubset(transaction):
                                freqSet[item] += 1
                                localSet[item] += 1

        for item, count in localSet.items():
                support = float(count)/len(transactionList)

                if support >= minSupport:
                        _itemSet.add(item)

        return _itemSet

def _joinSet(itemSet, length):
        return set([i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length])

def _getItemSetTransactionList(data_iterator):
    transactionList = list()
    itemSet = set()
    for record in data_iterator:
        transaction = frozenset(record)
        transactionList.append(transaction)
        for item in transaction:
            itemSet.add(frozenset([item]))              # Generate 1-itemSets
    return itemSet, transactionList

def apriori(data_iter, minSupport=0.15, minConfidence=0.6):
    itemSet, transactionList = _getItemSetTransactionList(data_iter)

    freqSet = defaultdict(int)
    largeSet = dict()
    # Global dictionary which stores (key=n-itemSets,value=support)
    # which satisfy minSupport

    oneCSet = _returnItemsWithMinSupport(itemSet,
                                        transactionList,
                                        minSupport,
                                        freqSet)

    currentLSet = oneCSet
    k = 2
    while currentLSet != set([]):
        largeSet[k-1] = currentLSet
        currentLSet = _joinSet(currentLSet, k)
        currentCSet = _returnItemsWithMinSupport(currentLSet,
                                                transactionList,
                                                minSupport,
                                                freqSet)
        currentLSet = currentCSet
        k = k + 1

    def getSupport(item):
            return float(freqSet[item])/len(transactionList)

    toRetItems = []
    for key, value in largeSet.items():
        toRetItems.extend([(tuple(item), getSupport(item))
                           for item in value])

    toRetRules = []
    for key, value in list(largeSet.items())[1:]:
        for item in value:
            __subsets = map(frozenset, [x for x in _subsets(item)])
            for element in __subsets:
                remain = item.difference(element)
                if len(remain) > 0:
                    confidence = getSupport(item)/getSupport(element)
                    if confidence >= minConfidence:
                        toRetRules.append(((tuple(element), tuple(remain)),
                                           confidence))
    return toRetItems, toRetRules

def printResults(items, rules):
    for item, support in sorted(items, key=lambda v: v[1]):
        print("item: {} , {:.3f}".format(str(item), support))
    print("\n------------------------ RULES:")
    for rule, confidence in sorted(rules, key=lambda v: v[1]):
        pre, post = rule
        print("Rule: {} ==> {} , {:.3f}".format(str(pre), str(post), confidence))



댓글