Generator functions and the yield statement

A normal function in Python can be terminated with a return statement, which may or may not return a value back to the statement that called the function. After executing a return, the function’s role in the program is over, at least until the next time the function is called. This means that the state of the function and all its local variables are deleted from memory.

In some cases, it would be useful to have a function that can return a sequence of intermediate results, with the state of the function being retained between each of these results. Such a function is called a generator function.

It’s easiest to understand how a generator function works by looking at a specific example. Suppose we would like to generate the sequence of prime numbers from the smallest such number (2) up to some maximum value. To do this, we need an algorithm for generating all the prime numbers in a given range. The easiest way of doing this is to use a version of the famous algorithm known as the sieve of Eratosthenes. The algorithm we’ll use is as follows.

First, recall that a prime number is an integer that is divisible only by itself and 1. An integer which is not prime is known as composite, and any composite number must be divisible by at least one prime number. This prime divisor must be less than or equal to the square root of the composite number. One other useful observation is that 2 is the only even prime, so after listing 2 as a prime, we need to examine only odd numbers. With that in mind, here’s the code for generating the prime numbers from 2 up to some value max.

from math import sqrt, floor

def genPrimes(max):
    primes = [2]
    yield 2

    for n in range(3, max + 1, 2):
        maxCheck = floor(sqrt(n))
        isPrime = True
        for fac in primes[1:]:
            if fac > maxCheck: break
            if n % fac == 0:
                isPrime = False
                break
        if isPrime:
            primes += [n]
            yield n

On line 4, we define the list primes which will contain the prime numbers. primes is initialized to 2, which is the first prime number. We’ll get to the yield statement in a minute.

Next, we need to find all the primes larger than 2 but less than or equal to max, so we use the loop over n on line 7 to do this. The loop runs from 3 up to max (remember that the second argument of range() is one more than its final value), in steps of 2, since we’re skipping even numbers.

The integer maxCheck is the largest integer that is less than or equal to the square root of n. If n is composite, it must have at least one prime factor in the range from 3 up to maxCheck. The loop on line 10 runs over the prime numbers in the primes list, skipping primes[0] which is 2 and cannot be a factor of an odd number. If the number fac is larger than maxCheck, we break out of the loop. Otherwise, we check whether fac divides n, and if so, we know that n is composite so we can again break out of the loop. If none of the primes divides n, then n is prime and we add it to the primes list on line 16.

This algorithm (without the yield statements) would create the primes list containing all the primes from 2 up to max, so we could just return this list at the end of the function and then use the list in some other part of the program. However, we have written genPrimes() so that it will ‘yield’ the primes one at a time. Let’s see how this works.

The first time genPrimes() is called, it runs up to line 5 where the first yield statement is found. At this point, the function sends the value 2 back to the statement that called genPrimes(). However, unlike a return statement, yield does not end the function. Rather, the function’s state, and all its internal variables, are retained.

After the calling statement does whatever it needs to with the yielded value, it can call the same instance of genPrimes() again. This time, execution of genPrimes() will pick up where it left off after the previous yield, so it will now start running at line 7, with the for loop.

The code in the for loop will run and determine the next prime number, which is 3. When execution reaches line 17, n will be 3 and another yield statement is encountered. This time, the value of n (3) is sent back to the calling statement. The process continues until the for loop on line 7 finishes.

If an attempt to call genPrimes() is made after all the yield statements in the function have been run, a StopIteration exception is raised. Depending on how genPrimes() is being called, this might terminate the program, or it may be handled cleanly by the calling statement. The important point is that the function itself is not shut down until after the final yield has been run.

To get an idea of how a generator function can be used, we’ll add the following code after the function above (both code blocks can be placed in the same file):

lastPrime = 2
numPrimes = 1
gapStats = {}
max = input('Enter largest number: ')

for n in genPrimes(int(max)):
    numPrimes += 1
    gap = n - lastPrime
    if gap in gapStats:
        gapStats[gap] += 1
    else:
        gapStats[gap] = 1
    lastPrime = n
sortGaps = sorted(gapStats.items())
print(f'Number of primes: {numPrimes}' + '\n', sortGaps)

This code will use genPrimes() to generate a sequence of primes, and then count the number of primes generated, and also count the frequencies of the various gaps between successive primes. The data on the frequencies of gaps are stored in a dictionary called gapStats.

The call to genPrimes() is made in the for loop on line 6. A for loop requires an iterator over which it can perform its loop, and the genPrimes() generator function serves as an iterator (Technically a generator is actually a subclass of an iterator, so it’s a type of iterator. Thus all generators are iterators, but the converse is not true.) Each time we run through the loop, the genPrimes() function is called, so it generates a sequence of primes that are sent back to the for loop. Thus each yield generates a value that is placed in the for loop’s variable n, which is then used to increment the count numPrimes, and to calculate the gap between the current prime and the previous prime, lastPrime, and update the counts in the dictionary.

