Source code for pymoskito.tools

# -*- coding: utf-8 -*-
"""
Tools, functions and other funny things
"""
import copy
import logging
import os
import re

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PyQt5.QtGui import QColor

logger = logging.getLogger(__name__)

__all__ = ["rotation_matrix_xyz", "get_resource"]


def sort_lists(a, b):
    b = [x for (y, x) in sorted(zip(a, b))]
    a = sorted(a)
    return a, b


[docs]def get_resource(res_name, res_type="icons"): """ Build absolute path to specified resource within the package Args: res_name (str): name of the resource res_type (str): subdir Return: str: path to resource """ own_path = os.path.dirname(__file__) resource_path = os.path.abspath(os.path.join(own_path, "resources", res_type)) return os.path.join(resource_path, res_name)
def sort_tree(data_list, sort_key_path): """ Helper method for data sorting. Takes a list of simulation results and sorts them into a tree whose index is given by the sort_key_path. Args: data_list(list): List of simulation results sort_key_path(list): List of dictionary keys to sort for. Return: dict: sorted dictionary """ result = {} for elem in data_list: temp_element = copy.deepcopy(elem) sort_name = get_sub_value(temp_element, sort_key_path) if sort_name not in result: result.update({sort_name: {}}) while temp_element: val, keys = _remove_deepest(temp_element) if keys: _add_sub_value(result[sort_name], keys, val) return result def get_sub_value(source, key_path): sub_dict = source for key in key_path: sub_dict = sub_dict[key] return sub_dict def _remove_deepest(top_dict, keys=None): """ Iterates recursively over dict and removes deepest entry. Args: top_dict (dict): dictionary keys (list): select entries to remove Return: tuple: entry and path to entry """ if not keys: keys = [] for key in list(top_dict.keys()): val = top_dict[key] if isinstance(val, dict): if val: keys.append(key) return _remove_deepest(val, keys) else: del top_dict[key] continue else: del top_dict[key] keys.append(key) return val, keys return None, None def _add_sub_value(top_dict, keys, val): if len(keys) == 1: # we are here if keys[0] in top_dict: top_dict[keys[0]].append(val) else: top_dict.update({keys[0]: [val]}) return # keep iterating if keys[0] not in top_dict: top_dict.update({keys[0]: {}}) _add_sub_value(top_dict[keys[0]], keys[1:], val) return
[docs]def rotation_matrix_xyz(axis, angle, angle_dim): """ Calculate the rotation matrix for a rotation around a given axis with the angle :math:`\\varphi`. Args: axis (str): choose rotation axis "x", "y" or "z" angle (int or float): rotation angle :math:`\\varphi` angle_dim (str): choose "deg" for degree or "rad" for radiant Return: :obj:`numpy.ndarray`: rotation matrix """ assert angle_dim is "deg" or angle_dim is "rad" assert axis is "x" or axis is "y" or axis is "z" x = 0 y = 0 z = 0 if angle_dim is "deg": a = np.deg2rad(angle) else: a = angle if axis is "x": x = 1 y = 0 z = 0 if axis is "y": x = 0 y = 1 z = 0 if axis is "z": x = 0 y = 0 z = 1 s = np.sin(a) c = np.cos(a) rotation_matrix = np.array([[c + x ** 2 * (1 - c), x * y * (1 - c) - z * s, x * z * (1 - c) + y * s], [y * x * (1 - c) + z * s, c + y ** 2 * (1 - c), y * z * (1 - c) - x * s], [z * x * (1 - c) - y * s, z * y * (1 - c) + x * s, c + z ** 2 * (1 - c)]]) return rotation_matrix
class PlainTextLogger(logging.Handler): """ Logging handler hat formats log data for line display """ def __init__(self, settings, level=logging.NOTSET): logging.Handler.__init__(self, level) self.name = "PlainTextLogger" self.settings = settings formatter = logging.Formatter( fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S") self.setFormatter(formatter) log_filter = PostFilter(invert=True) self.addFilter(log_filter) self.cb = None def set_target_cb(self, cb): self.cb = cb def emit(self, record): msg = self.format(record) if self.cb: clr = QColor(self.settings.value("log_colors/" + record.levelname, "#000000")) self.cb.setTextColor(clr) self.cb.append(msg) else: logging.getLogger().error("No callback configured!") class PostFilter(logging.Filter): """ Filter to sort out all not PostProcessing related log information """ def __init__(self, invert=False): logging.Filter.__init__(self) self._invert = invert self.exp = re.compile(r"Post|Meta|Process") def filter(self, record): m = self.exp.match(record.name) if self._invert: return not bool(m) else: return bool(m) def swap_cols(arr, frm, to): """ Swap the column `frm` from a given index `to` the given index. """ arr[:, [frm, to]] = arr[:, [to, frm]] return arr def swap_rows(arr, frm, to): """ Swap the rows `frm` from a given index `to` the given index. """ if len(arr.shape) == 1: arr[[frm, to]] = arr[[to, frm]] elif len(arr.shape) == 2: arr[[frm, to], :] = arr[[to, frm], :] return arr class LengthList(object): def __init__(self, maxLength): self.maxLength = maxLength self.ls = [] def push(self, st): if len(self.ls) == self.maxLength: self.ls.pop(0) self.ls.append(st) def get_list(self): return self.ls def __len__(self): return len(self.ls) def __getitem__(self, key): return self.ls[key] def get_figure_size(scale): """ calculate optimal figure size with the golden ratio :param scale: :return: """ # TODO: Get this from LaTeX using \the\textwidth fig_width_pt = 448.13095 inches_per_pt = 1.0 / 72.27 # Convert pt to inch (stupid imperial system) golden_ratio = (np.sqrt(5.0) - 1.0) / 2.0 # Aesthetic ratio fig_width = fig_width_pt * inches_per_pt * scale # width in inches fig_height = fig_width * golden_ratio # height in inches fig_size = [fig_width, fig_height] return fig_size class Exporter: def __init__(self, **kwargs): data_points = kwargs.get('data_points', None) if data_points is None: raise Exception("Given data points are None!") # build pandas data frame self.df = pd.DataFrame.from_dict(data_points) if 'time' in self.df.columns: self.df.set_index('time', inplace=True) def export_png(self, file_name): fig = plt.figure(figsize=(10, 6)) gs = gridspec.GridSpec(1, 1, hspace=0.1) axes = plt.Subplot(fig, gs[0]) for col in self.df.columns: self.df[col].plot(ax=axes, label=col) axes.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=4, ncol=4, mode="expand", borderaxespad=0., framealpha=0.5) axes.grid(True) if self.df.index.name == 'time': axes.set_xlabel(r"Time (s)") fig.add_subplot(axes) fig.savefig(file_name, dpi=300) def export_csv(self, file_name, sep=','): self.df.to_csv(file_name, sep=sep)