Runs NER prediction on the input data.
This method will return one of 3 things:
- If timings_only=True kwarg is passed in -> a single Timings object.
E.g. predict(data, timings_only=True)
- if in_data is a str, dict or JSONRecord -> a list with a single result.
- if in_data is a list -> a list of results, of the same length as in_data.
In a case where item from in_data had no entities in it, the result will be an empty list.
E.g. result = [[], [NERPrediction(), NERPrediction()], [], []].
In this example, there were 4 records in in_data, 3 of them had no entities,
and the second one had 2.
Source code in src/nemo_safe_synthesizer/pii_replacer/ner/ner_mp.py
| def predict(self, in_data: InData, **kwargs) -> Timings | PipelineResult:
"""
Runs NER prediction on the input data.
This method will return one of 3 things:
- If `timings_only=True` kwarg is passed in -> a single `Timings` object.
E.g. `predict(data, timings_only=True)`
- if `in_data` is a str, dict or JSONRecord -> a list with a single result.
- if `in_data` is a list -> a list of results, of the same length as `in_data`.
In a case where item from `in_data` had no entities in it, the result will be an empty list.
E.g. `result = [[], [NERPrediction(), NERPrediction()], [], []]`.
In this example, there were 4 records in `in_data`, 3 of them had no entities,
and the second one had 2.
"""
if "pipeline" in kwargs:
raise ValueError("A pipeline object cannot be passed to predict() in MP mode!")
logger.info(f"Starting NER prediction, using {self.num_proc} workers.")
# _futures = []
timings_only = kwargs.get("timings_only", False)
list_input = isinstance(in_data, list)
if list_input:
if len(in_data) > 0 and isinstance(in_data[0], JSONRecord):
# Send pure dicts to NER, as it's much faster to pickle/unpickle
# pure dicts that JSONRecords (and that's what multiprocessing is doing)
in_data = [record.original for record in in_data]
record_chunks = iter_record_chunks(iter(in_data), CHUNK_SIZE)
else:
# We need to handle the case where a non-list is sent in
# as our prediction object and we need to makesure the payload
# we send to the worker is the raw input, not a list
record_chunks = [in_data]
result_data_tracker = _ResultData()
total_chunks = 0
submitted_chunks = []
for i, p in enumerate(record_chunks):
result_data_tracker.lock.acquire()
logger.info(f"Submitting chunk number {i + 1} to NER workers.")
data = list(p) if list_input else p
payload = _ProcPayload(seq=i, in_data=data)
submitted_chunks.append(_ChunkInfo(seq=i, chunk_size=len(data) if list_input else 1))
self.pool.submit(_predict, payload, **kwargs).add_done_callback(result_data_tracker.handle_results)
logger.info(f"Chunk number {i + 1} has been submitted successfully.")
total_chunks += 1
# _futures.append(future)
logger.info("All chunks submitted, waiting for work to complete...")
# When we get here, all chunks have been submitted, and we just
# want to wait for them to finish up before moving on
start = time.time()
while len(result_data_tracker.results) < total_chunks:
time_passed_seconds = time.time() - start
if self.ner_max_runtime_seconds is not None and self.ner_max_runtime_seconds < time_passed_seconds:
logger.error(
"NER took more than %d seconds to finish, ending early",
self.ner_max_runtime_seconds,
)
self.pool.shutdown(wait=False, kill_workers=True)
self._initialize_pool() # in case it's reused
break
result_data_tracker.progress_callback.flush()
result_data = result_data_tracker.results
logger.info("NER prediction completed.")
if timings_only:
result_timings = iter(result_data)
timings = next(result_timings).out_data
for other_timings in result_timings:
timings.join(other_timings.out_data)
timings.set_avg(num_cpu=self.num_proc)
return timings
completed_chunks: dict[int, _ProcPayload] = {r.seq: r for r in result_data}
all_chunks: dict[int, _ChunkInfo] = {ci.seq: ci for ci in submitted_chunks}
# Restore the predictions to the order they would
# have been if predicting on a single worker
preds = []
for seq in sorted(list(all_chunks.keys())):
if (payload := completed_chunks.get(seq, None)) is not None:
preds.extend(payload.out_data)
else:
# Add an empty spot for each record in the chunk that wasn't completed
logger.warning(f"NER for chunk number {seq + 1} did not complete.")
preds.extend([[]] * all_chunks[seq].chunk_size)
return preds
|