Source code for cylc.flow.network.client

# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""Client for workflow runtime API."""

from abc import (
    ABCMeta,
    abstractmethod,
)
import asyncio
import os
from shutil import which
import socket
import sys
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Optional,
    Union,
)

import zmq
import zmq.asyncio

from cylc.flow import (
    LOG,
    __version__ as CYLC_VERSION,
)
from cylc.flow.exceptions import (
    ClientTimeout,
    ContactFileExists,
    CylcError,
    RequestError,
    WorkflowStopped,
)
from cylc.flow.hostuserutil import get_fqdn_by_host
from cylc.flow.network import (
    ZMQSocketBase,
    deserialize,
    get_location,
    serialize,
)
from cylc.flow.network.client_factory import CommsMeth
from cylc.flow.network.server import PB_METHOD_MAP
from cylc.flow.workflow_files import detect_old_contact_file


if TYPE_CHECKING:
    from cylc.flow.network import ResponseDict


class WorkflowRuntimeClientBase(metaclass=ABCMeta):
    """Base class for WorkflowRuntimeClients.

    WorkflowRuntimeClients that inherit from this must implement an async
    method ``async_request()``. This base class provides a ``serial_request()``
    method based on the ``async_request()`` method, callable by ``__call__``.
    It also provides a comms timeout handler method.
    """

    DEFAULT_TIMEOUT = 5  # seconds

    def __init__(
        self,
        workflow: str,
        host: Optional[str] = None,
        port: Union[int, str, None] = None,
        timeout: Union[float, str, None] = None
    ):
        self.workflow = workflow
        if not host or not port:
            host, port, _ = get_location(workflow)
        else:
            port = int(port)
        self.host = self._orig_host = host
        self.port = self._orig_port = port
        self.timeout = (
            float(timeout) if timeout is not None else self.DEFAULT_TIMEOUT
        )

    @abstractmethod
    async def async_request(
        self,
        command: str,
        args: Optional[Dict[str, Any]] = None,
        timeout: Optional[float] = None,
        req_meta: Optional[Dict[str, Any]] = None
    ) -> object:
        """Send an asynchronous request."""
        ...

    def serial_request(
        self,
        command: str,
        args: Optional[Dict[str, Any]] = None,
        timeout: Optional[float] = None,
        req_meta: Optional[Dict[str, Any]] = None
    ) -> object:
        """Send a request.

        For convenience use ``__call__`` to call this method.

        Args:
            command: The name of the endpoint to call.
            args: Arguments to pass to the endpoint function.
            timeout: Override the default timeout (seconds).

        Raises:
            ClientTimeout: If a response takes longer than timeout to arrive.
            ClientError: Coverall for all other issues including failed auth.

        Returns:
            object: The data exactly as returned from the endpoint function,
                nothing more, nothing less.

        """
        loop = getattr(self, 'loop', asyncio.new_event_loop())
        task = loop.create_task(
            self.async_request(command, args, timeout, req_meta)
        )
        loop.run_until_complete(task)
        if not hasattr(self, 'loop'):
            # (If inheriting class does have an event loop, don't mess with it)
            loop.close()
        return task.result()

    __call__ = serial_request

    def timeout_handler(self) -> None:
        """Handle the eventuality of a communication timeout with the workflow.

        Raises:
            WorkflowStopped: if the workflow has already stopped.
            CyclError: if the workflow has moved to different host/port.
        """
        contact_host, contact_port, _ = get_location(self.workflow)
        if (
            contact_host != get_fqdn_by_host(self._orig_host)
            or contact_port != self._orig_port
        ):
            raise CylcError(
                'The workflow is no longer running at '
                f'{self._orig_host}:{self._orig_port}\n'
                f'It has moved to {contact_host}:{contact_port}'
            )

        if os.getenv('CYLC_TASK_COMMS_METHOD'):
            # don't attempt to clean up old contact files in task messages
            return

        # Cannot connect, perhaps workflow is no longer running and is leaving
        # behind a contact file?
        try:
            detect_old_contact_file(self.workflow)
        except ContactFileExists:
            # old contact file exists and the workflow process still alive
            return
        else:
            # the workflow has stopped
            raise WorkflowStopped(self.workflow)