A for loop contains an implicit handler for the StopIteration exception, so once all the yields from genPrimes() have run, the loop receives a StopIteration and ends. After this, we sort the items in the dictionary on line 14 and print out the results.

A couple of important points should be noted here. First, the call to genPrimes() on line 6 above produces a specific instance of the genPrimes() generator. In order to access the yield statements within the generator sequentially, we must access the same instance of the function. The for loop does this automatically; the instance of genPrimes() in line 6 is retained from one iteration of the loop to the next, so the yields are called sequentially.

The other important point is that the program must provide a way of dealing with the StopIteration exception that is generated after the last yield is run. As mentioned above, a for loop does this automatically, but if you use a generator somewhere else in your program, you may need to provide an explicit way of catching this exception.

The next() method

To illustrate these points, we’ll consider using genPrimes() without a for loop. We replace the second block of code above with this:

gp = genPrimes(10)
while True:
    print(next(gp), ' ', end = '')

The first line creates a specific instance of genPrimes() and stores this in the variable gp. The while loop iterates over this generator by using the built-in  next() method. Each time next() is called, the generator runs up to its next yield statement and then returns the value at that point. Thus this loop generates all the primes between 2 and 10.

However, if we use next() to access successive yields, there is no handler in place to deal with the StopIteration exception after all the yields have run. If we run the above code, we get the output:

2  3  5  7  Traceback (most recent call last):
  File "D:\Documents\Programming\programmingpages\Python\Primes 01\Primes 01\Primes_01.py", line 3, in <module>
    print(next(gp), ' ', end = '')
StopIteration

The four primes less than 10 are printed, but then the StopIteration causes an error which stops the program. We can add a try block to handle the exception, as in:

gp = genPrimes(10)
try:
    while True:
        print(next(gp), ' ', end = '')
except StopIteration:
    print('\nAll primes found')

Now we get the output:

2  3  5  7
All primes found

Exercise

Write a generator function that generates the Fibonacci sequence as an infinite sequence (that is, the function will keep generating values, no matter how many times it is called). The Fibonacci sequence f_n is defined as

    \[ f_{n+1}=f_{n}+f_{n-1} \]

with f_1=f_2=1. That is, each term is the sum of the previous two terms. [Some definitions of the sequence start with f_1=0 and f_2=1, but for our purposes it’s easier to take the first two terms as 1.]

Use this generator to test the theorem that the ratio of two adjacent Fibonacci numbers tends to the Golden Ratio, which is defined as \phi=\frac{1+\sqrt{5}}{2}. That is, test that the following is true:

    \[ \lim_{n\rightarrow\infty}\frac{f_{n}}{f_{n-1}}=\phi=\frac{1+\sqrt{5}}{2} \]

As \phi is an irrational number, you won’t be able to prove that this limit is true, but you can verify it to a given precision. Your code should ask the user for a tolerance value (see this post for a refresher on tolerances in floating point numbers) and then iterate through the generator until this precision is reached. At this point, you should print out the value of \frac{f_{n}}{f_{n-1}}, the value of \phi, and the number of Fibonacci numbers at which the tolerance was reached.

See answers

The program is as follows:

from math import *

def fibon():
    fnm1 = 1
    yield fnm1
    fn = 1
    yield fn
    while True:
        fnp1 = fnm1 + fn
        yield fnp1
        fnm1 = fn
        fn = fnp1

phi = (1 + sqrt(5))/2
tol = input('Tolerance: ')
tol = float(tol)
genFibon = fibon()
f1 = next(genFibon)
count = 1
for f2 in genFibon:
    count += 1
    ratio = f2 / f1
    if isclose(ratio, phi, rel_tol=tol): break 
    f1 = f2

print(f'Ratio = {ratio}; Phi = {phi}; Count = {count}')

The fibon() function on line 3 takes no arguments, as it generates an infinite sequence of values. We specify the first two values as fnm1 = 1 and fn = 1, and include a specific yield for each of these values. The infinite loop on line 8 generates the sequence using the formula above by adding together the two previous values to get the next one, with a yield after each value is calculated.

The main program sets up the required values, and then creates an instance of the generator on line 17. As we need two Fibonacci numbers to calculate a ratio, we use next() to get the first number on line 18, then enter a loop on line 20. The for loop calls successive Fibonacci numbers from the generator, calculates the ratio, and compares this with \phi using the isclose() method we met when discussing floating point numbers. If the required tolerance has been met, we break out of the loop and print the results.

This program is a bit risky, since if the required tolerance is never met, the program will be in an infinite loop, so we should probably include some safeguards to cope with this situation. For example, we could break out of the for loop if count exceeded some value. However, you’ll find that the result converges surprisingly quickly. The ratio is equal to \phi within a tolerance of 10^{-10} after only 26 Fibonacci numbers have been generated.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.