Source code for cylc.flow.network.client

# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""Client for workflow runtime API."""

from abc import ABCMeta, abstractmethod
import asyncio
import os
from shutil import which
import socket
import sys
from typing import Any, Optional, Union, Dict

import zmq
import zmq.asyncio

from cylc.flow import LOG
from cylc.flow.exceptions import (
    ClientError,
    ClientTimeout,
    ContactFileExists,
    CylcError,
    WorkflowStopped,
)
from cylc.flow.hostuserutil import get_fqdn_by_host
from cylc.flow.network import (
    encode_,
    decode_,
    get_location,
    ZMQSocketBase
)
from cylc.flow.network.client_factory import CommsMeth
from cylc.flow.network.server import PB_METHOD_MAP
from cylc.flow.workflow_files import (
    detect_old_contact_file,
)


class WorkflowRuntimeClientBase(metaclass=ABCMeta):
    """Base class for WorkflowRuntimeClients.

    WorkflowRuntimeClients that inherit from this must implement an async
    method ``async_request()``. This base class provides a ``serial_request()``
    method based on the ``async_request()`` method, callable by ``__call__``.
    It also provides a comms timeout handler method.
    """

    DEFAULT_TIMEOUT = 5  # seconds

    def __init__(
        self,
        workflow: str,
        host: Optional[str] = None,
        port: Union[int, str, None] = None,
        timeout: Union[float, str, None] = None
    ):
        self.workflow = workflow
        if not host or not port:
            host, port, _ = get_location(workflow)
        else:
            port = int(port)
        self.host = self._orig_host = host
        self.port = self._orig_port = port
        self.timeout = (
            float(timeout) if timeout is not None else self.DEFAULT_TIMEOUT
        )

    @abstractmethod
    async def async_request(
        self,
        command: str,
        args: Optional[Dict[str, Any]] = None,
        timeout: Optional[float] = None,
        req_meta: Optional[Dict[str, Any]] = None
    ) -> object:
        """Send an asynchronous request."""
        ...

    def serial_request(
        self,
        command: str,
        args: Optional[Dict[str, Any]] = None,
        timeout: Optional[float] = None,
        req_meta: Optional[Dict[str, Any]] = None
    ) -> object:
        """Send a request.

        For convenience use ``__call__`` to call this method.

        Args:
            command: The name of the endpoint to call.
            args: Arguments to pass to the endpoint function.
            timeout: Override the default timeout (seconds).

        Raises:
            ClientTimeout: If a response takes longer than timeout to arrive.
            ClientError: Coverall for all other issues including failed auth.

        Returns:
            object: The data exactly as returned from the endpoint function,
                nothing more, nothing less.

        """
        loop = getattr(self, 'loop', asyncio.new_event_loop())
        task = loop.create_task(
            self.async_request(command, args, timeout, req_meta)
        )
        loop.run_until_complete(task)
        if not hasattr(self, 'loop'):
            # (If inheriting class does have an event loop, don't mess with it)
            loop.close()
        return task.result()

    __call__ = serial_request

    def timeout_handler(self) -> None:
        """Handle the eventuality of a communication timeout with the workflow.

        Raises:
            WorkflowStopped: if the workflow has already stopped.
            CyclError: if the workflow has moved to different host/port.
        """
        contact_host, contact_port, _ = get_location(self.workflow)
        if (
            contact_host != get_fqdn_by_host(self._orig_host)
            or contact_port != self._orig_port
        ):
            raise CylcError(
                'The workflow is no longer running at '
                f'{self._orig_host}:{self._orig_port}\n'
                f'It has moved to {contact_host}:{contact_port}'
            )

        if os.getenv('CYLC_TASK_COMMS_METHOD'):
            # don't attempt to clean up old contact files in task messages
            return

        # Cannot connect, perhaps workflow is no longer running and is leaving
        # behind a contact file?
        try:
            detect_old_contact_file(self.workflow)
        except ContactFileExists:
            # old contact file exists and the workflow process still alive
            return
        else:
            # the workflow has stopped
            raise WorkflowStopped(self.workflow)


