Using Classes As Decorators in Python

Improve code quality with advance python techniques

Abhishek Saini
7 min readJun 28, 2024

Decorators in python are very powerful yet not used properly by many. This magical concept allows you to add extra functionality to another function by simply adding a single line on top of any function.

Author Generated Image via DALL.E

Decorators can be defined as functions or classes in python. Earlier, I was only using functions to define decorators but recently experienced the clarity and readability provided by classes as decorators.

Recently, I created an LRU caching class and wanted to use it as a decorator to cache function responses. After some research on the topic, I found a solution that worked perfectly. Let’s walk through the process of solving this problem together.

First, let’s write the LRU caching class inside the lru.py file.

from typing import Any, Union, TypeVar, Optional, List

K = TypeVar("K")
V = TypeVar("V")


class Node:
def __init__(self, key: Optional[K], value: Optional[V]) -> None:
"""
Initializes a Node with a given key and value.

Args:
key (K, optional): The key of the node.
value (V, optional): The value of the node.
"""
self.key: Optional[K] = key
self.value: Optional[V] = value
self.next: Optional["Node"] = None
self.prev: Optional["Node"] = None


class LRU:
def __init__(self, max_size: int = 10, debug: bool = False) -> None:
"""
Initializes an LRU (Least Recently Used) cache with a given maximum size and debug flag.

Args:
max_size (int): The maximum size of the LRU cache. Default is 10.
debug (bool): Flag to enable or disable debug logging. Default is False.
"""
self.max_size: int = max_size
p_node = Node(None, None)
self.head: Node = p_node
self.tail: Node = p_node
self.size: int = 0
self.debug_logs: List[str] = []
self.debug: bool = debug

def _remove_node(self, node: Node) -> None:
"""
Removes a node from the LRU cache.

Args:
node (Node): The node to remove.
"""
self.size -= 1

nxt_node: Optional[Node] = node.next
prev_node: Optional[Node] = node.prev

if not nxt_node:
prev_node.next = None
self.tail = prev_node
else:
prev_node.next = nxt_node
nxt_node.prev = prev_node

self.debug_logs.append(
f"Node deleted from cache: {node.key} - {node.value} \n size of cache is {self.size}"
)

def _add_trailing_node(self, node: Node) -> None:
"""
Adds a node to the end of the LRU cache.

Args:
node (Node): The node to add.
"""
self.size += 1

node.prev = self.tail
self.tail.next = node
node.next = None
self.tail = node

self.debug_logs.append(
f"New node added to cache: {node.key} - {node.value} \n size of cache is {self.size}"
)

def add(self, key: K, value: V) -> None:
"""
Adds a new key-value pair to the LRU cache.

Args:
key (K): The key to add.
value (V): The value to add.
"""
new_node = Node(key, value)

if self.size >= self.max_size:
self._remove_node(self.head.next)

self._add_trailing_node(new_node)

def get(self, key: K) -> Optional[V]:
"""
Retrieves a value from the LRU cache by its key.

Args:
key (K): The key to search for.

Returns:
Optional[V]: The value associated with the key, or None if the key is not found.
"""
curr_node = self.head

while curr_node:
if curr_node.key == key:
self._remove_node(curr_node)
self._add_trailing_node(curr_node)
return curr_node.value

curr_node = curr_node.next

return None

def __str__(self) -> str:
"""
Returns a string representation of the LRU cache.

Returns:
str: The string representation of the cache.
"""
curr_node = self.head
nodes = ""
while curr_node:
nodes += f" -> Node({curr_node.key}, {curr_node.value}) "
curr_node = curr_node.next

return nodes

The LRU class is accepting two params — max_sizeand debug .
max_size is defining the maximum storage of cache and debug is defining whether to print the debug statements or not.

Now how to use this class as decorator over a function to cache their returned values ?

One approach is to write a decorator function that wraps around the class. Let’s create the lru_cache_decorator function inside the lru.py file.

