Java源码示例:jcuda.jcudnn.JCudnn
示例1
/**
* Performs an "softmax" operation on a matrix on the GPU
* @param ec execution context
* @param gCtx a valid {@link GPUContext}
* @param instName the invoking instruction's name for record {@link Statistics}.
* @param in1 input matrix
* @param outputName output matrix name
*/
public static void softmax(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String outputName) {
if(LOG.isTraceEnabled()) {
LOG.trace("GPU : softmax" + ", GPUContext=" + gCtx);
}
cudnnTensorDescriptor tensorDesc = allocateTensorDescriptor(toInt(in1.getNumRows()), toInt(in1.getNumColumns()), 1, 1);
Pointer srcPointer = getDensePointerForCuDNN(gCtx, in1, instName);
MatrixObject out = ec.getMatrixObject(outputName);
ec.allocateGPUMatrixObject(outputName, in1.getNumRows(), in1.getNumColumns());
out.getGPUObject(gCtx).allocateAndFillDense(0);
Pointer dstPointer = getDensePointerForCuDNN(gCtx, out, instName);
JCudnn.cudnnSoftmaxForward(gCtx.getCudnnHandle(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL,
one(), tensorDesc, srcPointer,
zero(), tensorDesc, dstPointer);
cudnnDestroyTensorDescriptor(tensorDesc);
}
示例2
/**
* Performs an "softmax" operation on a matrix on the GPU
* @param ec execution context
* @param gCtx a valid {@link GPUContext}
* @param instName the invoking instruction's name for record {@link Statistics}.
* @param in1 input matrix
* @param outputName output matrix name
*/
public static void softmax(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String outputName) {
if(LOG.isTraceEnabled()) {
LOG.trace("GPU : softmax" + ", GPUContext=" + gCtx);
}
cudnnTensorDescriptor tensorDesc = allocateTensorDescriptor(toInt(in1.getNumRows()), toInt(in1.getNumColumns()), 1, 1);
Pointer srcPointer = getDensePointerForCuDNN(gCtx, in1, instName);
MatrixObject out = ec.getMatrixObject(outputName);
ec.allocateGPUMatrixObject(outputName, in1.getNumRows(), in1.getNumColumns());
out.getGPUObject(gCtx).allocateAndFillDense(0);
Pointer dstPointer = getDensePointerForCuDNN(gCtx, out, instName);
JCudnn.cudnnSoftmaxForward(gCtx.getCudnnHandle(), CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL,
one(), tensorDesc, srcPointer,
zero(), tensorDesc, dstPointer);
cudnnDestroyTensorDescriptor(tensorDesc);
}
示例3
private static void singleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, String instName,
Pointer x, Pointer hx, Pointer cx, Pointer wPointer, // input
String outputName, String cyName, // output
String rnnMode, boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException {
boolean hasCarry = rnnMode.equalsIgnoreCase("lstm");
// Get output pointers
Pointer cudnnYPointer = gCtx.allocate(instName, N*T*M*sizeOfDataType);
Pointer hyPointer = !return_sequences ? getDenseOutputPointer(ec, gCtx, instName, outputName, N, M) : gCtx.allocate(instName, N*M*sizeOfDataType);
Pointer cyPointer = hasCarry ? getDenseOutputPointer(ec, gCtx, instName, cyName, N, M) : new Pointer();
// Pointer wPointer = getDensePointerForCuDNN(gCtx, w, instName, D+M+2, 4*M);
try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, wPointer)) {
JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.xDesc, x,
algo.hxDesc, hx,
algo.cxDesc, cx,
algo.wDesc, wPointer,
algo.yDesc, cudnnYPointer,
algo.hyDesc, hyPointer,
algo.cyDesc, cyPointer,
algo.workSpace, algo.sizeInBytes,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
}
if(return_sequences) {
gCtx.cudaFreeHelper(instName, hyPointer, DMLScript.EAGER_CUDA_FREE);
Pointer sysdsYPointer = getDenseOutputPointer(ec, gCtx, instName, outputName, N, T*M);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_output",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*M),
sysdsYPointer, cudnnYPointer, N, T, M, N*T*M);
}
gCtx.cudaFreeHelper(instName, cudnnYPointer, DMLScript.EAGER_CUDA_FREE);
}
示例4
public static void main(String args[])
{
JCuda.setExceptionsEnabled(true);
JCudnn.setExceptionsEnabled(true);
JCublas2.setExceptionsEnabled(true);
int version = (int) cudnnGetVersion();
System.out.printf("cudnnGetVersion() : %d , " +
"CUDNN_VERSION from cudnn.h : %d\n",
version, CUDNN_VERSION);
System.out.println("Creating network and layers...");
Network mnist = new Network();
System.out.println("Classifying...");
int i1 = mnist.classifyExample(dataDirectory + first_image);
int i2 = mnist.classifyExample(dataDirectory + second_image);
mnist.setConvolutionAlgorithm(CUDNN_CONVOLUTION_FWD_ALGO_FFT);
int i3 = mnist.classifyExample(dataDirectory + third_image);
System.out.println(
"\nResult of classification: " + i1 + " " + i2 + " " + i3);
if (i1 != 1 || i2 != 3 || i3 != 5)
{
System.out.println("\nTest failed!\n");
}
else
{
System.out.println("\nTest passed!\n");
}
mnist.destroy();
}
示例5
private static void singleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, String instName,
Pointer x, Pointer hx, Pointer cx, Pointer wPointer, // input
String outputName, String cyName, // output
String rnnMode, boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException {
boolean hasCarry = rnnMode.equalsIgnoreCase("lstm");
// Get output pointers
Pointer cudnnYPointer = gCtx.allocate(instName, N*T*M*sizeOfDataType);
Pointer hyPointer = !return_sequences ? getDenseOutputPointer(ec, gCtx, instName, outputName, N, M) : gCtx.allocate(instName, N*M*sizeOfDataType);
Pointer cyPointer = hasCarry ? getDenseOutputPointer(ec, gCtx, instName, cyName, N, M) : new Pointer();
// Pointer wPointer = getDensePointerForCuDNN(gCtx, w, instName, D+M+2, 4*M);
try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, wPointer)) {
JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.xDesc, x,
algo.hxDesc, hx,
algo.cxDesc, cx,
algo.wDesc, wPointer,
algo.yDesc, cudnnYPointer,
algo.hyDesc, hyPointer,
algo.cyDesc, cyPointer,
algo.workSpace, algo.sizeInBytes,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
}
if(return_sequences) {
gCtx.cudaFreeHelper(instName, hyPointer, DMLScript.EAGER_CUDA_FREE);
Pointer sysdsYPointer = getDenseOutputPointer(ec, gCtx, instName, outputName, N, T*M);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_output",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*M),
sysdsYPointer, cudnnYPointer, N, T, M, N*T*M);
}
gCtx.cudaFreeHelper(instName, cudnnYPointer, DMLScript.EAGER_CUDA_FREE);
}
示例6
public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, String instName,
String rnnMode, int N, int T, int M, int D, boolean isTraining, Pointer w) throws DMLRuntimeException {
this.gCtx = gCtx;
this.instName = instName;
// Allocate input/output descriptors
xDesc = new cudnnTensorDescriptor[T];
dxDesc = new cudnnTensorDescriptor[T];
yDesc = new cudnnTensorDescriptor[T];
dyDesc = new cudnnTensorDescriptor[T];
for(int t = 0; t < T; t++) {
xDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
dxDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
yDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
dyDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
}
hxDesc = allocateTensorDescriptorWithStride(1, N, M);
dhxDesc = allocateTensorDescriptorWithStride(1, N, M);
cxDesc = allocateTensorDescriptorWithStride(1, N, M);
dcxDesc = allocateTensorDescriptorWithStride(1, N, M);
hyDesc = allocateTensorDescriptorWithStride(1, N, M);
dhyDesc = allocateTensorDescriptorWithStride(1, N, M);
cyDesc = allocateTensorDescriptorWithStride(1, N, M);
dcyDesc = allocateTensorDescriptorWithStride(1, N, M);
// Initial dropout descriptor
dropoutDesc = new cudnnDropoutDescriptor();
JCudnn.cudnnCreateDropoutDescriptor(dropoutDesc);
long [] _dropOutSizeInBytes = {-1};
JCudnn.cudnnDropoutGetStatesSize(gCtx.getCudnnHandle(), _dropOutSizeInBytes);
dropOutSizeInBytes = _dropOutSizeInBytes[0];
dropOutStateSpace = new Pointer();
if (dropOutSizeInBytes != 0)
dropOutStateSpace = gCtx.allocate(instName, dropOutSizeInBytes);
JCudnn.cudnnSetDropoutDescriptor(dropoutDesc, gCtx.getCudnnHandle(), 0, dropOutStateSpace, dropOutSizeInBytes, 12345);
// Initialize RNN descriptor
rnnDesc = new cudnnRNNDescriptor();
cudnnCreateRNNDescriptor(rnnDesc);
JCudnn.cudnnSetRNNDescriptor_v6(gCtx.getCudnnHandle(), rnnDesc, M, 1, dropoutDesc,
CUDNN_LINEAR_INPUT, CUDNN_UNIDIRECTIONAL,
getCuDNNRnnMode(rnnMode), CUDNN_RNN_ALGO_STANDARD, LibMatrixCUDA.CUDNN_DATA_TYPE);
// Allocate filter descriptor
int expectedNumWeights = getExpectedNumWeights();
if(rnnMode.equalsIgnoreCase("lstm") && (D+M+2)*4*M != expectedNumWeights) {
throw new DMLRuntimeException("Incorrect number of RNN parameters " + (D+M+2)*4*M + " != " + expectedNumWeights + ", where numFeatures=" + D + ", hiddenSize=" + M);
}
wDesc = allocateFilterDescriptor(expectedNumWeights);
dwDesc = allocateFilterDescriptor(expectedNumWeights);
// Setup workspace
workSpace = new Pointer(); reserveSpace = new Pointer();
sizeInBytes = getWorkspaceSize(T);
if(sizeInBytes != 0)
workSpace = gCtx.allocate(instName, sizeInBytes);
reserveSpaceSizeInBytes = 0;
if(isTraining) {
reserveSpaceSizeInBytes = getReservespaceSize(T);
if (reserveSpaceSizeInBytes != 0) {
reserveSpace = gCtx.allocate(instName, reserveSpaceSizeInBytes);
}
}
}
示例7
private long getWorkspaceSize(int seqLength) {
long [] sizeInBytesArray = new long[1];
JCudnn.cudnnGetRNNWorkspaceSize(gCtx.getCudnnHandle(), rnnDesc, seqLength, xDesc, sizeInBytesArray);
return sizeInBytesArray[0];
}
示例8
private long getReservespaceSize(int seqLength) {
long [] sizeInBytesArray = new long[1];
JCudnn.cudnnGetRNNTrainingReserveSize(gCtx.getCudnnHandle(), rnnDesc, seqLength, xDesc, sizeInBytesArray);
return sizeInBytesArray[0];
}
示例9
private int getExpectedNumWeights() throws DMLRuntimeException {
long [] weightSizeInBytesArray = {-1}; // (D+M+2)*4*M
JCudnn.cudnnGetRNNParamsSize(gCtx.getCudnnHandle(), rnnDesc, xDesc[0], weightSizeInBytesArray, LibMatrixCUDA.CUDNN_DATA_TYPE);
// check if (D+M+2)*4M == weightsSize / sizeof(dataType) where weightsSize is given by 'cudnnGetRNNParamsSize'.
return LibMatrixCUDA.toInt(weightSizeInBytesArray[0]/LibMatrixCUDA.sizeOfDataType);
}
示例10
private static cudnnFilterDescriptor allocateFilterDescriptor(int numWeights) {
cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
cudnnCreateFilterDescriptor(filterDesc);
JCudnn.cudnnSetFilterNdDescriptor(filterDesc, LibMatrixCUDA.CUDNN_DATA_TYPE, CUDNN_TENSOR_NCHW, 3, new int[] {numWeights, 1, 1});
return filterDesc;
}
示例11
public static void lstmBackward(ExecutionContext ec, GPUContext gCtx, String instName,
Pointer x, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input
String dxName, String dwName, String dbName, String dhxName, String dcxName, // output
boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException {
// Transform the input dout and prepare them for cudnnRNNBackwardData
Pointer dy = gCtx.allocate(instName, N*T*M*sizeOfDataType);
int size = return_sequences ? N*T*M : N*M;
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_backward_gradients",
ExecutionConfig.getConfigForSimpleVectorOperations(size),
getDenseInputPointer(ec, gCtx, instName, doutName, N, return_sequences ? T*M : M),
dy, N, T, M, size, return_sequences ? 1 : 0);
ec.releaseMatrixInputForGPUInstruction(doutName);
// Allocate intermediate pointers computed by forward
Pointer yPointer = gCtx.allocate(instName, N*T*M*sizeOfDataType);
try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", N, T, M, D, true, wPointer)) {
JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.xDesc, x,
algo.hxDesc, hx,
algo.cxDesc, cx,
algo.wDesc, wPointer,
algo.yDesc, yPointer,
algo.hyDesc, new Pointer(),
algo.cyDesc, new Pointer(),
algo.workSpace, algo.sizeInBytes,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
Pointer cudnnDx = gCtx.allocate(instName, N*T*D*LibMatrixCUDA.sizeOfDataType);
JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.yDesc, yPointer,
// ----------------------
// Additional inputs:
algo.dyDesc, dy,
algo.dhyDesc, new Pointer(),
algo.dcyDesc, getDenseInputPointer(ec, gCtx, instName, dcyName, N, M),
// ----------------------
algo.wDesc, wPointer,
algo.hxDesc, hx,
algo.cxDesc, cx,
// ----------------------
// Output:
algo.dxDesc, cudnnDx,
algo.dhxDesc, getDenseOutputPointer(ec, gCtx, instName, dhxName, N, M),
algo.dcxDesc, getDenseOutputPointer(ec, gCtx, instName, dcxName, N, M),
// ----------------------
algo.workSpace, algo.sizeInBytes,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
gCtx.cudaFreeHelper(instName, dy, DMLScript.EAGER_CUDA_FREE);
ec.releaseMatrixInputForGPUInstruction(dcyName);
ec.releaseMatrixOutputForGPUInstruction(dhxName);
ec.releaseMatrixOutputForGPUInstruction(dcxName);
Pointer smlDx = getDenseOutputPointer(ec, gCtx, instName, dxName, N, T*D);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
smlDx, cudnnDx, N, D, T*D, N*T*D);
ec.releaseMatrixOutputForGPUInstruction(dxName);
gCtx.cudaFreeHelper(instName, cudnnDx, DMLScript.EAGER_CUDA_FREE);
// -------------------------------------------------------------------------------------------
Pointer cudnnDwPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.xDesc, x,
algo.hxDesc, hx,
algo.yDesc, yPointer,
algo.workSpace, algo.sizeInBytes,
algo.dwDesc, cudnnDwPointer,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight",
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
getDenseOutputPointer(ec, gCtx, instName, dwName, D+M, 4*M),
getDenseOutputPointer(ec, gCtx, instName, dbName, 1, 4*M), cudnnDwPointer, D, M);
gCtx.cudaFreeHelper(instName, cudnnDwPointer, DMLScript.EAGER_CUDA_FREE);
ec.releaseMatrixOutputForGPUInstruction(dwName);
ec.releaseMatrixOutputForGPUInstruction(dbName);
// -------------------------------------------------------------------------------------------
gCtx.cudaFreeHelper(instName, yPointer, DMLScript.EAGER_CUDA_FREE);
}
}
示例12
public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, String instName,
String rnnMode, int N, int T, int M, int D, boolean isTraining, Pointer w) throws DMLRuntimeException {
this.gCtx = gCtx;
this.instName = instName;
// Allocate input/output descriptors
xDesc = new cudnnTensorDescriptor[T];
dxDesc = new cudnnTensorDescriptor[T];
yDesc = new cudnnTensorDescriptor[T];
dyDesc = new cudnnTensorDescriptor[T];
for(int t = 0; t < T; t++) {
xDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
dxDesc[t] = allocateTensorDescriptorWithStride(N, D, 1);
yDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
dyDesc[t] = allocateTensorDescriptorWithStride(N, M, 1);
}
hxDesc = allocateTensorDescriptorWithStride(1, N, M);
dhxDesc = allocateTensorDescriptorWithStride(1, N, M);
cxDesc = allocateTensorDescriptorWithStride(1, N, M);
dcxDesc = allocateTensorDescriptorWithStride(1, N, M);
hyDesc = allocateTensorDescriptorWithStride(1, N, M);
dhyDesc = allocateTensorDescriptorWithStride(1, N, M);
cyDesc = allocateTensorDescriptorWithStride(1, N, M);
dcyDesc = allocateTensorDescriptorWithStride(1, N, M);
// Initial dropout descriptor
dropoutDesc = new cudnnDropoutDescriptor();
JCudnn.cudnnCreateDropoutDescriptor(dropoutDesc);
long [] _dropOutSizeInBytes = {-1};
JCudnn.cudnnDropoutGetStatesSize(gCtx.getCudnnHandle(), _dropOutSizeInBytes);
dropOutSizeInBytes = _dropOutSizeInBytes[0];
dropOutStateSpace = new Pointer();
if (dropOutSizeInBytes != 0)
dropOutStateSpace = gCtx.allocate(instName, dropOutSizeInBytes);
JCudnn.cudnnSetDropoutDescriptor(dropoutDesc, gCtx.getCudnnHandle(), 0, dropOutStateSpace, dropOutSizeInBytes, 12345);
// Initialize RNN descriptor
rnnDesc = new cudnnRNNDescriptor();
cudnnCreateRNNDescriptor(rnnDesc);
JCudnn.cudnnSetRNNDescriptor_v6(gCtx.getCudnnHandle(), rnnDesc, M, 1, dropoutDesc,
CUDNN_LINEAR_INPUT, CUDNN_UNIDIRECTIONAL,
getCuDNNRnnMode(rnnMode), CUDNN_RNN_ALGO_STANDARD, LibMatrixCUDA.CUDNN_DATA_TYPE);
// Allocate filter descriptor
int expectedNumWeights = getExpectedNumWeights();
if(rnnMode.equalsIgnoreCase("lstm") && (D+M+2)*4*M != expectedNumWeights) {
throw new DMLRuntimeException("Incorrect number of RNN parameters " + (D+M+2)*4*M + " != " + expectedNumWeights + ", where numFeatures=" + D + ", hiddenSize=" + M);
}
wDesc = allocateFilterDescriptor(expectedNumWeights);
dwDesc = allocateFilterDescriptor(expectedNumWeights);
// Setup workspace
workSpace = new Pointer(); reserveSpace = new Pointer();
sizeInBytes = getWorkspaceSize(T);
if(sizeInBytes != 0)
workSpace = gCtx.allocate(instName, sizeInBytes);
reserveSpaceSizeInBytes = 0;
if(isTraining) {
reserveSpaceSizeInBytes = getReservespaceSize(T);
if (reserveSpaceSizeInBytes != 0) {
reserveSpace = gCtx.allocate(instName, reserveSpaceSizeInBytes);
}
}
}
示例13
private long getWorkspaceSize(int seqLength) {
long [] sizeInBytesArray = new long[1];
JCudnn.cudnnGetRNNWorkspaceSize(gCtx.getCudnnHandle(), rnnDesc, seqLength, xDesc, sizeInBytesArray);
return sizeInBytesArray[0];
}
示例14
private long getReservespaceSize(int seqLength) {
long [] sizeInBytesArray = new long[1];
JCudnn.cudnnGetRNNTrainingReserveSize(gCtx.getCudnnHandle(), rnnDesc, seqLength, xDesc, sizeInBytesArray);
return sizeInBytesArray[0];
}
示例15
private int getExpectedNumWeights() throws DMLRuntimeException {
long [] weightSizeInBytesArray = {-1}; // (D+M+2)*4*M
JCudnn.cudnnGetRNNParamsSize(gCtx.getCudnnHandle(), rnnDesc, xDesc[0], weightSizeInBytesArray, LibMatrixCUDA.CUDNN_DATA_TYPE);
// check if (D+M+2)*4M == weightsSize / sizeof(dataType) where weightsSize is given by 'cudnnGetRNNParamsSize'.
return LibMatrixCUDA.toInt(weightSizeInBytesArray[0]/LibMatrixCUDA.sizeOfDataType);
}
示例16
private static cudnnFilterDescriptor allocateFilterDescriptor(int numWeights) {
cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
cudnnCreateFilterDescriptor(filterDesc);
JCudnn.cudnnSetFilterNdDescriptor(filterDesc, LibMatrixCUDA.CUDNN_DATA_TYPE, CUDNN_TENSOR_NCHW, 3, new int[] {numWeights, 1, 1});
return filterDesc;
}
示例17
public static void lstmBackward(ExecutionContext ec, GPUContext gCtx, String instName,
Pointer x, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input
String dxName, String dwName, String dbName, String dhxName, String dcxName, // output
boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException {
// Transform the input dout and prepare them for cudnnRNNBackwardData
Pointer dy = gCtx.allocate(instName, N*T*M*sizeOfDataType);
int size = return_sequences ? N*T*M : N*M;
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_backward_gradients",
ExecutionConfig.getConfigForSimpleVectorOperations(size),
getDenseInputPointer(ec, gCtx, instName, doutName, N, return_sequences ? T*M : M),
dy, N, T, M, size, return_sequences ? 1 : 0);
ec.releaseMatrixInputForGPUInstruction(doutName);
// Allocate intermediate pointers computed by forward
Pointer yPointer = gCtx.allocate(instName, N*T*M*sizeOfDataType);
try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", N, T, M, D, true, wPointer)) {
JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.xDesc, x,
algo.hxDesc, hx,
algo.cxDesc, cx,
algo.wDesc, wPointer,
algo.yDesc, yPointer,
algo.hyDesc, new Pointer(),
algo.cyDesc, new Pointer(),
algo.workSpace, algo.sizeInBytes,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
Pointer cudnnDx = gCtx.allocate(instName, N*T*D*LibMatrixCUDA.sizeOfDataType);
JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.yDesc, yPointer,
// ----------------------
// Additional inputs:
algo.dyDesc, dy,
algo.dhyDesc, new Pointer(),
algo.dcyDesc, getDenseInputPointer(ec, gCtx, instName, dcyName, N, M),
// ----------------------
algo.wDesc, wPointer,
algo.hxDesc, hx,
algo.cxDesc, cx,
// ----------------------
// Output:
algo.dxDesc, cudnnDx,
algo.dhxDesc, getDenseOutputPointer(ec, gCtx, instName, dhxName, N, M),
algo.dcxDesc, getDenseOutputPointer(ec, gCtx, instName, dcxName, N, M),
// ----------------------
algo.workSpace, algo.sizeInBytes,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
gCtx.cudaFreeHelper(instName, dy, DMLScript.EAGER_CUDA_FREE);
ec.releaseMatrixInputForGPUInstruction(dcyName);
ec.releaseMatrixOutputForGPUInstruction(dhxName);
ec.releaseMatrixOutputForGPUInstruction(dcxName);
Pointer smlDx = getDenseOutputPointer(ec, gCtx, instName, dxName, N, T*D);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
smlDx, cudnnDx, N, D, T*D, N*T*D);
ec.releaseMatrixOutputForGPUInstruction(dxName);
gCtx.cudaFreeHelper(instName, cudnnDx, DMLScript.EAGER_CUDA_FREE);
// -------------------------------------------------------------------------------------------
Pointer cudnnDwPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), algo.rnnDesc, T,
algo.xDesc, x,
algo.hxDesc, hx,
algo.yDesc, yPointer,
algo.workSpace, algo.sizeInBytes,
algo.dwDesc, cudnnDwPointer,
algo.reserveSpace, algo.reserveSpaceSizeInBytes);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight",
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
getDenseOutputPointer(ec, gCtx, instName, dwName, D+M, 4*M),
getDenseOutputPointer(ec, gCtx, instName, dbName, 1, 4*M), cudnnDwPointer, D, M);
gCtx.cudaFreeHelper(instName, cudnnDwPointer, DMLScript.EAGER_CUDA_FREE);
ec.releaseMatrixOutputForGPUInstruction(dwName);
ec.releaseMatrixOutputForGPUInstruction(dbName);
// -------------------------------------------------------------------------------------------
gCtx.cudaFreeHelper(instName, yPointer, DMLScript.EAGER_CUDA_FREE);
}
}