diff --git a/examples/models/models_mesh_picking.c b/examples/models/models_mesh_picking.c
index 31500bce1..a790e978a 100644
--- a/examples/models/models_mesh_picking.c
+++ b/examples/models/models_mesh_picking.c
@@ -41,16 +41,24 @@ int main(void)
 
     Vector3 towerPos = { 0.0f, 0.0f, 0.0f };                    // Set model position
     BoundingBox towerBBox = GetMeshBoundingBox(tower.meshes[0]);   // Get mesh bounding box
-    bool hitMeshBBox = false;
-    bool hitTriangle = false;
+
+    // Ground quad
+    Vector3 g0 = (Vector3){ -50.0f, 0.0f, -50.0f };
+    Vector3 g1 = (Vector3){ -50.0f, 0.0f,  50.0f };
+    Vector3 g2 = (Vector3){  50.0f, 0.0f,  50.0f };
+    Vector3 g3 = (Vector3){  50.0f, 0.0f, -50.0f };
 
     // Test triangle
-    Vector3 ta = (Vector3){ -25.0, 0.5, 0.0 };
-    Vector3 tb = (Vector3){ -4.0, 2.5, 1.0 };
-    Vector3 tc = (Vector3){ -8.0, 6.5, 0.0 };
+    Vector3 ta = (Vector3){ -25.0f, 0.5f, 0.0f };
+    Vector3 tb = (Vector3){ -4.0f, 2.5f, 1.0f };
+    Vector3 tc = (Vector3){ -8.0f, 6.5f, 0.0f };
 
     Vector3 bary = { 0.0f, 0.0f, 0.0f };
 
+    // Test sphere
+    Vector3 sp = (Vector3){ -30.0f, 5.0f, 5.0f };
+    float sr = 4.0f;
+
     SetCameraMode(camera, CAMERA_FREE); // Set a free camera mode
 
     SetTargetFPS(60);                   // Set our game to run at 60 frames-per-second
@@ -69,11 +77,11 @@ int main(void)
         collision.hit = false;
         Color cursorColor = WHITE;
 
-        // Get ray and test against ground, triangle, and mesh
+        // Get ray and test against objects
         ray = GetMouseRay(GetMousePosition(), camera);
 
-        // Check ray collision aginst ground plane
-        RayCollision groundHitInfo = GetRayCollisionGround(ray, 0.0f);
+        // Check ray collision against ground quad
+        RayCollision groundHitInfo = GetRayCollisionQuad(ray, g0, g1, g2, g3);
 
         if ((groundHitInfo.hit) && (groundHitInfo.distance < collision.distance))
         {
@@ -92,30 +100,37 @@ int main(void)
             hitObjectName = "Triangle";
 
             bary = Vector3Barycenter(collision.point, ta, tb, tc);
-            hitTriangle = true;
         }
-        else hitTriangle = false;
-
-        RayCollision meshHitInfo = { 0 };
+        
+        // Check ray collision against test sphere
+        RayCollision sphereHitInfo = GetRayCollisionSphere(ray, sp, sr);
+        
+        if ((sphereHitInfo.hit) && (sphereHitInfo.distance < collision.distance)) {
+            collision = sphereHitInfo;
+            cursorColor = ORANGE;
+            hitObjectName = "Sphere";
+        }
 
         // Check ray collision against bounding box first, before trying the full ray-mesh test
-        if (GetRayCollisionBox(ray, towerBBox).hit)
+        RayCollision boxHitInfo = GetRayCollisionBox(ray, towerBBox);
+
+        if ((boxHitInfo.hit) && (boxHitInfo.distance < collision.distance))
         {
-            hitMeshBBox = true;
+            collision = boxHitInfo;
+            cursorColor = ORANGE;
+            hitObjectName = "Box";
 
             // Check ray collision against model
             // NOTE: It considers model.transform matrix!
-            meshHitInfo = GetRayCollisionModel(ray, tower);
+            RayCollision meshHitInfo = GetRayCollisionModel(ray, tower);
 
-            if ((meshHitInfo.hit) && (meshHitInfo.distance < collision.distance))
+            if (meshHitInfo.hit)
             {
                 collision = meshHitInfo;
                 cursorColor = ORANGE;
                 hitObjectName = "Mesh";
             }
         }
-
-        hitMeshBBox = false;
         //----------------------------------------------------------------------------------
 
         // Draw
@@ -136,8 +151,11 @@ int main(void)
                 DrawLine3D(tb, tc, PURPLE);
                 DrawLine3D(tc, ta, PURPLE);
 
+                // Draw the test sphere
+                DrawSphereWires(sp, sr, 8, 8, PURPLE);
+
                 // Draw the mesh bbox if we hit it
-                if (hitMeshBBox) DrawBoundingBox(towerBBox, LIME);
+                if (boxHitInfo.hit) DrawBoundingBox(towerBBox, LIME);
 
                 // If we hit something, draw the cursor at the hit point
                 if (collision.hit)
