"""
Kernel core where MDAnalysis code runs
"""
import asyncio
from dataclasses import asdict
import comm
import MDAnalysis as mda
from mdadash.backend.widgets.base import WidgetManager
[docs]
class CommHandler:
"""Comm Handler
This class is responsible for handling all the communication to and
from this kernel. This is the class that interfaces with the
:class:`~mdadash.backend.kernel.manager.KernelManager` on the server side.
"""
def __init__(self):
self._comm = None
self._handlers = {}
comm.get_comm_manager().register_target(
"kernel_comm_handler", self._handle_comm_open
)
[docs]
def register_handler(self, msg_type: str, handler_func: callable) -> None:
"""Register a handler function for a message type
Parameters
----------
msg_type: str
A message type to identify the handler
handler_func: callable
The handler function to invoke when a message with this message
type is received
"""
self._handlers[msg_type] = handler_func
[docs]
def send(self, msg: dict) -> None:
"""Send a message (response) back on the comm
Parameters
----------
msg: dict
A generic message dictionary
"""
if self._comm is not None:
self._comm.send(msg)
else:
raise ValueError("comm is not open yet") # pragma: no cover
def _handle_comm_open(self, _comm: comm.base_comm.BaseComm, _msg):
"""Internal: Handler when the comm is opened (comm_open)"""
self._comm = _comm
# set the handler for comm messages (comm_msg)
self._comm.on_msg(self._handle_msg)
def _handle_msg(self, msg):
"""Internal: Dispatch the message to the registered handler"""
content_data = msg["content"]["data"]
msg_type = content_data["msg_type"]
if msg_type in self._handlers:
self._handlers[msg_type](content_data["data"])
else:
error_msg = f"{msg_type} does not have a registered handler"
self.send({"status": "error", "message": error_msg})
raise ValueError(error_msg)
[docs]
class UniverseManager:
"""Universe Manager
This class is responsible for managing all MDAnalysis universes. It has
handlers to interact with the MD simulation. These handlers are invoked by
comm messages sent from the server.
This also provides an iterable and indexable access to the individual
universes.
"""
def __init__(self, _wm: WidgetManager, _comm_handler: CommHandler):
self._universes = []
self._universes_step = []
self._iter_loop_task = None
self._iter_loop_running = False
self._iter_loop_resumed = asyncio.Event()
self._iter_loop_resumed.clear()
self._wm = _wm
self._comm_handler = _comm_handler
self._connected = False
self._running = False
def __iter__(self) -> iter:
"""To support iteration"""
return iter(self._universes)
def __len__(self) -> int:
"""Number of universes"""
return len(self._universes)
def __getitem__(self, index: int):
"""Return universe based on index"""
# numeric index based array access
_max = len(self._universes)
if 0 <= index < _max:
return self._universes[index]
raise ValueError(f"Invalid index {index} of {_max} items")
[docs]
def init_n_universes(self, n: int) -> None:
"""Initialize array for n universes
Parameters
----------
n: int
Number of universes to initialize
"""
self._universes = [None] * n
self._universes_step = [1] * n
[docs]
def connect_to_simulations(self, universe_configs: list[dict]) -> None:
"""Connect to MD simulations
Parameters
----------
universe_configs: list[dict]
A list of configurations for universe(s) creation.
Each dict has universe related config like topology, trajectory,
imdclient params, user-defined kwargs etc
"""
try:
for uid, config in enumerate(universe_configs):
kwargs = {}
topology = config.get("topology")
trajectory = config.get("trajectory")
for key, value in config.items():
if key in (
"topology",
"trajectory",
"step",
"total_steps",
"kwargs",
):
continue
if value is not None:
kwargs[key] = value
for name, value in config["kwargs"]:
if name.strip():
kwargs[name] = value
# create universe
u = mda.Universe(
topology,
trajectory,
**kwargs,
)
if uid == 0:
self._send_tsdata(u)
self._send_sessioninfo(u)
self._universes[uid] = u
# set universe for the widget instances with this uid
self._wm._set_universe(uid, u)
# save universe step config
self._universes_step = [uc.get("step", 1) for uc in universe_configs]
# start iter loop for trajectories
self._iter_loop_resumed.clear()
self._iter_loop_running = True
self._iter_loop_task = asyncio.create_task(self._iter_loop())
self._connected = True
self._comm_handler.send({"status": "ok"})
except Exception as e: # pylint: disable=broad-exception-caught
self._comm_handler.send({"status": "error", "message": str(e)})
def _send_tsdata(self, u: mda.Universe):
"""Internal: Send timestep data out"""
self._comm_handler.send(
{
"tsinfo": {
"frame": u.trajectory.frame,
"tsdata": u.trajectory.ts.data,
}
}
)
def _send_sessioninfo(self, u: mda.Universe):
self._comm_handler.send(
{"sessioninfo": asdict(u.trajectory._imdclient.get_imdsessioninfo())}
)
[docs]
def disconnect_from_simulations(self, _data: dict) -> None:
"""Disconnect from MD simulations"""
self._iter_loop_running = False
self._iter_loop_task.cancel()
for u in self._universes:
u.trajectory.close()
self._connected = False
self._wm._invoke_lifecycle_method("post_disconnect")
self._comm_handler.send({"status": "ok"})
[docs]
def pause_simulations(self, _data: dict) -> None:
"""Pause MD simulations"""
self._iter_loop_resumed.clear()
self._running = False
self._wm._invoke_lifecycle_method("post_pause")
self._comm_handler.send({"status": "ok"})
[docs]
def resume_simulations(self, _data: dict) -> None:
"""Resume MD simulations"""
self._wm._invoke_lifecycle_method("pre_resume")
self._iter_loop_resumed.set()
self._running = True
self._comm_handler.send({"status": "ok"})
def _trajectory_next(self, u, step):
"""Internal: Iterate trajectory by `step` frame(s)"""
try:
# TODO: Modify this when `imdclient` and `IMDReader` are updated
# with support for 'Transmission rate' packet
# Burn the timesteps until we reach the desired step
# Don't use next() to avoid unnecessary transformations
while (u.trajectory._frame + 1) % step != 0:
u.trajectory._read_next_timestep()
except (EOFError, IOError): # pragma: no cover
raise StopIteration from None
return u.trajectory.next()
async def _iter_loop(self):
"""Internal: Iteration loop for trajectories"""
try:
while self._iter_loop_running:
await self._iter_loop_resumed.wait()
for uid, u in enumerate(self._universes):
try:
# iterate in thread to not block on a network call here
await asyncio.to_thread(
self._trajectory_next,
u,
self._universes_step[uid],
)
if uid == 0:
self._send_tsdata(u)
# run widgets for this timestep
self._wm.run_widgets(uid)
# pylint: disable=broad-exception-caught
except Exception as e: # pragma: no cover
print(e)
await asyncio.sleep(0)
except asyncio.CancelledError:
pass
[docs]
def init_n_universes(data: dict) -> None:
"""Initialize `n` universes in :class:`UniverseManager`"""
um.init_n_universes(data.get("n"))
wm = WidgetManager()
comm_handler = CommHandler()
um = UniverseManager(wm, comm_handler)
widgets_comm = WidgetsComm(um, wm, comm_handler)
# register handlers
comm_handler.register_handler("init_n_universes", init_n_universes)
comm_handler.register_handler("connect_to_simulations", um.connect_to_simulations)
comm_handler.register_handler(
"disconnect_from_simulations", um.disconnect_from_simulations
)
comm_handler.register_handler("pause_simulations", um.pause_simulations)
comm_handler.register_handler("resume_simulations", um.resume_simulations)
comm_handler.register_handler(
"widgets:get_available_widgets", widgets_comm.get_available_widgets
)
comm_handler.register_handler("widgets:add_instance", widgets_comm.add_instance)
comm_handler.register_handler(
"widgets:duplicate_instance", widgets_comm.duplicate_instance
)
comm_handler.register_handler("widgets:remove_instance", widgets_comm.remove_instance)
comm_handler.register_handler("widget:get_inputs", widgets_comm.get_inputs)
comm_handler.register_handler("widget:set_input", widgets_comm.set_input)