Skip to content

ner_mp

ner_mp

Classes:

Name Description
NERParallel

NERParallel(pipeline_factory, *, num_proc=None, ner_max_runtime_seconds=None)

Methods:

Name Description
predict

Runs NER prediction on the input data.

Source code in src/nemo_safe_synthesizer/pii_replacer/ner/ner_mp.py
def __init__(
    self,
    pipeline_factory: Callable[[], pipeline.Pipeline],
    *,
    num_proc: Optional[int] = None,
    ner_max_runtime_seconds: Optional[int] = None,
):
    if num_proc is None:
        num_proc = mp.cpu_count()
    self.num_proc = num_proc
    self.pipeline_factory = pipeline_factory
    self._initialize_pool()
    self.ner_max_runtime_seconds = ner_max_runtime_seconds

predict(in_data, **kwargs)

Runs NER prediction on the input data.

This method will return one of 3 things: - If timings_only=True kwarg is passed in -> a single Timings object. E.g. predict(data, timings_only=True) - if in_data is a str, dict or JSONRecord -> a list with a single result. - if in_data is a list -> a list of results, of the same length as in_data. In a case where item from in_data had no entities in it, the result will be an empty list. E.g. result = [[], [NERPrediction(), NERPrediction()], [], []]. In this example, there were 4 records in in_data, 3 of them had no entities, and the second one had 2.

Source code in src/nemo_safe_synthesizer/pii_replacer/ner/ner_mp.py
def predict(self, in_data: InData, **kwargs) -> Timings | PipelineResult:
    """
    Runs NER prediction on the input data.

    This method will return one of 3 things:
    - If `timings_only=True` kwarg is passed in -> a single `Timings` object.
      E.g. `predict(data, timings_only=True)`
    - if `in_data` is a str, dict or JSONRecord -> a list with a single result.
    - if `in_data` is a list -> a list of results, of the same length as `in_data`.
      In a case where item from `in_data` had no entities in it, the result will be an empty list.
      E.g. `result = [[], [NERPrediction(), NERPrediction()], [], []]`.
      In this example, there were 4 records in `in_data`, 3 of them had no entities,
      and the second one had 2.
    """
    if "pipeline" in kwargs:
        raise ValueError("A pipeline object cannot be passed to predict() in MP mode!")

    logger.info(f"Starting NER prediction, using {self.num_proc} workers.")

    # _futures = []

    timings_only = kwargs.get("timings_only", False)

    list_input = isinstance(in_data, list)
    if list_input:
        if len(in_data) > 0 and isinstance(in_data[0], JSONRecord):
            # Send pure dicts to NER, as it's much faster to pickle/unpickle
            # pure dicts that JSONRecords (and that's what multiprocessing is doing)
            in_data = [record.original for record in in_data]

        record_chunks = iter_record_chunks(iter(in_data), CHUNK_SIZE)
    else:
        # We need to handle the case where a non-list is sent in
        # as our prediction object and we need to makesure the payload
        # we send to the worker is the raw input, not a list
        record_chunks = [in_data]

    result_data_tracker = _ResultData()

    total_chunks = 0
    submitted_chunks = []
    for i, p in enumerate(record_chunks):
        result_data_tracker.lock.acquire()
        logger.info(f"Submitting chunk number {i + 1} to NER workers.")

        data = list(p) if list_input else p
        payload = _ProcPayload(seq=i, in_data=data)
        submitted_chunks.append(_ChunkInfo(seq=i, chunk_size=len(data) if list_input else 1))

        self.pool.submit(_predict, payload, **kwargs).add_done_callback(result_data_tracker.handle_results)

        logger.info(f"Chunk number {i + 1} has been submitted successfully.")
        total_chunks += 1
        # _futures.append(future)

    logger.info("All chunks submitted, waiting for work to complete...")

    # When we get here, all chunks have been submitted, and we just
    # want to wait for them to finish up before moving on
    start = time.time()
    while len(result_data_tracker.results) < total_chunks:
        time_passed_seconds = time.time() - start
        if self.ner_max_runtime_seconds is not None and self.ner_max_runtime_seconds < time_passed_seconds:
            logger.error(
                "NER took more than %d seconds to finish, ending early",
                self.ner_max_runtime_seconds,
            )
            self.pool.shutdown(wait=False, kill_workers=True)
            self._initialize_pool()  # in case it's reused
            break

    result_data_tracker.progress_callback.flush()
    result_data = result_data_tracker.results

    logger.info("NER prediction completed.")

    if timings_only:
        result_timings = iter(result_data)
        timings = next(result_timings).out_data
        for other_timings in result_timings:
            timings.join(other_timings.out_data)
        timings.set_avg(num_cpu=self.num_proc)
        return timings

    completed_chunks: dict[int, _ProcPayload] = {r.seq: r for r in result_data}
    all_chunks: dict[int, _ChunkInfo] = {ci.seq: ci for ci in submitted_chunks}

    # Restore the predictions to the order they would
    # have been if predicting on a single worker
    preds = []
    for seq in sorted(list(all_chunks.keys())):
        if (payload := completed_chunks.get(seq, None)) is not None:
            preds.extend(payload.out_data)

        else:
            # Add an empty spot for each record in the chunk that wasn't completed
            logger.warning(f"NER for chunk number {seq + 1} did not complete.")
            preds.extend([[]] * all_chunks[seq].chunk_size)

    return preds