Write Python code to draw a directed graph using matplotlib and networkx. Show how to use the code with some example data. Include all necessary import statements.
Response
Let's draw a directed graph using Matplotlib and NetworkX in Python.
First, ensure you have these libraries installed. If not, you can install them using pip:
pip install matplotlib networkx
Here's the Python code with an example:
import matplotlib.pyplot as plt
import networkx as nx
def draw_directed_graph(edges, pos=None, node_size=300, font_size=10, arrow_size=20, ax=None):
"""
Draws a directed graph using Matplotlib and NetworkX.
Args:
edges (list of tuples): A list of tuples, where each tuple represents
a directed edge (source, target).
pos (dict, optional): A dictionary with nodes as keys and positions
as values. If None, a spring layout will be used.
node_size (int, optional): The size of the nodes. Defaults to 300.
font_size (int, optional): The font size for node labels. Defaults to 10.
arrow_size (int, optional): The size of the arrows for directed edges.
Defaults to 20.
ax (matplotlib.axes.Axes, optional): The Matplotlib axes object to draw
on. If None, a new figure and axes
will be created.
"""
# Create a directed graph object
G = nx.DiGraph()
# Add edges to the graph
G.add_edges_from(edges)
# If no positions are provided, compute a layout
if pos is None:
pos = nx.spring_layout(G) # You can choose other layouts like nx.circular_layout(G), nx.shell_layout(G), etc.
# If no axes object is provided, create a new figure and axes
if ax is None:
fig, ax = plt.subplots()
# Draw the nodes
nx.draw_networkx_nodes(G, pos, node_size=node_size, ax=ax, node_color='skyblue', alpha=0.9)
# Draw the edges
nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrowsize=arrow_size, edge_color='gray', width=1.5, ax=ax)
# Draw the labels (node names)
nx.draw_networkx_labels(G, pos, font_size=font_size, ax=ax)
# Set plot title and remove axes
ax.set_title("Directed Graph Example")
ax.axis('off')
# Display the plot if no axes object was provided
if ax == plt.gca(): # Check if the current axes is the one we created
plt.show()
# --- Example Usage ---
if __name__ == "__main__":
# Define the edges for the directed graph
# Each tuple (source, target) represents a directed edge from source to target
example_edges = [
('A', 'B'),
('A', 'C'),
('B', 'D'),
('C', 'D'),
('D', 'E'),
('E', 'A'),
('C', 'E')
]
# --- Basic Example ---
print("Drawing basic directed graph...")
draw_directed_graph(example_edges)
# --- Example with Custom Layout and Styling ---
print("\nDrawing directed graph with custom layout and styling...")
# You can pre-define node positions if you want a specific arrangement
custom_pos = {
'A': (0, 0),
'B': (1, 1),
'C': (1, -1),
'D': (2, 0),
'E': (3, 0)
}
plt.figure(figsize=(8, 6)) # Create a new figure for this example
draw_directed_graph(
example_edges,
pos=custom_pos,
node_size=500,
font_size=12,
arrow_size=30,
ax=plt.gca() # Draw on the current figure's axes
)
plt.title("Directed Graph with Custom Layout and Styling")
plt.show()
# --- Example with a Circular Layout ---
print("\nDrawing directed graph with circular layout...")
plt.figure(figsize=(6, 6))
draw_directed_graph(
example_edges,
pos=nx.circular_layout(nx.DiGraph(example_edges)), # Compute circular layout
node_size=400,
font_size=11,
arrow_size=25,
ax=plt.gca()
)
plt.title("Directed Graph with Circular Layout")
plt.show()
Explanation:
Import Libraries:
matplotlib.pyplot as plt : This is the standard way to import Matplotlib's plotting module.
networkx as nx : This imports the NetworkX library, conventionally aliased as nx .
draw_directed_graph Function:
edges : This is the most crucial input. It's a list of tuples, where each tuple (source, target) defines a directed connection from source to target .
pos (Optional): This parameter allows you to specify the exact positions of your nodes on the plot. If you don't provide it, NetworkX will automatically calculate a layout (e.g., spring_layout ).
- Styling Parameters (
node_size , font_size , arrow_size ): These allow you to customize the appearance of the nodes, their labels, and the arrows on the directed edges.
ax (Optional): This is for more advanced usage where you might want to draw multiple graphs on the same figure or integrate the graph drawing into a larger Matplotlib plot. If ax is None , a new figure and axes are created.
Creating the Graph:
G = nx.DiGraph() : Initializes an empty directed graph object.
Adding Edges:
G.add_edges_from(edges) : This efficiently adds all the edges from your edges list to the graph G . NetworkX automatically creates nodes if they don't already exist.
Calculating Layout (if pos is None ):
pos = nx.spring_layout(G) : The spring_layout is a common algorithm that tries to position nodes in a visually appealing way, simulating forces between nodes. Other layouts are available (e.g., circular_layout , shell_layout , spectral_layout ).
Drawing Components:
nx.draw_networkx_nodes(G, pos, ...) : Draws the nodes of the graph.
node_color='skyblue' and alpha=0.9 are styling options.
nx.draw_networkx_edges(G, pos, ...) : Draws the edges.
arrowstyle='-|>' : This is a Matplotlib arrow style that creates a filled triangle at the end of the line, indicating direction.
arrowsize=arrow_size : Controls the size of the arrowhead.
edge_color='gray' and width=1.5 are styling options.
nx.draw_networkx_labels(G, pos, ...) : Draws the labels (node names) on top of the nodes.
Plot Customization:
ax.set_title(...) : Sets the title for the plot.
ax.axis('off') : Hides the x and y axes, which are usually not needed for graph visualizations.
Displaying the Plot:
plt.show() : This command displays the Matplotlib figure containing the drawn graph. It's called if the function itself created the figure and axes.
Example Usage (if __name__ == "__main__": )
This block demonstrates how to call the draw_directed_graph function with different scenarios:
- Basic Example: Shows the default behavior with a generated layout.
- Custom Layout and Styling: Demonstrates how to pre-define node positions and apply different styling for nodes, fonts, and arrows. It uses
plt.figure() and ax=plt.gca() to draw on a specific figure.
- Circular Layout: Shows how to use
nx.circular_layout to arrange nodes in a circle.
When you run this script, you will see three separate Matplotlib windows, each displaying a version of the directed graph.
|
Model: |
gemini-2.5-pro |
Temperature: |
0.2 |
Max output tokens: |
8192 |
TopK: |
40.0 |
TopP: |
0.95 |
|