def lru_cache_decorator(max_size: int = 10, debug: bool = False) -> Callable:
"""
A decorator to add LRU (Least Recently Used) caching functionality to a function.

Args:
max_size (int): The maximum size of the LRU cache. Default is 10.
debug (bool): Flag to enable or disable debug logging. Default is False.

Returns:
Callable: A decorated function with LRU caching.
"""
lru_cache = LRU(max_size=max_size, debug=debug)

def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
key = f"{func.__name__}|{args}|{kwargs}"
cached_data = lru_cache.get(key=key)
if cached_data is not None:
if debug:
print("Returning cached data")
print(lru_cache)
print("Printing debug logs: \n", lru_cache.debug_logs)
return cached_data
result = func(*args, **kwargs)
lru_cache.add(key=key, value=result)
if debug:
print(lru_cache)
print("Printing debug logs: \n", lru_cache.debug_logs)
return result

return wrapper

return decorator

Now that we have created our decorator function, let’s create a main.py file to test it out.

from lru import lru_cache_decorator
import time


@lru_cache_decorator(3)
def add(a, b):
time.sleep(5) # mocking time taking process
return a + b


@lru_cache_decorator(3)
def sub(a, b):
time.sleep(5) # mocking time taking process
return a - b


if __name__ == "__main__":
start = time.time()
print(add(1, 2))
print(sub(2, 2))
print(add(3, 4))

print(sub(2, 2)) # Duplicate function call
print(add(1, 2)) # Duplicate function call

end = time.time()
print(f"Total time taken : {end-start} secs")

Following is the output generated after calling main.py

🍺 abhisheksaini caching () => python3 main.py
No cached data found !
3
No cached data found !
0
No cached data found !
7
Returning cached data
0
Returning cached data
3

Total time taken : 15.017259120941162 secs

As you can see it cached the data for first 3 function calls and later returned the cached data for next 2 repetitive calls. Overall time taken to execute this code is 15 secs whereas if we didn’t use the decorator function for caching, it would have taken around 25 secs. So, we know its working.

All this is fine but did you also asked the question

How its retaining cached data if there are multiple function calls and thus multiple calls to initialize LRU class

Well, the lru_cache_decorator works by creating an instance of the LRU class only once when the decorator is applied to the function. This instance retains its state across multiple calls to the decorated function, allowing it to maintain the cache.

This approach is fine and working, only overhead is to create another function for LRU class to be used as decorator. What can we do better to improve readability and maintainability ?

Another approach is to use class features itself. Python classes have a method __call__ which executes code when class object is executed. Let’s add __call__ method to our LRU class inside lru.py file

from typing import Any, Union, TypeVar, Optional, List

K = TypeVar("K")
V = TypeVar("V")

class Node:
def __init__(self, key: Optional[K], value: Optional[V]) -> None:
"""
Initializes a Node with a given key and value.

Args:
key (K, optional): The key of the node.
value (V, optional): The value of the node.
"""
self.key: Optional[K] = key
self.value: Optional[V] = value
self.next: Optional["Node"] = None
self.prev: Optional["Node"] = None

class LRU:
def __init__(self, max_size: int = 10, debug: bool = False) -> None:
"""
Initializes an LRU (Least Recently Used) cache with a given maximum size and debug flag.

Args:
max_size (int): The maximum size of the LRU cache. Default is 10.
debug (bool): Flag to enable or disable debug logging. Default is False.
"""
self.max_size: int = max_size
p_node = Node(None, None)
self.head: Node = p_node
self.tail: Node = p_node
self.size: int = 0
self.debug_logs: List[str] = []
self.debug: bool = debug

def _remove_node(self, node: Node) -> None:
"""
Removes a node from the LRU cache.

Args:
node (Node): The node to remove.
"""
self.size -= 1

nxt_node: Optional[Node] = node.next
prev_node: Optional[Node] = node.prev

