diff --git a/src/Database.elm b/src/Database.elm index a1032cb..6ce0634 100644 --- a/src/Database.elm +++ b/src/Database.elm @@ -147,7 +147,7 @@ type QueryPlan = Read String | Select Selection QueryPlan | Project (List Field) QueryPlan - | Join { left : QueryPlan, leftFields : List Field, right : QueryPlan, rightFields : List Field } + | Join { left : QueryPlan, right : QueryPlan, fields : List ( Field, Field ) } runPlan : QueryPlan -> Database -> Result Problem Relation @@ -197,14 +197,20 @@ runPlan plan ((Database db) as db_) = Err err -> Err err ) + + leftFields = + List.map Tuple.first config.fields + + rightFields = + List.map Tuple.second config.fields in Result.map2 (\left right -> - if takeFields config.leftFields left.schema /= takeFields config.rightFields right.schema then + if takeFields leftFields left.schema /= takeFields rightFields right.schema then Err (SchemaMismatch - { wanted = takeFields config.leftFields left.schema - , got = takeFields config.rightFields right.schema + { wanted = takeFields leftFields left.schema + , got = takeFields rightFields right.schema } ) @@ -213,7 +219,7 @@ runPlan plan ((Database db) as db_) = leftIndex = List.foldl (\row -> - Sort.Dict.update (takeFields config.leftFields row) + Sort.Dict.update (takeFields leftFields row) (\maybeRows -> case maybeRows of Just rows -> @@ -231,7 +237,7 @@ runPlan plan ((Database db) as db_) = , rows = List.concatMap (\rightRow -> - case Sort.Dict.get (takeFields config.rightFields rightRow) leftIndex of + case Sort.Dict.get (takeFields rightFields rightRow) leftIndex of Just rows -> List.map (\leftRow -> Array.append leftRow rightRow) rows @@ -241,8 +247,8 @@ runPlan plan ((Database db) as db_) = right.rows } ) - (runInput config.left config.leftFields) - (runInput config.right config.rightFields) + (runInput config.left leftFields) + (runInput config.right rightFields) |> Result.andThen identity diff --git a/src/Datalog.elm b/src/Datalog.elm index a9d51cc..936b3f9 100644 --- a/src/Datalog.elm +++ b/src/Datalog.elm @@ -34,8 +34,12 @@ ruleToPlan (Rule (Atom _ headTerms) bodyAtoms) = let ( leftNames, leftPlan ) = atomToPlan nextAtom - - fields = + in + ( leftNames ++ rightNames + , Database.Join + { left = leftPlan + , right = rightPlan + , fields = Dict.merge (\_ _ soFar -> soFar) (\_ left right soFar -> ( left, right ) :: soFar) @@ -43,13 +47,6 @@ ruleToPlan (Rule (Atom _ headTerms) bodyAtoms) = (Dict.fromList (List.indexedMap (\i field -> ( field, i )) leftNames)) (Dict.fromList (List.indexedMap (\i field -> ( field, i )) rightNames)) [] - in - ( leftNames ++ rightNames - , Database.Join - { left = leftPlan - , leftFields = List.map Tuple.first fields - , right = rightPlan - , rightFields = List.map Tuple.second fields } ) ) diff --git a/tests/DatabaseTests.elm b/tests/DatabaseTests.elm index 82f6d25..fb28942 100644 --- a/tests/DatabaseTests.elm +++ b/tests/DatabaseTests.elm @@ -191,9 +191,8 @@ runPlanTests = (runPlan (Join { left = Read "mascots" - , leftFields = [ 3 ] , right = Read "teams" - , rightFields = [ 0 ] + , fields = [ ( 3, 0 ) ] } ) ) @@ -205,9 +204,8 @@ runPlanTests = (runPlan (Join { left = Read "mascots" - , leftFields = [ 0 ] , right = Read "teams" - , rightFields = [ 4 ] + , fields = [ ( 0, 4 ) ] } ) ) @@ -219,9 +217,8 @@ runPlanTests = (runPlan (Join { left = Read "mascots" - , leftFields = [ 0 ] , right = Read "teams" - , rightFields = [ 3 ] + , fields = [ ( 0, 3 ) ] } ) ) @@ -233,27 +230,6 @@ runPlanTests = } ) ) - , test "it's an error if you join on different numbers of keys" <| - \_ -> - mascotsDb - |> Result.andThen - (runPlan - (Join - { left = Read "mascots" - , leftFields = [ 0, 1 ] - , right = Read "teams" - , rightFields = [ 0 ] - } - ) - ) - |> Expect.equal - (Err - (SchemaMismatch - { wanted = Array.fromList [ StringType, StringType ] - , got = Array.fromList [ StringType ] - } - ) - ) , test "joins on fields in order" <| \_ -> mascotsDb @@ -261,9 +237,8 @@ runPlanTests = (runPlan (Join { left = Read "mascots" - , leftFields = [ 1 ] , right = Read "teams" - , rightFields = [ 0 ] + , fields = [ ( 1, 0 ) ] } ) ) @@ -284,9 +259,8 @@ runPlanTests = (runPlan (Join { left = Read "mascots" - , leftFields = [] , right = Read "teams" - , rightFields = [] + , fields = [] } ) ) diff --git a/tests/DatalogTests.elm b/tests/DatalogTests.elm index 73c4b03..569b596 100644 --- a/tests/DatalogTests.elm +++ b/tests/DatalogTests.elm @@ -37,9 +37,8 @@ datalogTests = |> Expect.equal (Database.Join { left = Database.Read "reachable" - , leftFields = [ 0 ] , right = Database.Read "link" - , rightFields = [ 1 ] + , fields = [ ( 0, 1 ) ] } |> Database.Project [ 2, 1 ] |> Ok