Source code for napari_toska.ToskaSkeleton

from napari.layers import Labels
import napari_toska as nts
import networkx as nx
import numpy as np
import pandas as pd


[docs] class ToskaSkeleton(Labels): """ A class to represent a skeleton image in napari. Parameters: ----------- labels_data : napari.types.LabelsData The data to be processed - has to be a napari Label image. neighborhood : str The neighborhood connectivity of the skeleton. Can be "n4", "n6", "n8", "n18", or "n26". "n4" and "n8" apply only to 2D images. "n6", "n18", and "n26" apply to 3D images. **kwargs Additional keyword arguments to pass to the Labels layer. Attributes: ----------- neighborhood : str The neighborhood connectivity of the skeleton. graph : nx.Graph A networkx graph representing the skeleton. The graph is built from the skeleton data and is stored in `layer.metadata` for easy access. features : pd.DataFrame A DataFrame containing features of the individual nodes and edges of the skeleton graph but also of the skeleton as a whole. Methods: -------- analyze() Analyze the skeleton data and build a networkx graph from it. create_feature_map(feature: str) -> napari.types.ImageData Create a feature map from the skeleton data. Any column in the features DataFrame can be used as a feature. Examples: --------- >>> import napari >>> import napari_toska as nts >>> from skimage.data import binary_blobs >>> from skimage.measure import label >>> >>> # create a binary image >>> labels = label(binary_blobs(rng=0)) >>> >>> # Build the Skeleton object >>> Skeleton = nts.ToskaSkeleton(labels, neighborhood='n8') >>> Skeleton.analyze() """ def __init__(self, labels_data: "napari.types.LabelsData", neighborhood: str, **kwargs): super().__init__(np.asarray(labels_data), **kwargs) self._neighborhood = neighborhood self.graph = nx.Graph() def analyze(self): """ Analyze the skeleton data and build a networkx graph from it. Returns: -------- None """ self._parse_skeleton() self._build_nx_graph() self._measure_branch_length() self._detect_spines() self._graph_summary() # drop index column if 'index' in self.features.columns: self.features = self.features.drop(columns='index') @property def neighborhood(self): return self._neighborhood @property def graph(self) -> nx.Graph: return self.metadata.get('graph') @graph.setter def graph(self, graph: nx.Graph): self.metadata['graph'] = graph def create_feature_map(self, feature: str) -> "napari.types.ImageData": feature_map = np.zeros(self.data.shape, dtype=int) for _, row in self.features.iterrows(): feature_map[self.data == row['label']] = row[feature] return feature_map def _parse_skeleton(self) -> None: from skimage import measure labelled_skeletons = nts.generate_labeled_skeletonization(self.data).astype(int) parsed_skeletons = nts.parse_all_skeletons(labelled_skeletons, neighborhood=self._neighborhood) end_points = measure.label(parsed_skeletons == 1) branches = measure.label(parsed_skeletons == 2) branch_points = measure.label(parsed_skeletons == 3) # create a labels image with unique labels for each object n_endpoints = end_points.max() n_branches = branches.max() branches[branches != 0] += n_endpoints branch_points[branch_points != 0] += n_endpoints + n_branches unique_labels = (end_points + branches + branch_points).astype(int) # add object types to features (branch/end/chain) object_types = np.ones((unique_labels.max()), dtype=int) object_types[:n_endpoints] = 1 # end points object_types[n_endpoints:n_endpoints + n_branches] = 2 # branches object_types[n_endpoints + n_branches:] = 3 # branch points self.data = unique_labels self.features = pd.DataFrame({ 'label': np.arange(1, unique_labels.max() + 1).astype(int), 'object_type': object_types, }) # add skeleton ID to features as a categorical column for i, row in self.features.iterrows(): self.features.loc[i, 'skeleton_id'] = labelled_skeletons[self.data == row['label']][0] # make column type categorical self.features['skeleton_id'] = self.features['skeleton_id'].astype('category') self.features['object_type'] = self.features['object_type'].astype('category') self.features['label'] = self.features['label'].astype('category') return def _build_nx_graph(self): import tqdm # add all branch points and end points to Graph for _, row in self.features[self.features['object_type'] != 2].iterrows(): self.graph.add_node(row['label'], object_type=row['object_type'], label=row['label']) df_branches = self.features[self.features['object_type'] == 2] # iterate over branches and find neighboring branch points or end points for i, row in tqdm.tqdm(df_branches.iterrows(), desc='Building Graph', total=len(df_branches)): connecting_labels = self._find_neighboring_labels(row['label']) # a branch should connect to exactly two other objects if len(connecting_labels) != 2: self.features.iloc[i, self.features.columns.get_loc('object_type')] = 1 print('detected malformatted label: {}'.format(int(row['label'])), 'changed type from 2 (branch) -> 1 (end point)') continue self.graph.add_edge(connecting_labels[0], connecting_labels[1], label=row['label']) # check if there are any isolated nodes of type 1 (end points) isolated_nodes = np.array([node for node in self.graph.nodes if self.graph.degree(node) == 0], dtype=int) if len(isolated_nodes) > 0: print('Found isolated nodes: ', isolated_nodes) # check for neighborhood around isolated nodes for node in isolated_nodes: connecting_labels = self._find_neighboring_labels(node) if len(connecting_labels) == 2: self.graph.add_edge(connecting_labels[0], connecting_labels[1], label=None) else: print('Could not connect isolated node: ', node) def _measure_branch_length(self): # create a label images with only branch labels, mute all the others LUT = np.asarray([0] + list(self.features['label'])) object_type = np.asarray([0] + list(self.features['object_type'])) # set entries to zero where the object type is 1 or 3 LUT[object_type == 1] = 0 LUT[object_type == 3] = 0 branch_label_image = LUT[self.data] # measure the length of each branch branch_lengths = nts.calculate_branch_lengths(branch_label_image) # merge into features self.features = pd.merge(self.features, branch_lengths, on='label', how='left') self.features['branch_length'] = self.features['branch_length'].fillna(0) # update graph edge weights with branch lengths for u, v, data in self.graph.edges(data=True): data['branch_length'] = self.features[self.features['label'] == data['label']]['branch_length'].values[0] def _find_neighboring_labels(self, query_label: int): from skimage import morphology branch = self.data == query_label branch_point_coordinates = np.asarray(np.where(branch)).T # get bounding box around branch min_coords = branch_point_coordinates.min(axis=0) - 1 max_coords = branch_point_coordinates.max(axis=0) + 2 # check if min/max values exceed array dimensions min_coords[min_coords < 0] = 0 for i in range(len(max_coords)): if max_coords[i] > branch.shape[i]: max_coords[i] = branch.shape[i] # Create a tuple of slices for each dimension slices = tuple( slice(min_coord, max_coord) for min_coord, max_coord in zip(min_coords, max_coords) ) # Crop the image data using the slices, # then expand the branch points to overlap with the neighboring objects cropped_branch = branch[slices] cropped_branch = morphology.binary_dilation(cropped_branch, footprint=np.ones((3, 3))) cropped_data = self.data[slices] touching_labels = np.logical_xor(cropped_branch, cropped_data == query_label) * cropped_data connecting_labels = np.unique(touching_labels) connecting_labels = connecting_labels[connecting_labels != 0] return connecting_labels def _detect_spines(self): """ Detect spines in the skeleton graph. The spine is defined as the longest path between two degree 1 nodes in the skeleton graph. Returns: -------- None """ import tqdm graph = self.graph.copy() self.features['spine'] = 0 # split in connected components connected_components = list(nx.connected_components(graph)) # find degree 1 nodes in connected components for idx, component in tqdm.tqdm(enumerate(connected_components), total=len(connected_components), desc='Finding spines'): spine_nodes = [] for node in component: if graph.degree(node) == 1: spine_nodes.append(int(node)) if len(spine_nodes) < 2: continue # measure distances between every pair of spine nodes spine_distances = [] for i, spine_node in enumerate(spine_nodes): for j in range(i+1, len(spine_nodes)): path = [int(i) for i in nx.shortest_path(graph, source=spine_node, target=spine_nodes[j])] # measure distance as per the edge attrbitue 'branch_length' distance = 0 edge_labels = [] for k in range(len(path) - 1): distance += graph[path[k]][path[k+1]]['branch_length'] edge_labels.append(graph[path[k]][path[k+1]]['label']) # for k in range(len(path) - 1): # distance += graph[path[k]][path[k+1]]['branch_length'] spine_distances.append({ 'spine_1': spine_node, 'spine_2': spine_nodes[j], 'distance': distance, 'edge_labels': [int(i) for i in edge_labels] }) # sort by distance spine_distances = pd.DataFrame(spine_distances) spine_distances = spine_distances.sort_values(by='distance') # set spine attribute to 1 for nodes and edges in shortest path longest_shortest_path = spine_distances.iloc[-1] #graph.edges[shortest_path['spine_1'], shortest_path['spine_2']]['spine'] = 1 for label in longest_shortest_path['edge_labels']: self.features.loc[self.features['label'] == label, 'spine'] = 1 # make spine column categorical self.features['spine'] = self.features['spine'].astype('category') def _graph_summary(self): """ Calculate summary features of the skeleton graph. The following features are calculated: - n_branches: The number of branches in the skeleton. - n_endpoints: The number of endpoints in the skeleton. - n_branch_points: The number of branch points in the skeleton. - n_nodes: The number of nodes in the skeleton, which is the sum of branch points and endpoints. - n_cycle_basis: The number of cycle basis in the skeleton. A cycle basis is a set of cycles that can be used to generate all other cycles in a graph. - n_possible_undirected_cycles: The number of possible undirected cycles in the skeleton. An undirected cycle is a path that starts and ends at the same node and visits each node only once (except for the starting node). Returns: -------- None """ import tqdm # get subgraphs connected_components = [self.graph.subgraph(c).copy() for c in nx.connected_components(self.graph)] self.features['n_branches'] = 0 self.features['n_endpoints'] = 0 self.features['n_branch_points'] = 0 self.features['n_nodes'] = 0 for component in tqdm.tqdm(connected_components, desc='Calculating summary features', total=len(connected_components)): n_branches = len(component.edges) n_endpoints = len([node for node in component.nodes if component.degree(node) == 1]) n_branchpoints = len([node for node in component.nodes if component.degree(node) > 2]) # cycle basis directed_graph = component.to_directed() possible_directed_cycles = list(nx.simple_cycles(directed_graph)) n_cycle_basis = len(nx.cycle_basis(component)) n_possible_undirected_cycles = len( [x for x in possible_directed_cycles if len(x) > 2])//2 # add to features # get label property of all edges and nodes edge_labels = [component[u][v]['label'] for u, v in component.edges] node_labels = [component.nodes[node]['label'] for node in component.nodes] labels = edge_labels + node_labels self.features.loc[self.features['label'].isin(labels), 'n_branches'] = n_branches self.features.loc[self.features['label'].isin(labels), 'n_endpoints'] = n_endpoints self.features.loc[self.features['label'].isin(labels), 'n_branch_points'] = n_branchpoints self.features.loc[self.features['label'].isin(labels), 'n_cycle_basis'] = n_cycle_basis self.features.loc[self.features['label'].isin(labels), 'n_possible_undirected_cycles'] = n_possible_undirected_cycles self.features.loc[self.features['label'].isin(labels), 'n_nodes'] = n_branchpoints + n_endpoints
def analyze_skeleton_comprehensive( labels_input: Labels, neighborhood: str = 'n8', viewer: 'napari.Viewer' = None) -> Labels: """ Run a complete Skeleton analysis using napari-toska Parameters: ----------- labels_data: """ from napari_skimage_regionprops import TableWidget Skeleton = ToskaSkeleton(labels_data=labels_input.data, neighborhood=neighborhood) Skeleton.analyze() Skeleton.name = f'Skeleton of {labels_input.name}' if viewer is not None: table_widget = TableWidget(viewer=viewer, layer=Skeleton) viewer.window.add_dock_widget(table_widget) return Skeleton