if not nxt_node:
prev_node.next = None
self.tail = prev_node
else:
prev_node.next = nxt_node
nxt_node.prev = prev_node

self.debug_logs.append(
f"Node deleted from cache: {node.key} - {node.value} \n size of cache is {self.size}"
)

def _add_trailing_node(self, node: Node) -> None:
"""
Adds a node to the end of the LRU cache.

Args:
node (Node): The node to add.
"""
self.size += 1

node.prev = self.tail
self.tail.next = node
node.next = None
self.tail = node

self.debug_logs.append(
f"New node added to cache: {node.key} - {node.value} \n size of cache is {self.size}"
)

def add(self, key: K, value: V) -> None:
"""
Adds a new key-value pair to the LRU cache.

Args:
key (K): The key to add.
value (V): The value to add.
"""
new_node = Node(key, value)

if self.size >= self.max_size:
self._remove_node(self.head.next)

self._add_trailing_node(new_node)

def get(self, key: K) -> Optional[V]:
"""
Retrieves a value from the LRU cache by its key.

Args:
key (K): The key to search for.

Returns:
Optional[V]: The value associated with the key, or None if the key is not found.
"""
curr_node = self.head

while curr_node:
if curr_node.key == key:
self._remove_node(curr_node)
self._add_trailing_node(curr_node)
return curr_node.value

curr_node = curr_node.next

return None

def __str__(self) -> str:
"""
Returns a string representation of the LRU cache.

Returns:
str: The string representation of the cache.
"""
curr_node = self.head
nodes = ""
while curr_node:
nodes += f" -> Node({curr_node.key}, {curr_node.value}) "
curr_node = curr_node.next

return nodes

def __call__(self, func: Callable) -> Callable:
"""
Decorator method to add LRU caching functionality to a function.

Args:
func (Callable): The function to be decorated.

Returns:
Callable: A wrapper function with LRU caching.
"""
def wrapper(*args, **kwargs):
"""
Wrapper function that manages LRU caching for the decorated function.

Args:
*args: Positional arguments passed to the decorated function.
**kwargs: Keyword arguments passed to the decorated function.

Returns:
Any: The result of the decorated function, either from the cache or freshly computed.
"""
key = f"{func.__name__}|{args}|{kwargs}"
cached_data = self.get(key=key)
if cached_data is not None:
print("Returning cached data")
if self.debug:
print("Printing debug logs: \n", self.debug_logs)
return cached_data
result = func(*args, **kwargs)
self.add(key=key, value=result)
print("No cached data found !")
if self.debug:
print("Printing debug logs: \n", self.debug_logs)
return result

return wrapper

__call__ method is acting like the decorator function over the LRU class. Now, we don’t need lru_cache_decorator function anymore. Let’s make changes to our main.py too and execute the code.

from lru import  LRU
import time



@LRU(3)
def add(a, b):
time.sleep(5) # mocking time taking process
return a + b


@LRU(3)
def sub(a, b):
time.sleep(5) # mocking time taking process
return a - b


if __name__ == "__main__":
start = time.time()
print(add(1, 2))
print(sub(2, 2))
print(add(3, 4))

print(sub(2, 2)) # Duplicate function call
print(add(1, 2)) # Duplicate function call

end = time.time()
print(f"Total time taken : {end-start} secs")
🍺 abhisheksaini caching () => python3 main.py
No cached data found !
3
No cached data found !
0
No cached data found !
7
Returning cached data
0
Returning cached data
3
Total time taken : 15.013511896133423 secs

It didn’t effect our output but surely made our code more readable and maintainable.

You can apply this concept to use any of your classes as decorators. Leveraging class decorators in Python can bring about a significant boost in your code’s efficiency, readability, and maintainability

If you found this Medium story helpful, kindly let me know in the comments. I would love to hear your opinion.

--

--

Abhishek Saini

Experienced Software Developer • Problem Solver • System Architect • Python • Rest-API • Micro-services • AWS