[docs] class WorkflowRuntimeClient( # type: ignore[misc] ZMQSocketBase, WorkflowRuntimeClientBase ): # (Ignoring mypy 'definition of "host" in base class "ZMQSocketBase" is # incompatible with definition in base class "WorkflowRuntimeClientBase"') """Initiate a client to the scheduler API. Initiates the REQ part of a ZMQ REQ-REP pair. This class contains the logic for the ZMQ message interface and client - server communication. Determine host and port from the contact file unless provided. If there is no socket bound to the specified host/port the client will bail after ``timeout`` seconds. Args: workflow: Name of the workflow to connect to. timeout: Set the default timeout in seconds. The default is ``ZMQClient.DEFAULT_TIMEOUT``. Note the default timeout can be overridden for individual requests. host: The host where the flow is running if known. If both host and port are provided it is not necessary to load the contact file. port: The port on which the REQ-REP TCP server is listening. If both host and port are provided it is not necessary to load the contact file. Attributes: host: Workflow host name. port: Workflow host port. timeout_handler: Optional function which runs before ClientTimeout is raised. This provides an interface for raising more specific exceptions in the event of a communication timeout. header: Request "header" data to attach to each request. Usage: Call endpoints using ``ZMQClient.__call__``. Message interface: * Accepts responses of the format: {"data": {...}} * Accepts error in the format: {"error": {"message": MSG}} * Returns requests of the format: {"command": CMD, "args": {...}} Raises: WorkflowStopped: if the workflow is not running. Call server "endpoints" using: ``__call__``, ``serial_request`` .. automethod:: cylc.flow.network.client.WorkflowRuntimeClient.serial_request ``async_request`` .. automethod:: cylc.flow.network.client.WorkflowRuntimeClient.async_request """ # socket & event loop not None - get assigned on init by self.start(): socket: zmq.asyncio.Socket loop: asyncio.AbstractEventLoop def __init__( self, workflow: str, host: Optional[str] = None, port: Union[int, str, None] = None, timeout: Union[float, str, None] = None, context: Optional[zmq.asyncio.Context] = None, srv_public_key_loc: Optional[str] = None ): ZMQSocketBase.__init__(self, zmq.REQ, workflow, context=context) WorkflowRuntimeClientBase.__init__(self, workflow, host, port, timeout) # convert to milliseconds: self.timeout *= 1000 self.poller: Any = None # Connect the ZMQ socket on instantiation self.start(self.host, self.port, srv_public_key_loc) # gather header info post start self.header = self.get_header() def _socket_options(self): """Set socket options after socket instantiation before connect. Overwrites Base method. """ # if there is no server don't keep the client hanging around self.socket.setsockopt(zmq.LINGER, int(self.DEFAULT_TIMEOUT)) # create a poller to handle timeouts self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN)
[docs] async def async_request( self, command: str, args: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, req_meta: Optional[Dict[str, Any]] = None ) -> object: """Send an asynchronous request using asyncio. Has the same arguments and return values as ``serial_request``. """ timeout = (float(timeout) * 1000 if timeout else None) or self.timeout if not args: args = {} # Note: we are using CurveZMQ to secure the messages (see # self.curve_auth, self.socket.curve_...key etc.). We have set up # public-key cryptography on the ZMQ messaging and sockets, so # there is no need to encrypt messages ourselves before sending. # send message msg: Dict[str, Any] = {'command': command, 'args': args} msg.update(self.header) # add the request metadata if req_meta: msg['meta'].update(req_meta) LOG.debug('zmq:send %s', msg) message = serialize(msg) self.socket.send_string(message) # receive response if self.poller.poll(timeout): res: bytes = await self.socket.recv() else: self.timeout_handler() raise ClientTimeout( 'Timeout waiting for server response.' ' This could be due to network or server issues.' '\n* You might want to increase the timeout using the' ' --comms-timeout option;' '\n* or check the workflow log.' ) LOG.debug('zmq:recv %s', res) if command in PB_METHOD_MAP: return res response: ResponseDict = deserialize(res.decode()) try: return response['data'] except KeyError: error = response.get('error') if isinstance(error, dict): error = error.get('message', error) if not error: error = ( f"Received invalid response for Cylc {CYLC_VERSION}: " f"{response}" ) raise RequestError( str(error), response.get('cylc_version') ) from None
def get_header(self) -> dict: """Return "header" data to attach to each request for traceability. Returns: dict: dictionary with the header information, such as program and hostname. """ host = socket.gethostname() if len(sys.argv) > 1: cmd = sys.argv[1] else: cmd = sys.argv[0] cylc_executable_location = which("cylc") if cylc_executable_location: cylc_bin_dir = os.path.abspath( os.path.join(cylc_executable_location, os.pardir) ) if not cylc_bin_dir.endswith("/"): cylc_bin_dir = f"{cylc_bin_dir}/" if cmd.startswith(cylc_bin_dir): cmd = cmd.replace(cylc_bin_dir, '') return { 'meta': { 'prog': cmd, 'host': host, 'comms_method': os.getenv( "CLIENT_COMMS_METH", default=CommsMeth.ZMQ.value ) } }