diff --git a/src/main/java/graphql/execution/AsynchronousExecutionStrategy.java b/src/main/java/graphql/execution/AsynchronousExecutionStrategy.java new file mode 100644 index 0000000000..c7c12025b6 --- /dev/null +++ b/src/main/java/graphql/execution/AsynchronousExecutionStrategy.java @@ -0,0 +1,223 @@ +package graphql.execution; + +import graphql.ExecutionResult; +import graphql.ExecutionResultImpl; +import graphql.execution.instrumentation.Instrumentation; +import graphql.execution.instrumentation.InstrumentationContext; +import graphql.execution.instrumentation.parameters.FieldFetchParameters; +import graphql.execution.instrumentation.parameters.FieldParameters; +import graphql.language.Field; +import graphql.schema.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; + +import static graphql.execution.FieldCollectorParameters.newParameters; +import static graphql.execution.TypeInfo.newTypeInfo; + +public class AsynchronousExecutionStrategy extends ExecutionStrategy { + + private static final Logger log = LoggerFactory.getLogger(AsynchronousExecutionStrategy.class); + + @Override + public ExecutionResult execute(ExecutionContext executionContext, + ExecutionParameters parameters) throws NonNullableFieldWasNullException { + + Map> fields = parameters.fields(); + Map results = Collections.synchronizedMap(new HashMap<>()); + CompletionStage future = CompletableFuture.completedFuture(null); + + for (String fieldName : fields.keySet()) { + final List fieldList = fields.get(fieldName); + CompletionStage resolveFieldFuture = + resolveFieldAsync(executionContext, parameters, fieldList). + thenApplyAsync(executionResult -> { + if (executionResult != null) { + results.put(fieldName, executionResult.getData()); + } else { + results.put(fieldName, null); + } + return executionResult; + }); + future = future.thenCombineAsync(resolveFieldFuture,(t,executionResult)-> t); + } + + return new ExecutionResultImpl(future.thenApplyAsync(t -> results),executionContext.getErrors()); + } + + protected CompletionStage resolveFieldAsync(ExecutionContext executionContext, ExecutionParameters parameters, List fields) { + GraphQLObjectType type = parameters.typeInfo().castType(GraphQLObjectType.class); + GraphQLFieldDefinition + fieldDef = getFieldDef(executionContext.getGraphQLSchema(), type, fields.get(0)); + + Map argumentValues = valuesResolver.getArgumentValues(fieldDef.getArguments(), fields.get(0).getArguments(), executionContext.getVariables()); + + GraphQLOutputType fieldType = fieldDef.getType(); + DataFetchingFieldSelectionSet fieldCollector = DataFetchingFieldSelectionSetImpl + .newCollector(executionContext, fieldType, fields); + + DataFetchingEnvironment environment = new DataFetchingEnvironmentImpl( + parameters.source(), + argumentValues, + executionContext.getRoot(), + fields, + fieldType, + type, + executionContext.getGraphQLSchema(), + executionContext.getFragmentsByName(), + executionContext.getExecutionId(), + fieldCollector); + + Instrumentation instrumentation = executionContext.getInstrumentation(); + + InstrumentationContext + fieldCtx = instrumentation.beginField(new FieldParameters(executionContext, fieldDef, environment)); + + InstrumentationContext fetchCtx = instrumentation.beginFieldFetch(new FieldFetchParameters(executionContext, fieldDef, environment)); + Object resolvedValue = null; + + CompletableFuture dataFetcherResult = null; + try { + resolvedValue = fieldDef.getDataFetcher().get(environment); + + if(resolvedValue instanceof CompletionStage) { + dataFetcherResult = (CompletableFuture) resolvedValue; + } else { + dataFetcherResult = CompletableFuture.completedFuture(resolvedValue); + } + } catch (Exception e) { + log.warn("Exception while fetching data", e); + dataFetcherResult = new CompletableFuture(); + dataFetcherResult.completeExceptionally(e); + } + + return dataFetcherResult.handleAsync((value,th)-> { + if(th != null) { + log.warn("Exception while fetching data", th); + handleDataFetchingException(executionContext, fieldDef, argumentValues, new ExecutionException(th)); + fetchCtx.onEnd(th); + } + + TypeInfo fieldTypeInfo = newTypeInfo() + .type(fieldType) + .parentInfo(parameters.typeInfo()) + .build(); + + ExecutionParameters newParameters = ExecutionParameters.newParameters() + .typeInfo(fieldTypeInfo) + .fields(parameters.fields()) + .arguments(argumentValues) + .source(value).build(); + + return newParameters; + }).thenComposeAsync(newParameters -> completeValueAsync(executionContext, newParameters, + fields)); + + + } + protected CompletionStage completeValueAsync(ExecutionContext executionContext, ExecutionParameters parameters, List fields) { + TypeInfo typeInfo = parameters.typeInfo(); + Object result = parameters.source(); + GraphQLType fieldType = parameters.typeInfo().type(); + + if (result == null) { + if (typeInfo.typeIsNonNull()) { + // see http://facebook.github.io/graphql/#sec-Errors-and-Non-Nullability + NonNullableFieldWasNullException nonNullException = new NonNullableFieldWasNullException(typeInfo); + executionContext.addError(nonNullException); + throw nonNullException; + } + return CompletableFuture.completedFuture(null); + } else if (fieldType instanceof GraphQLList) { + return completeValueForListAsync(executionContext, parameters, fields, toIterable(result)); + } else if (fieldType instanceof GraphQLScalarType) { + return CompletableFuture.completedFuture(completeValueForScalar((GraphQLScalarType) fieldType, result)); + } else if (fieldType instanceof GraphQLEnumType) { + return CompletableFuture.completedFuture(completeValueForEnum((GraphQLEnumType) fieldType, result)); + } + + + GraphQLObjectType resolvedType; + if (fieldType instanceof GraphQLInterfaceType) { + TypeResolutionParameters resolutionParams = TypeResolutionParameters.newParameters() + .graphQLInterfaceType((GraphQLInterfaceType) fieldType) + .field(fields.get(0)) + .value(parameters.source()) + .argumentValues(parameters.arguments()) + .schema(executionContext.getGraphQLSchema()).build(); + resolvedType = resolveTypeForInterface(resolutionParams); + + } else if (fieldType instanceof GraphQLUnionType) { + TypeResolutionParameters resolutionParams = TypeResolutionParameters.newParameters() + .graphQLUnionType((GraphQLUnionType) fieldType) + .field(fields.get(0)) + .value(parameters.source()) + .argumentValues(parameters.arguments()) + .schema(executionContext.getGraphQLSchema()).build(); + resolvedType = resolveTypeForUnion(resolutionParams); + } else { + resolvedType = (GraphQLObjectType) fieldType; + } + + FieldCollectorParameters collectorParameters = newParameters(executionContext.getGraphQLSchema(), resolvedType) + .fragments(executionContext.getFragmentsByName()) + .variables(executionContext.getVariables()) + .build(); + + Map> subFields = fieldCollector.collectFields(collectorParameters, fields); + + ExecutionParameters newParameters = ExecutionParameters.newParameters() + .typeInfo(typeInfo.asType(resolvedType)) + .fields(subFields) + .source(result).build(); + + // Calling this from the executionContext to ensure we shift back from mutation strategy to the query strategy. + + ExecutionResult executionResult = executionContext.getQueryStrategy().execute(executionContext, newParameters); + if(!(executionResult.getData() instanceof CompletionStage)) { + return CompletableFuture.completedFuture(executionResult); + } else { + return ((CompletionStage) executionResult.getData()).handleAsync((resultMap,th) -> + new ExecutionResultImpl(resultMap,executionResult.getErrors()) + ); + + } + } + + protected CompletionStage completeValueForListAsync(ExecutionContext executionContext, ExecutionParameters parameters, List fields, Iterable result) { + TypeInfo typeInfo = parameters.typeInfo(); + GraphQLList fieldType = typeInfo.castType(GraphQLList.class); + List resultList = Collections.synchronizedList(new ArrayList<>()); + CompletionStage future = CompletableFuture.completedFuture(null); + + for (Object item : result) { + ExecutionParameters newParameters = ExecutionParameters.newParameters() + .typeInfo(typeInfo.asType(fieldType.getWrappedType())) + .fields(parameters.fields()) + .source(item).build(); + + CompletionStage completedValueFuture = + completeValueAsync(executionContext, newParameters, fields); + + future = future.thenCombineAsync(completedValueFuture, (t,executionResult) -> { + resultList.add(executionResult.getData()); + return null; + }); + } + + return future.thenApplyAsync(t -> new ExecutionResultImpl(resultList,null)); + } + + private Iterable toIterable(Object result) { + if (result.getClass().isArray()) { + result = Arrays.asList((Object[]) result); + } + //noinspection unchecked + return (Iterable) result; + } +} + diff --git a/src/test/groovy/graphql/execution/AsynchronousExecutionStrategyTest.groovy b/src/test/groovy/graphql/execution/AsynchronousExecutionStrategyTest.groovy new file mode 100644 index 0000000000..f6a968ba9b --- /dev/null +++ b/src/test/groovy/graphql/execution/AsynchronousExecutionStrategyTest.groovy @@ -0,0 +1,130 @@ +package graphql.execution + +import graphql.GraphQL +import graphql.schema.DataFetcher +import graphql.schema.GraphQLObjectType +import graphql.schema.GraphQLSchema +import spock.lang.Specification + +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionStage + +import static graphql.Scalars.GraphQLString +import static graphql.schema.GraphQLFieldDefinition.newFieldDefinition +import static graphql.schema.GraphQLObjectType.newObject + +class AsynchronousExecutionStrategyTest extends Specification { + + def "Example usage of AsynchronousExecutionStrategy."() { + given: + + GraphQLObjectType queryType = newObject() + .name("data") + .field( + newFieldDefinition().type(GraphQLString).name("key1").dataFetcher({env -> CompletableFuture.completedFuture("value1")})) + .field( + newFieldDefinition().type(GraphQLString).name("key2").staticValue("value2")) + .build(); + + GraphQLSchema schema = GraphQLSchema.newSchema() + .query(queryType) + .build(); + + def expected = [key1:"value1",key2:"value2"] + + when: + GraphQL graphQL = GraphQL.newGraphQL(schema) + .queryExecutionStrategy(new AsynchronousExecutionStrategy()) + .build(); + + Map result = ((CompletionStage) graphQL.execute("{key1,key2}").data).toCompletableFuture().get(); + + then: + assert expected == result; + } + + def "Ensure the execution order." () { + given: + Timer timer = new Timer(); + + DataFetcher> grandFetcher = { + env -> + CompletableFuture future = new CompletableFuture<>() + timer.schedule({_-> future.complete([field:"grandValue"]) },50) + return future + } + + DataFetcher> parentFetcher = { + env -> + CompletableFuture future = new CompletableFuture<>() + timer.schedule({_-> future.complete([field:"parentValue"]) },20) + return future + } + + DataFetcher> childFetcher = { + env -> + CompletableFuture future = new CompletableFuture<>() + timer.schedule({_-> future.complete([field:"childValue"]) },10) + return future + } + + GraphQLObjectType childObjectType = newObject().name("ChildObject"). + field(newFieldDefinition().name("field").type(GraphQLString)). + build(); + + GraphQLObjectType parentObjectType = newObject().name("ParentObject"). + field(newFieldDefinition().name("field").type(GraphQLString)). + field(newFieldDefinition().name("child").type(childObjectType).dataFetcher(childFetcher)). + build(); + + GraphQLObjectType grandObjectType = newObject().name("GrandObject"). + field(newFieldDefinition().name("field").type(GraphQLString)). + field(newFieldDefinition().name("parent").type(parentObjectType).dataFetcher(parentFetcher)). + build(); + + GraphQLObjectType rootObjectType = newObject().name("Root"). + field( + newFieldDefinition().name("grand").type(grandObjectType).dataFetcher(grandFetcher) + ).build(); + + GraphQLSchema schema = GraphQLSchema.newSchema() + .query(rootObjectType) + .build(); + when: + + GraphQL graphQL = GraphQL.newGraphQL(schema) + .queryExecutionStrategy(new AsynchronousExecutionStrategy()) + .build(); + + String queryString = + """ + { + grand { + field + parent { + field + child { + field + } + } + } + } + """ + Map result = ((CompletionStage) graphQL.execute(queryString).data).toCompletableFuture().get(); + + def expected = [ + grand:[ + field: "grandValue", + parent:[ + field:"parentValue", + child: [ + field: "childValue" + ] + ] + ] + ] + + then: + assert result == expected + } +}