@@ -154,7 +172,7 @@ int main(void)
                 }
 
                 DrawRay(ray, MAROON);
-
+                
                 DrawGrid(10, 10.0f);
 
             EndMode3D();
@@ -178,7 +196,8 @@ int main(void)
                                     collision.normal.y,
                                     collision.normal.z), 10, ypos + 30, 10, BLACK);
 
-                if (hitTriangle) DrawText(TextFormat("Barycenter: %3.2f %3.2f %3.2f",  bary.x, bary.y, bary.z), 10, ypos + 45, 10, BLACK);
+                if (triHitInfo.hit && strcmp(hitObjectName, "Triangle") == 0)
+                    DrawText(TextFormat("Barycenter: %3.2f %3.2f %3.2f",  bary.x, bary.y, bary.z), 10, ypos + 45, 10, BLACK);
             }
 
             DrawText("Use Mouse to Move Camera", 10, 430, 10, GRAY);
diff --git a/src/models.c b/src/models.c
index 9217cc2db..50cc6cb75 100644
--- a/src/models.c
+++ b/src/models.c
@@ -2985,19 +2985,31 @@ RayCollision GetRayCollisionSphere(Ray ray, Vector3 center, float radius)
     RayCollision collision = { 0 };
 
     Vector3 raySpherePos = Vector3Subtract(center, ray.position);
-    float distance = Vector3Length(raySpherePos);
     float vector = Vector3DotProduct(raySpherePos, ray.direction);
-    float d = radius*radius - (distance*distance - vector*vector);
+    float distance = Vector3Length(raySpherePos);
+    float d = radius*radius - (distance * distance - vector*vector);
 
-    if (d >= 0.0f) collision.hit = true;
+    collision.hit = d >= 0.0f;
 
     // Check if ray origin is inside the sphere to calculate the correct collision point
-    if (distance < radius) collision.distance = vector + sqrtf(d);
-    else collision.distance = vector - sqrtf(d);
+    if (distance < radius) { // inside
+        collision.distance = vector + sqrtf(d);
 
-    // Calculate collision point
-    collision.point = Vector3Add(ray.position, Vector3Scale(ray.direction, collision.distance));
+        // Calculate collision point
+        collision.point = Vector3Add(ray.position, Vector3Scale(ray.direction, collision.distance));
 
+        // Calculate collision normal (pointing outwards)
+        collision.normal = Vector3Negate(Vector3Normalize(Vector3Subtract(collision.point, center)));
+    } else { // outside
+        collision.distance = vector - sqrtf(d);
+
+        // Calculate collision point
+        collision.point = Vector3Add(ray.position, Vector3Scale(ray.direction, collision.distance));
+
+        // Calculate collision normal (pointing inwards)
+        collision.normal = Vector3Normalize(Vector3Subtract(collision.point, center));
+    }
+    
     return collision;
 }
 
@@ -3006,19 +3018,60 @@ RayCollision GetRayCollisionBox(Ray ray, BoundingBox box)
 {
     RayCollision collision = { 0 };
 
-    float t[8] = { 0 };
-    t[0] = (box.min.x - ray.position.x)/ray.direction.x;
-    t[1] = (box.max.x - ray.position.x)/ray.direction.x;
-    t[2] = (box.min.y - ray.position.y)/ray.direction.y;
-    t[3] = (box.max.y - ray.position.y)/ray.direction.y;
-    t[4] = (box.min.z - ray.position.z)/ray.direction.z;
-    t[5] = (box.max.z - ray.position.z)/ray.direction.z;
+    // Note: If ray.position is inside the box, the distance is negative (as if the ray was reversed)
+    // Reversing ray.direction will give use the correct result.
+    bool insideBox = 
+        ray.position.x > box.min.x && ray.position.x < box.max.x &&
+        ray.position.y > box.min.y && ray.position.y < box.max.y &&
+        ray.position.z > box.min.z && ray.position.z < box.max.z;
+
+    if (insideBox) {
+        ray.direction = Vector3Negate(ray.direction);
+    }
+
+    float t[11] = { 0 };
+
+    t[8] = 1.0f / ray.direction.x;
+    t[9] = 1.0f / ray.direction.y;
+    t[10] = 1.0f / ray.direction.z;
+
+    t[0] = (box.min.x - ray.position.x) * t[8];
+    t[1] = (box.max.x - ray.position.x) * t[8];
+    t[2] = (box.min.y - ray.position.y) * t[9];
+    t[3] = (box.max.y - ray.position.y) * t[9];
+    t[4] = (box.min.z - ray.position.z) * t[10];
+    t[5] = (box.max.z - ray.position.z) * t[10];
     t[6] = (float)fmax(fmax(fmin(t[0], t[1]), fmin(t[2], t[3])), fmin(t[4], t[5]));
     t[7] = (float)fmin(fmin(fmax(t[0], t[1]), fmax(t[2], t[3])), fmax(t[4], t[5]));
 
     collision.hit = !(t[7] < 0 || t[6] > t[7]);
-    
-    // TODO: Calculate other RayCollision data
+    collision.distance = t[6];
+    collision.point = Vector3Add(ray.position, Vector3Scale(ray.direction, collision.distance));
+
+    // Get box center point
+    collision.normal = Vector3Lerp(box.min, box.max, 0.5f);
+    // Get vector center point->hit point
+    collision.normal = Vector3Subtract(collision.point, collision.normal);
+    // Scale vector to unit cube
+    //  we use an additional .01 to fix numerical errors
+    collision.normal = Vector3Scale(collision.normal, 2.01f);
+    collision.normal = Vector3Divide(collision.normal, Vector3Subtract(box.max, box.min));
+    //  the relevant elemets of the vector are now slightly larger than 1.0f (or smaller than -1.0f)
+    //  and the others are somewhere between -1.0 and 1.0
+    //  casting to int is exactly our wanted normal!
+    collision.normal.x = (int)collision.normal.x;
+    collision.normal.y = (int)collision.normal.y;
+    collision.normal.z = (int)collision.normal.z;
+
+    collision.normal = Vector3Normalize(collision.normal);
+
+    if (insideBox) {
+        // Reset ray.direction
+        ray.direction = Vector3Negate(ray.direction);
+        // Fix result
+        collision.distance *= -1.0f;
+        collision.normal = Vector3Negate(collision.normal);
+    }
 
     return collision;
 }
