Backtracking – algorytm z nawrotami

Czytam ostatnio trochę o algorytmach i natrafiłem na algorytm Backtrackingu (pol. algorytm z nawrotami). Zacząłem czytać, aby zrozumieć do czego można to zastosować i jak to działa. Ostatecznie napisałem też implementację tego algorytmu w postaci sudoku solver’a. Szczegóły implementacji oraz wyjaśnienie co do czego i jak to działa w poniższym opisie.

Okej. Pierwsze pytanie (dla niezaznajomionych) co to jest sudoku? Za wikipedią:

Łamigłówka, której celem jest wypełnienie diagramu 9 × 9 w taki sposób, aby w każdym wierszu, w każdej kolumnie i w każdym z dziewięciu pogrubionych kwadratów 3 × 3 (zwanych „blokami” lub „podkwadratami”) znalazło się po jednej cyfrze od 1 do 9.

Czas na implementację 😉 Zdefiniujmy główną funkcję rozwiązującą soduku o nazwie solve. Funkcja przyjmować będzie jeden parametr: grid (o wielkości 9×9 pól)

def solve(grid):
    # szukamy lokalizacji pierwszego zera na gridzie; location to lista zawierająca dwie liczby - pierwsza 
    # wskazuje numer wiersza (licząc od zera), druga wskazuje numer kolumny (licząc od zera)
    location = find_empty_location(grid)

    # jeżeli nie znaleźliśmy lokalizacji zera to znaczy, że jest super i cały grid został uzupełniony
    if not location:
        return True

    # iterujemy po liczbach od 1 do 9
    for num in range(1, 10):
        # sprawdzamy czy lokalizacja jest bezpieczna
        if location_is_safe(num, location, grid):
            row_index, column_index = location

            # aktualizujemy znalezioną lokalizację wybraną liczbą
            grid[row_index][column_index] = num

            # wywołujemy rekurencyjnie funkcję solve - jeśli udało się znaleźć liczbę wychodzimy z funkcji
            if solve(grid):
                return True

            # nie udało się znaleźć odpowiedniej liczby, dlatego ustawiamy lokalizację z powrotem na 0
            grid[row_index][column_index] = 0

    # tutaj odbywa się backtracing, czyli nawrót z funkcji
    return False

Mamy już główną funkcję solve, która jest sercem naszego solvera. Teraz czas na funkcje pomocnicze. Zacznijmy od funkcji szukającej lokalizacji z zerem:

def find_empty_location(grid):
    # przechodzimy po wierszach i kolumnach i zwracamy lokalizację pierwszego znalezionego zera
    for row in range(9):
        for column in range(9):
            if grid[row][column] == 0:
                return [row, column]

Teraz utwórzmy funkcję sprawdzającą, czy wybrana lokalizacja jest bezpieczna:

def location_is_safe(num, location, grid):
    row_index, column_index = location

    # wszystkie liczby wchodzące w wiersz, gdzie znajduje się nasza liczba
    row = [item for item in grid[row_index]]

    # sprawdzamy, czy wiersz jest bezpieczny (nie zawiera wybranej liczby)
    if not check_is_safe(row, num):
        return False

    # wszystkie liczby wchodzące w kolumnę, gdzie znajduje się nasza liczba
    column = [grid[index][column_index] for index in range(9)]

    # sprawdzamy, czy kolumna jest bezpieczna (nie zawiera wybranej liczby)
    if not check_is_safe(column, num):
        return False

    # czas na sprawdzenie, czy box 3x3 zawiera wybraną liczbę
    # żeby to sprawdzić ustalamy index boxu, czyli pierwszą liczbę w lewym górnym rogu
    # możemy mieć takie lokalizacje jak np.: [0,0], [3,0], [3,3] itd.
    box_row_index = row_index - row_index % 3
    box_column_index = column_index - column_index % 3

    # ustalamy cały box, tzn. do wybranej lokalizacji boxu dodajemy + 3 aby uzyskać box o wielkości 3x3
    box = [row[box_column_index:box_column_index + 3] for row in grid[box_row_index:box_row_index + 3]]

    # sprawdzamy czy box jest bezpieczny (nie zawiera wybranej liczby)
    if not check_box_is_safe(box, num):
        return False

    # zwracamy True, to znaczy, że lokalizacja jest bezpieczna
    return True

Dodajmy funkcję sprawdzającą czy lokalizacja jest bezpieczna dla wiersza lub kolumny:

def check_is_safe(row, num):
    # iterujemy po każdej liczbie w wierszu / kolumnie
    # jeśli znajdziemy naszą liczbę to zwracamy False
    for field in row:
        if num == field:
            return False

    # wiesz / kolumna jest bezpieczna
    return True

Teraz sprawdźmy czy box jest bezpieczny:

def check_box_is_safe(box, num):
    # iterujemy po wszystkich wierszach i kolumnach boxu
    # jeśli znajdziemy naszą liczbę to zwracamy False
    for i in range(3):
        for j in range(3):
            if box[i][j] == num:
                return False

    # box jest bezpieczny
    return True

Stwórzmy również funkcję pomocniczą do wyświetlenia na ekranie naszego rozwiązanego grida:

def print_grid(grid):
    for i in range(9):
        for j in range(9):
            print(grid[i][j], end="")
        print('\n')

Zdefiniujmy grid, który chcemy rozwiązać:

grid = [
    [0,5,8,0,0,0,0,0,3],
    [1,7,0,0,5,0,0,0,8],
    [0,0,0,0,0,0,1,0,0],
    [0,0,0,0,0,0,0,0,0],
    [4,0,7,0,8,0,0,0,6],
    [0,8,3,0,6,0,0,1,7],
    [9,1,0,0,0,3,0,7,0],
    [0,0,6,0,0,0,0,8,0],
    [0,0,0,0,0,0,0,3,4],
]

Wszystko gotowe, możemy uruchomić nasz sudoku solver:

if solve(grid):
    print_grid(grid)
else:
    print("Nie znaleziono rozwiazania")

Rozwiązane sudoku:

658142793

172359468

349678152

561237849

497581326

283964517

914823675

736495281

825716934

W przypadku grida zdefiniowanego w powyższym przykładzie funkcja solve potrzebowała 301478 nawrotów.

Cały skrypt znajdziecie tutaj: https://gist.github.com/bgruszka/df6e441761a8e5412300b9392c17d2b9

Dodatkowo poniżej jeszcze filmik (~15 min) z tłumaczeniem na żywo jak to działa. PS. Wybaczcie czasami szumy w filmie – nie wiedzieć czemu laptop mi się strasznie rozgrzał 😉 Dodam jeszcze, że to mój pierwszy film tego typu, więc proszę o wyrozumiałość 🙂

Dodaj komentarz

Twój adres email nie zostanie opublikowany. Pola, których wypełnienie jest wymagane, są oznaczone symbolem *