[docs] class WorkflowRuntimeClient( # type: ignore[misc] ZMQSocketBase, WorkflowRuntimeClientBase ): # (Ignoring mypy 'definition of "host" in base class "ZMQSocketBase" is # incompatible with definition in base class "WorkflowRuntimeClientBase"') """Initiate a client to the scheduler API. Initiates the REQ part of a ZMQ REQ-REP pair. This class contains the logic for the ZMQ message interface and client - server communication. Determine host and port from the contact file unless provided. If there is no socket bound to the specified host/port the client will bail after ``timeout`` seconds. Args: workflow: Name of the workflow to connect to. timeout: Set the default timeout in seconds. The default is ``ZMQClient.DEFAULT_TIMEOUT``. Note the default timeout can be overridden for individual requests. host: The host where the flow is running if known. If both host and port are provided it is not necessary to load the contact file. port: The port on which the REQ-REP TCP server is listening. If both host and port are provided it is not necessary to load the contact file. Attributes: host: Workflow host name. port: Workflow host port. timeout_handler: Optional function which runs before ClientTimeout is raised. This provides an interface for raising more specific exceptions in the event of a communication timeout. header: Request "header" data to attach to each request. Usage: Call endpoints using ``ZMQClient.__call__``. Message interface: * Accepts responses of the format: {"data": {...}} * Accepts error in the format: {"error": {"message": MSG}} * Returns requests of the format: {"command": CMD, "args": {...}} Raises: WorkflowStopped: if the workflow is not running. Call server "endpoints" using: ``__call__``, ``serial_request`` .. automethod:: cylc.flow.network.client.WorkflowRuntimeClient.serial_request ``async_request`` .. automethod:: cylc.flow.network.client.WorkflowRuntimeClient.async_request """ # socket & event loop not None - get assigned on init by self.start(): socket: zmq.asyncio.Socket loop: asyncio.AbstractEventLoop def __init__( self, workflow: str, host: Optional[str] = None, port: Union[int, str, None] = None, timeout: Union[float, str, None] = None, context: Optional[zmq.asyncio.Context] = None, srv_public_key_loc: Optional[str] = None ): ZMQSocketBase.__init__(self, zmq.REQ, workflow, context=context) WorkflowRuntimeClientBase.__init__(self, workflow, host, port, timeout) # convert to milliseconds: self.timeout *= 1000 self.poller: Any = None # Connect the ZMQ socket on instantiation self.start(self.host, self.port, srv_public_key_loc) # gather header info post start self.header = self.get_header() def _socket_options(self): """Set socket options after socket instantiation before connect. Overwrites Base method. """ # if there is no server don't keep the client hanging around self.socket.setsockopt(zmq.LINGER, int(self.DEFAULT_TIMEOUT)) # create a poller to handle timeouts self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN)
[docs] async def async_request( self, command: str, args: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, req_meta: Optional[Dict[str, Any]] = None ) -> object: """Send an asynchronous request using asyncio. Has the same arguments and return values as ``serial_request``. """ timeout = (float(timeout) * 1000 if timeout else None) or self.timeout if not args: args = {} # Note: we are using CurveZMQ to secure the messages (see # self.curve_auth, self.socket.curve_...key etc.). We have set up # public-key cryptography on the ZMQ messaging and sockets, so # there is no need to encrypt messages ourselves before sending. # send message msg: Dict[str, Any] = {'command': command, 'args': args} msg.update(self.header) # add the request metadata if req_meta: msg['meta'].update(req_meta) LOG.debug('zmq:send %s', msg) message = encode_(msg) self.socket.send_string(message) # receive response if self.poller.poll(timeout): res = await self.socket.recv() else: self.timeout_handler() raise ClientTimeout( 'Timeout waiting for server response.' ' This could be due to network or server issues.' '\n* You might want to increase the timeout using the' ' --comms-timeout option;' '\n* or check the workflow log.' ) if msg['command'] in PB_METHOD_MAP: response = {'data': res} else: response = decode_( res.decode() if isinstance(res, bytes) else res ) LOG.debug('zmq:recv %s', response) try: return response['data'] except KeyError: error = response.get( 'error', {'message': f'Received invalid response: {response}'}, ) raise ClientError( error.get('message'), # type: ignore error.get('traceback'), # type: ignore ) from None
def get_header(self) -> dict: """Return "header" data to attach to each request for traceability. Returns: dict: dictionary with the header information, such as program and hostname. """ host = socket.gethostname() if len(sys.argv) > 1: cmd = sys.argv[1] else: cmd = sys.argv[0] cylc_executable_location = which("cylc") if cylc_executable_location: cylc_bin_dir = os.path.abspath( os.path.join(cylc_executable_location, os.pardir) ) if not cylc_bin_dir.endswith("/"): cylc_bin_dir = f"{cylc_bin_dir}/" if cmd.startswith(cylc_bin_dir): cmd = cmd.replace(cylc_bin_dir, '') return { 'meta': { 'prog': cmd, 'host': host, 'comms_method': os.getenv( "CLIENT_COMMS_METH", default=CommsMeth.ZMQ.value ) } }