@@ -3089,6 +3142,7 @@ RayCollision GetRayCollisionModel(Ray ray, Model model)
 }
 
 // Get collision info between ray and triangle
+// NOTE: The points are expected to be in counter-clockwise winding
 // NOTE: Based on https://en.wikipedia.org/wiki/M%C3%B6ller%E2%80%93Trumbore_intersection_algorithm
 RayCollision GetRayCollisionTriangle(Ray ray, Vector3 p1, Vector3 p2, Vector3 p3)
 {
@@ -3147,26 +3201,14 @@ RayCollision GetRayCollisionTriangle(Ray ray, Vector3 p1, Vector3 p2, Vector3 p3
     return collision;
 }
 
-// Get collision info between ray and ground plane (Y-normal plane)
-RayCollision GetRayCollisionGround(Ray ray, float groundHeight)
-{
-    #define EPSILON 0.000001        // A small number
-
+// Get collision info between ray and quad
+// NOTE: The points are expected to be in counter-clockwise winding
+RayCollision GetRayCollisionQuad(Ray ray, Vector3 p1, Vector3 p2, Vector3 p3, Vector3 p4) {
     RayCollision collision = { 0 };
 
-    if (fabsf(ray.direction.y) > EPSILON)
-    {
-        float distance = (ray.position.y - groundHeight)/-ray.direction.y;
+    collision = GetRayCollisionTriangle(ray, p1, p2, p4);
 
-        if (distance >= 0.0)
-        {
-            collision.hit = true;
-            collision.distance = distance;
-            collision.normal = (Vector3){ 0.0, 1.0, 0.0 };
-            collision.point = Vector3Add(ray.position, Vector3Scale(ray.direction, distance));
-            collision.point.y = groundHeight;
-        }
-    }
+    if (!collision.hit) collision = GetRayCollisionTriangle(ray, p2, p3, p4);
 
     return collision;
 }
diff --git a/src/raylib.h b/src/raylib.h
index db9bb4019..0babd0ef4 100644
--- a/src/raylib.h
+++ b/src/raylib.h
@@ -1447,10 +1447,10 @@ RLAPI bool CheckCollisionBoxes(BoundingBox box1, BoundingBox box2);
 RLAPI bool CheckCollisionBoxSphere(BoundingBox box, Vector3 center, float radius);                      // Detect collision between box and sphere
 RLAPI RayCollision GetRayCollisionSphere(Ray ray, Vector3 center, float radius);                        // Get collision info between ray and sphere
 RLAPI RayCollision GetRayCollisionBox(Ray ray, BoundingBox box);                                        // Get collision info between ray and box
-RLAPI RayCollision GetRayCollisionMesh(Ray ray, Mesh mesh, Matrix transform);                           // Get collision info between ray and mesh
 RLAPI RayCollision GetRayCollisionModel(Ray ray, Model model);                                          // Get collision info between ray and model
+RLAPI RayCollision GetRayCollisionMesh(Ray ray, Mesh mesh, Matrix transform);                           // Get collision info between ray and mesh
 RLAPI RayCollision GetRayCollisionTriangle(Ray ray, Vector3 p1, Vector3 p2, Vector3 p3);                // Get collision info between ray and triangle
-RLAPI RayCollision GetRayCollisionGround(Ray ray, float groundHeight);                                  // Get collision info between ray and ground plane (Y-normal plane)
+RLAPI RayCollision GetRayCollisionQuad(Ray ray, Vector3 p1, Vector3 p2, Vector3 p3, Vector3 p4);        // Get collision info between ray and quad
 
 //------------------------------------------------------------------------------------
 // Audio Loading and Playing Functions (Module: audio)