diff options
Diffstat (limited to 'java/com/android/voicemail/impl/transcribe/TranscriptionTaskAsync.java')
-rw-r--r-- | java/com/android/voicemail/impl/transcribe/TranscriptionTaskAsync.java | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/java/com/android/voicemail/impl/transcribe/TranscriptionTaskAsync.java b/java/com/android/voicemail/impl/transcribe/TranscriptionTaskAsync.java new file mode 100644 index 000000000..3c41aef89 --- /dev/null +++ b/java/com/android/voicemail/impl/transcribe/TranscriptionTaskAsync.java @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ +package com.android.voicemail.impl.transcribe; + +import android.app.job.JobWorkItem; +import android.content.Context; +import android.util.Pair; +import com.android.dialer.common.Assert; +import com.android.dialer.logging.DialerImpression; +import com.android.voicemail.impl.VvmLog; +import com.android.voicemail.impl.transcribe.TranscriptionService.JobCallback; +import com.android.voicemail.impl.transcribe.grpc.GetTranscriptResponseAsync; +import com.android.voicemail.impl.transcribe.grpc.TranscriptionClientFactory; +import com.android.voicemail.impl.transcribe.grpc.TranscriptionResponseAsync; +import com.google.internal.communications.voicemailtranscription.v1.GetTranscriptRequest; +import com.google.internal.communications.voicemailtranscription.v1.TranscribeVoicemailAsyncRequest; +import com.google.internal.communications.voicemailtranscription.v1.TranscriptionStatus; + +/** + * Background task to get a voicemail transcription using the asynchronous API. The async API works + * as follows: + * + * <ol> + * <li>client uploads voicemail data to the server + * <li>server responds with a transcription-id and an estimated transcription wait time + * <li>client waits appropriate amount of time then begins polling for the result + * </ol> + * + * This implementation blocks until the response or an error is received, even though it is using + * the asynchronous server API. + */ +public class TranscriptionTaskAsync extends TranscriptionTask { + private static final String TAG = "TranscriptionTaskAsync"; + + public TranscriptionTaskAsync( + Context context, + JobCallback callback, + JobWorkItem workItem, + TranscriptionClientFactory clientFactory, + TranscriptionConfigProvider configProvider) { + super(context, callback, workItem, clientFactory, configProvider); + } + + @Override + protected Pair<String, TranscriptionStatus> getTranscription() { + VvmLog.i(TAG, "getTranscription"); + + TranscriptionResponseAsync uploadResponse = + (TranscriptionResponseAsync) + sendRequest((client) -> client.sendUploadRequest(getUploadRequest())); + + if (uploadResponse == null) { + VvmLog.i(TAG, "getTranscription, failed to upload voicemail."); + return new Pair<>(null, TranscriptionStatus.FAILED_NO_RETRY); + } else { + waitForTranscription(uploadResponse); + return pollForTranscription(uploadResponse); + } + } + + @Override + protected DialerImpression.Type getRequestSentImpression() { + return DialerImpression.Type.VVM_TRANSCRIPTION_REQUEST_SENT_ASYNC; + } + + private static void waitForTranscription(TranscriptionResponseAsync uploadResponse) { + long millis = uploadResponse.getEstimatedWaitMillis(); + VvmLog.i(TAG, "waitForTranscription, " + millis + " millis"); + sleep(millis); + } + + private Pair<String, TranscriptionStatus> pollForTranscription( + TranscriptionResponseAsync uploadResponse) { + VvmLog.i(TAG, "pollForTranscription"); + GetTranscriptRequest request = getGetTranscriptRequest(uploadResponse); + for (int i = 0; i < configProvider.getMaxGetTranscriptPolls(); i++) { + GetTranscriptResponseAsync response = + (GetTranscriptResponseAsync) + sendRequest((client) -> client.sendGetTranscriptRequest(request)); + if (response == null) { + VvmLog.i(TAG, "pollForTranscription, no transcription result."); + } else if (response.isTranscribing()) { + VvmLog.i(TAG, "pollForTranscription, poll count: " + (i + 1)); + } else if (response.hasFatalError()) { + VvmLog.i(TAG, "pollForTranscription, fail. " + response.getErrorDescription()); + return new Pair<>(null, response.getTranscriptionStatus()); + } else { + VvmLog.i(TAG, "pollForTranscription, got transcription"); + return new Pair<>(response.getTranscript(), TranscriptionStatus.SUCCESS); + } + sleep(configProvider.getGetTranscriptPollIntervalMillis()); + } + VvmLog.i(TAG, "pollForTranscription, timed out."); + return new Pair<>(null, TranscriptionStatus.FAILED_NO_RETRY); + } + + private TranscribeVoicemailAsyncRequest getUploadRequest() { + return TranscribeVoicemailAsyncRequest.newBuilder() + .setVoicemailData(audioData) + .setAudioFormat(encoding) + .build(); + } + + private GetTranscriptRequest getGetTranscriptRequest(TranscriptionResponseAsync uploadResponse) { + Assert.checkArgument(uploadResponse.getTranscriptionId() != null); + return GetTranscriptRequest.newBuilder() + .setTranscriptionId(uploadResponse.getTranscriptionId()) + .build(); + } +} |