import logging
import math
import os
import socket
from contextlib import closing
from typing import Optional
import numpy as np
from bmipy import Bmi
import grpc
import numpy
from grpc_status import rpc_status
from google.rpc import error_details_pb2
from . import bmi_pb2, bmi_pb2_grpc
from .constants import GRPC_MAX_MESSAGE_LENGTH
log = logging.getLogger(__name__)
[docs]
class RemoteException(grpc.RpcError):
def __init__(self, message, tb):
super().__init__(message)
self.remote_stacktrace = tb
[docs]
def handle_error(exc):
"""Parsers DebugInfo (https://github.com/googleapis/googleapis/blob/07244bb797ddd6e0c1c15b02b4467a9a5729299f/google/rpc/error_details.proto#L46-L52) from the trailing metadata of a grpc.RpcError
Args:
exc (grpc.RpcError): Exception to handle
Raises: original exception or RemoteException
"""
status = rpc_status.from_call(exc)
if status is None:
raise
for detail in status.details:
if detail.Is(error_details_pb2.DebugInfo.DESCRIPTOR):
info = error_details_pb2.DebugInfo()
detail.Unpack(info)
remote_traceback = info.stack_entries
remote_detail = info.detail
raise RemoteException(remote_detail, remote_traceback) from exc
raise
def _fits_in_message(array):
"""Tests whether array can be passed through a gRPC message with a max message size of 4Mb"""
array_size = array.size * array.itemsize
return array_size <= GRPC_MAX_MESSAGE_LENGTH
[docs]
class BmiClient(Bmi):
"""
Client BMI interface, implementing BMI by forwarding every function call via GRPC to the server connected to the
same port. A GRPC channel can be passed to the constructor; if not, it constructs an insecure channel on a free
port itself. The timeout parameter indicates the model BMI startup timeout parameter (s).
>>> import grpc
>>> from grpc4bmi.bmi_grpc_client import BmiClient
>>> mymodel = BmiClient(grpc.insecure_channel("localhost:<PORT>"))
>>> print(mymodel.get_component_name())
Hello world
"""
def __init__(self, channel=None, timeout=None, stub=None):
if stub is None:
c = BmiClient.create_grpc_channel() if channel is None else channel
self.stub = bmi_pb2_grpc.BmiServiceStub(c)
future = grpc.channel_ready_future(c)
future.result(timeout=timeout)
else:
self.stub = stub
def __del__(self):
del self.stub
[docs]
@staticmethod
def create_grpc_channel(port=0, host=None):
p, h = port, host
if h is None:
h = "localhost"
if p == 0:
p = os.environ.get("BMI_PORT", 50051)
return grpc.insecure_channel(':'.join([h, str(p)]))
[docs]
@staticmethod
def get_unique_port(host=None):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("" if host is None else host, 0))
return int(s.getsockname()[1])
[docs]
def initialize(self, filename: Optional[str]):
fname = "" if filename is None else filename
try:
return self.stub.initialize(bmi_pb2.InitializeRequest(config_file=fname))
except grpc.RpcError as e:
handle_error(e)
[docs]
def update(self):
try:
self.stub.update(bmi_pb2.Empty())
except grpc.RpcError as e:
handle_error(e)
[docs]
def update_until(self, time: float) -> None:
try:
self.stub.updateUntil(bmi_pb2.GetTimeResponse(time=time))
except grpc.RpcError as e:
handle_error(e)
[docs]
def finalize(self):
try:
self.stub.finalize(bmi_pb2.Empty())
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_component_name(self):
try:
return str(self.stub.getComponentName(bmi_pb2.Empty()).name)
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_output_item_count(self) -> int:
try:
return self.stub.getOutputItemCount(bmi_pb2.Empty()).count
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_output_var_names(self):
try:
return tuple([str(s) for s in self.stub.getOutputVarNames(bmi_pb2.Empty()).names])
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_time_units(self):
try:
response = str(self.stub.getTimeUnits(bmi_pb2.Empty()).units)
return None if not response else response
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_time_step(self):
try:
return self.stub.getTimeStep(bmi_pb2.Empty()).interval
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_current_time(self):
try:
return self.stub.getCurrentTime(bmi_pb2.Empty()).time
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_start_time(self):
try:
return self.stub.getStartTime(bmi_pb2.Empty()).time
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_end_time(self):
try:
return self.stub.getEndTime(bmi_pb2.Empty()).time
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_var_grid(self, name):
try:
return self.stub.getVarGrid(bmi_pb2.GetVarRequest(name=name)).grid_id
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_var_type(self, name):
try:
return str(self.stub.getVarType(bmi_pb2.GetVarRequest(name=name)).type)
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_var_itemsize(self, name):
try:
item_size = self.stub.getVarItemSize(bmi_pb2.GetVarRequest(name=name)).size
if item_size == 0:
# BMI < v2.0 did not have get_var_itemsize, so old server will return 0
# fallback to getting item size from var type
var_type = self.get_var_type(name)
try:
item_size = numpy.dtype(var_type).itemsize
log.info(f'get_var_itemsize returned 0, corrected to {item_size} using get_var_type.')
except TypeError:
raise ValueError('get_var_itemsize returned 0, which is impossible')
return item_size
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_var_units(self, name):
try:
response = str(self.stub.getVarUnits(bmi_pb2.GetVarRequest(name=name)).units)
return None if not response else response
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_var_nbytes(self, name):
try:
return self.stub.getVarNBytes(bmi_pb2.GetVarRequest(name=name)).nbytes
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_var_location(self, name: str) -> str:
try:
location = self.stub.getVarLocation(bmi_pb2.GetVarRequest(name=name)).location
return bmi_pb2.GetVarLocationResponse.Location.Name(location).lower()
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_value(self, name, dest):
fits = _fits_in_message(dest)
if not fits:
return self._chunked_get_value(name, dest)
try:
response = self.stub.getValue(bmi_pb2.GetVarRequest(name=name))
numpy.copyto(src=BmiClient.make_array(response), dst=dest)
return dest
except grpc.RpcError as e:
handle_error(e)
def _chunked_get_value(self, name: str, dest: np.array) -> np.array:
# Make chunk one item smaller than maximum (4Mb)
chunk_size = math.floor(GRPC_MAX_MESSAGE_LENGTH / dest.dtype.itemsize) - dest.dtype.itemsize
chunks = []
log.info(f'Too many items ({dest.size}) for single call, '
f'using multiple get_value_at_indices() with into chunks of {chunk_size} items')
for i in range(0, dest.size, chunk_size):
start = i
stop = i + chunk_size
# Last chunk can be smaller
if stop > dest.size:
stop = dest.size
chunks.append(self._get_value_at_range(name, start, stop))
numpy.concatenate(chunks, out=dest)
return dest
def _get_value_at_range(self, name, start, stop):
log.info(f'Fetching value range {start} - {stop}')
try:
response = self.stub.getValueAtIndices(bmi_pb2.GetValueAtIndicesRequest(name=name,
indices=range(start, stop)))
return BmiClient.make_array(response)
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_value_ptr(self, name: str) -> np.ndarray:
"""Not possible, unable give reference to data structure in another process and possibly another machine"""
raise NotImplementedError("Array references cannot be transmitted through this GRPC channel")
[docs]
def get_value_at_indices(self, name, dest, indices):
try:
index_array = indices
if indices is list:
index_array = numpy.array(indices)
response = self.stub.getValueAtIndices(bmi_pb2.GetValueAtIndicesRequest(name=name,
indices=index_array.flatten()))
numpy.copyto(src=BmiClient.make_array(response), dst=dest)
return dest
except grpc.RpcError as e:
handle_error(e)
[docs]
def set_value(self, name, values):
try:
if values.dtype in (numpy.int16, numpy.int32, numpy.int64):
request = bmi_pb2.SetValueRequest(name=name,
values_int=bmi_pb2.IntArrayMessage(values=values.flatten()))
elif values.dtype in (numpy.float32, numpy.float16):
request = bmi_pb2.SetValueRequest(name=name,
values_float=bmi_pb2.FloatArrayMessage(values=values.flatten()))
elif values.dtype == numpy.float64:
request = bmi_pb2.SetValueRequest(name=name,
values_double=bmi_pb2.DoubleArrayMessage(values=values.flatten()))
else:
raise NotImplementedError("Arrays with type %s cannot be transmitted through this GRPC channel" % values.dtype)
self.stub.setValue(request)
except grpc.RpcError as e:
handle_error(e)
[docs]
def set_value_at_indices(self, name, inds, src):
try:
index_array = inds
if inds is list:
index_array = numpy.array(inds)
if src.dtype in (numpy.int32, numpy.int64):
request = bmi_pb2.SetValueAtIndicesRequest(name=name,
indices=index_array.flatten(),
values_int=bmi_pb2.IntArrayMessage(values=src.flatten()))
elif src.dtype in (numpy.float32, numpy.float16):
request = bmi_pb2.SetValueAtIndicesRequest(name=name,
indices=index_array.flatten(),
values_float=bmi_pb2.FloatArrayMessage(values=src.flatten()))
elif src.dtype == numpy.float64:
request = bmi_pb2.SetValueAtIndicesRequest(name=name,
indices=index_array.flatten(),
values_double=bmi_pb2.DoubleArrayMessage(values=src.flatten()))
else:
raise NotImplementedError("Arrays with type %s cannot be transmitted through this GRPC channel" % src.dtype)
self.stub.setValueAtIndices(request)
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_size(self, grid):
try:
return self.stub.getGridSize(bmi_pb2.GridRequest(grid_id=grid)).size
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_rank(self, grid):
try:
return self.stub.getGridRank(bmi_pb2.GridRequest(grid_id=grid)).rank
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_type(self, grid):
try:
return str(self.stub.getGridType(bmi_pb2.GridRequest(grid_id=grid)).type)
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_x(self, grid, x):
try:
src = numpy.array(self.stub.getGridX(bmi_pb2.GridRequest(grid_id=grid)).coordinates)
numpy.copyto(src=src, dst=x)
return x
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_y(self, grid, y):
try:
src = numpy.array(self.stub.getGridY(bmi_pb2.GridRequest(grid_id=grid)).coordinates)
numpy.copyto(src=src, dst=y)
return y
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_z(self, grid, z):
try:
src = numpy.array(self.stub.getGridZ(bmi_pb2.GridRequest(grid_id=grid)).coordinates)
numpy.copyto(src=src, dst=z)
return z
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_shape(self, grid, shape):
try:
src = tuple(self.stub.getGridShape(bmi_pb2.GridRequest(grid_id=grid)).shape)
numpy.copyto(src=src, dst=shape)
return shape
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_spacing(self, grid, spacing):
try:
src = tuple(self.stub.getGridSpacing(bmi_pb2.GridRequest(grid_id=grid)).spacing)
numpy.copyto(src=src, dst=spacing)
return spacing
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_origin(self, grid, origin):
try:
src = tuple(self.stub.getGridOrigin(bmi_pb2.GridRequest(grid_id=grid)).origin)
numpy.copyto(src=src, dst=origin)
return origin
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_node_count(self, grid: int) -> int:
try:
return self.stub.getGridNodeCount(bmi_pb2.GridRequest(grid_id=grid)).count
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_edge_count(self, grid: int) -> int:
try:
return self.stub.getGridEdgeCount(bmi_pb2.GridRequest(grid_id=grid)).count
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_face_count(self, grid: int) -> int:
try:
return self.stub.getGridFaceCount(bmi_pb2.GridRequest(grid_id=grid)).count
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_edge_nodes(self, grid: int, edge_nodes: np.ndarray) -> np.ndarray:
try:
links = self.stub.getGridEdgeNodes(bmi_pb2.GridRequest(grid_id=grid)).edge_nodes
numpy.copyto(src=links, dst=edge_nodes)
return edge_nodes
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_face_nodes(self, grid: int, face_nodes: np.ndarray) -> np.ndarray:
try:
links = self.stub.getGridFaceNodes(bmi_pb2.GridRequest(grid_id=grid)).face_nodes
numpy.copyto(src=links, dst=face_nodes)
return face_nodes
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_face_edges(self, grid: int, face_edges: np.ndarray) -> np.ndarray:
try:
links = self.stub.getGridFaceEdges(bmi_pb2.GridRequest(grid_id=grid)).face_edges
numpy.copyto(src=links, dst=face_edges)
return face_edges
except grpc.RpcError as e:
handle_error(e)
[docs]
def get_grid_nodes_per_face(self, grid: int, nodes_per_face: np.ndarray) -> np.ndarray:
try:
links = self.stub.getGridNodesPerFace(bmi_pb2.GridRequest(grid_id=grid)).nodes_per_face
numpy.copyto(src=links, dst=nodes_per_face)
return nodes_per_face
except grpc.RpcError as e:
handle_error(e)
[docs]
@staticmethod
def make_array(response):
if response.HasField("values_int"):
return numpy.array(response.values_int.values)
if response.HasField("values_float"):
return numpy.array(response.values_float.values)
if response.HasField("values_double"):
return numpy.array(response.values_double.values)