Python visualization

Write Python code to visualize a weighted directed graph with example data

You can query a model directly and test the results returned when using different parameter values with the Cloud console, or by calling the Vertex AI API directly.

Prompt Freeform

Write Python code to draw a directed graph using matplotlib and networkx. Show how to use the code with some example data. Include all necessary import statements.

Response

Let's draw a directed graph using Matplotlib and NetworkX in Python.

First, ensure you have these libraries installed. If not, you can install them using pip:

pip install matplotlib networkx

Here's the Python code with an example:

import matplotlib.pyplot as plt
import networkx as nx

def draw_directed_graph(edges, pos=None, node_size=300, font_size=10, arrow_size=20, ax=None):
    """
    Draws a directed graph using Matplotlib and NetworkX.

    Args:
        edges (list of tuples): A list of tuples, where each tuple represents
                                 a directed edge (source, target).
        pos (dict, optional): A dictionary with nodes as keys and positions
                              as values. If None, a spring layout will be used.
        node_size (int, optional): The size of the nodes. Defaults to 300.
        font_size (int, optional): The font size for node labels. Defaults to 10.
        arrow_size (int, optional): The size of the arrows for directed edges.
                                    Defaults to 20.
        ax (matplotlib.axes.Axes, optional): The Matplotlib axes object to draw
                                             on. If None, a new figure and axes
                                             will be created.
    """
    # Create a directed graph object
    G = nx.DiGraph()

    # Add edges to the graph
    G.add_edges_from(edges)

    # If no positions are provided, compute a layout
    if pos is None:
        pos = nx.spring_layout(G) # You can choose other layouts like nx.circular_layout(G), nx.shell_layout(G), etc.

    # If no axes object is provided, create a new figure and axes
    if ax is None:
        fig, ax = plt.subplots()

    # Draw the nodes
    nx.draw_networkx_nodes(G, pos, node_size=node_size, ax=ax, node_color='skyblue', alpha=0.9)

    # Draw the edges
    nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrowsize=arrow_size, edge_color='gray', width=1.5, ax=ax)

    # Draw the labels (node names)
    nx.draw_networkx_labels(G, pos, font_size=font_size, ax=ax)

    # Set plot title and remove axes
    ax.set_title("Directed Graph Example")
    ax.axis('off')

    # Display the plot if no axes object was provided
    if ax == plt.gca(): # Check if the current axes is the one we created
        plt.show()

# --- Example Usage ---

if __name__ == "__main__":
    # Define the edges for the directed graph
    # Each tuple (source, target) represents a directed edge from source to target
    example_edges = [
        ('A', 'B'),
        ('A', 'C'),
        ('B', 'D'),
        ('C', 'D'),
        ('D', 'E'),
        ('E', 'A'),
        ('C', 'E')
    ]

    # --- Basic Example ---
    print("Drawing basic directed graph...")
    draw_directed_graph(example_edges)

    # --- Example with Custom Layout and Styling ---
    print("\nDrawing directed graph with custom layout and styling...")
    # You can pre-define node positions if you want a specific arrangement
    custom_pos = {
        'A': (0, 0),
        'B': (1, 1),
        'C': (1, -1),
        'D': (2, 0),
        'E': (3, 0)
    }
    plt.figure(figsize=(8, 6)) # Create a new figure for this example
    draw_directed_graph(
        example_edges,
        pos=custom_pos,
        node_size=500,
        font_size=12,
        arrow_size=30,
        ax=plt.gca() # Draw on the current figure's axes
    )
    plt.title("Directed Graph with Custom Layout and Styling")
    plt.show()

    # --- Example with a Circular Layout ---
    print("\nDrawing directed graph with circular layout...")
    plt.figure(figsize=(6, 6))
    draw_directed_graph(
        example_edges,
        pos=nx.circular_layout(nx.DiGraph(example_edges)), # Compute circular layout
        node_size=400,
        font_size=11,
        arrow_size=25,
        ax=plt.gca()
    )
    plt.title("Directed Graph with Circular Layout")
    plt.show()

Explanation:

  1. Import Libraries:

    • matplotlib.pyplot as plt: This is the standard way to import Matplotlib's plotting module.
    • networkx as nx: This imports the NetworkX library, conventionally aliased as nx.
  2. draw_directed_graph Function:

    • edges: This is the most crucial input. It's a list of tuples, where each tuple (source, target) defines a directed connection from source to target.
    • pos (Optional): This parameter allows you to specify the exact positions of your nodes on the plot. If you don't provide it, NetworkX will automatically calculate a layout (e.g., spring_layout).
    • Styling Parameters (node_size, font_size, arrow_size): These allow you to customize the appearance of the nodes, their labels, and the arrows on the directed edges.
    • ax (Optional): This is for more advanced usage where you might want to draw multiple graphs on the same figure or integrate the graph drawing into a larger Matplotlib plot. If ax is None, a new figure and axes are created.
  3. Creating the Graph:

    • G = nx.DiGraph(): Initializes an empty directed graph object.
  4. Adding Edges:

    • G.add_edges_from(edges): This efficiently adds all the edges from your edges list to the graph G. NetworkX automatically creates nodes if they don't already exist.
  5. Calculating Layout (if pos is None):

    • pos = nx.spring_layout(G): The spring_layout is a common algorithm that tries to position nodes in a visually appealing way, simulating forces between nodes. Other layouts are available (e.g., circular_layout, shell_layout, spectral_layout).
  6. Drawing Components:

    • nx.draw_networkx_nodes(G, pos, ...): Draws the nodes of the graph.
      • node_color='skyblue' and alpha=0.9 are styling options.
    • nx.draw_networkx_edges(G, pos, ...): Draws the edges.
      • arrowstyle='-|>': This is a Matplotlib arrow style that creates a filled triangle at the end of the line, indicating direction.
      • arrowsize=arrow_size: Controls the size of the arrowhead.
      • edge_color='gray' and width=1.5 are styling options.
    • nx.draw_networkx_labels(G, pos, ...): Draws the labels (node names) on top of the nodes.
  7. Plot Customization:

    • ax.set_title(...): Sets the title for the plot.
    • ax.axis('off'): Hides the x and y axes, which are usually not needed for graph visualizations.
  8. Displaying the Plot:

    • plt.show(): This command displays the Matplotlib figure containing the drawn graph. It's called if the function itself created the figure and axes.

Example Usage (if __name__ == "__main__":)

This block demonstrates how to call the draw_directed_graph function with different scenarios:

  • Basic Example: Shows the default behavior with a generated layout.
  • Custom Layout and Styling: Demonstrates how to pre-define node positions and apply different styling for nodes, fonts, and arrows. It uses plt.figure() and ax=plt.gca() to draw on a specific figure.
  • Circular Layout: Shows how to use nx.circular_layout to arrange nodes in a circle.

When you run this script, you will see three separate Matplotlib windows, each displaying a version of the directed graph.

Model: gemini-2.5-pro
Temperature: 0.2
Max output tokens: 8192
TopK: 40.0
TopP: 0.95